/tests/test_worker_pool.py

https://github.com/aiokitchen/aiomisc · Python · 205 lines · 153 code · 52 blank · 0 comment · 24 complexity · 6fd031c55fe2835f73dd4e7e265f8218 MD5 · raw file

  1. import asyncio
  2. import operator
  3. import platform
  4. import sys
  5. import threading
  6. from multiprocessing.context import ProcessError
  7. from os import getpid
  8. from time import sleep
  9. import pytest
  10. from aiomisc import WorkerPool
  11. skipif = pytest.mark.skipif(
  12. sys.version_info < (3, 7),
  13. reason="https://bugs.python.org/issue37380",
  14. )
  15. @pytest.fixture
  16. async def worker_pool(loop) -> WorkerPool:
  17. async with WorkerPool(4) as pool:
  18. yield pool
  19. @skipif
  20. async def test_success(worker_pool):
  21. results = await asyncio.gather(
  22. *[
  23. worker_pool.create_task(operator.mul, i, i)
  24. for i in range(worker_pool.workers * 2)
  25. ]
  26. )
  27. results = sorted(results)
  28. assert sorted(results) == [i * i for i in range(worker_pool.workers * 2)]
  29. @skipif
  30. async def test_incomplete_task_kill(worker_pool):
  31. await asyncio.gather(
  32. *[
  33. worker_pool.create_task(getpid)
  34. for _ in range(worker_pool.workers * 4)
  35. ]
  36. )
  37. with pytest.raises(asyncio.TimeoutError):
  38. await asyncio.wait_for(
  39. asyncio.gather(
  40. *[
  41. worker_pool.create_task(sleep, 3600)
  42. for _ in range(worker_pool.workers)
  43. ]
  44. ), timeout=1,
  45. )
  46. await asyncio.gather(
  47. *[
  48. worker_pool.create_task(getpid)
  49. for _ in range(worker_pool.workers * 4)
  50. ]
  51. )
  52. @pytest.mark.skipif(
  53. platform.system() == "Windows", reason="Flapping on windows",
  54. )
  55. @skipif
  56. async def test_incomplete_task_pool_reuse(worker_pool):
  57. pids_start = set(process.pid for process in worker_pool.processes)
  58. await asyncio.gather(
  59. *[
  60. worker_pool.create_task(getpid)
  61. for _ in range(worker_pool.workers * 4)
  62. ]
  63. )
  64. with pytest.raises(asyncio.TimeoutError):
  65. await asyncio.wait_for(
  66. asyncio.gather(
  67. *[
  68. worker_pool.create_task(sleep, 3600)
  69. for _ in range(worker_pool.workers)
  70. ]
  71. ), timeout=1,
  72. )
  73. await asyncio.gather(
  74. *[
  75. worker_pool.create_task(getpid)
  76. for _ in range(worker_pool.workers * 4)
  77. ]
  78. )
  79. pids_end = set(process.pid for process in worker_pool.processes)
  80. assert list(pids_start) == list(pids_end)
  81. @skipif
  82. async def test_exceptions(worker_pool):
  83. results = await asyncio.gather(
  84. *[
  85. worker_pool.create_task(operator.truediv, i, 0)
  86. for i in range(worker_pool.workers * 2)
  87. ], return_exceptions=True
  88. )
  89. assert len(results) == worker_pool.workers * 2
  90. for exc in results:
  91. assert isinstance(exc, ZeroDivisionError)
  92. @skipif
  93. async def test_exit(worker_pool):
  94. exceptions = await asyncio.gather(
  95. *[
  96. worker_pool.create_task(exit, 1)
  97. for _ in range(worker_pool.workers)
  98. ],
  99. return_exceptions=True
  100. )
  101. assert len(exceptions) == worker_pool.workers
  102. for exc in exceptions:
  103. assert isinstance(exc, ProcessError)
  104. @skipif
  105. async def test_exit_respawn(worker_pool):
  106. exceptions = await asyncio.gather(
  107. *[
  108. worker_pool.create_task(exit, 1)
  109. for _ in range(worker_pool.workers * 3)
  110. ],
  111. return_exceptions=True
  112. )
  113. assert len(exceptions) == worker_pool.workers * 3
  114. for exc in exceptions:
  115. assert isinstance(exc, ProcessError)
  116. INITIALIZER_ARGS = None
  117. INITIALIZER_KWARGS = None
  118. def initializer(*args, **kwargs):
  119. global INITIALIZER_ARGS, INITIALIZER_KWARGS
  120. INITIALIZER_ARGS = args
  121. INITIALIZER_KWARGS = kwargs
  122. def get_initializer_args():
  123. return INITIALIZER_ARGS, INITIALIZER_KWARGS
  124. @skipif
  125. async def test_initializer(worker_pool):
  126. pool = WorkerPool(
  127. 1, initializer=initializer, initializer_args=("foo",),
  128. initializer_kwargs={"spam": "egg"},
  129. )
  130. async with pool:
  131. args, kwargs = await pool.create_task(get_initializer_args)
  132. assert args == ("foo",)
  133. assert kwargs == {"spam": "egg"}
  134. async with WorkerPool(1, initializer=initializer) as pool:
  135. args, kwargs = await pool.create_task(get_initializer_args)
  136. assert args == ()
  137. assert kwargs == {}
  138. async with WorkerPool(1) as pool:
  139. args, kwargs = await pool.create_task(get_initializer_args)
  140. assert args is None
  141. assert kwargs is None
  142. def bad_initializer():
  143. return 1 / 0
  144. @skipif
  145. async def test_bad_initializer(worker_pool):
  146. pool = WorkerPool(1, initializer=bad_initializer)
  147. with pytest.raises(ZeroDivisionError):
  148. async with pool:
  149. await pool.create_task(get_initializer_args)
  150. @skipif
  151. async def test_threads_active_count_in_pool(worker_pool):
  152. threads = await worker_pool.create_task(threading.active_count)
  153. assert threads == 1