/distributed/tests/test_utils_test.py

https://github.com/dask/distributed · Python · 218 lines · 167 code · 51 blank · 0 comment · 24 complexity · 0f6260372968513d298c26dae2ed3a26 MD5 · raw file

  1. import asyncio
  2. from contextlib import contextmanager
  3. import socket
  4. import threading
  5. from time import sleep
  6. import pytest
  7. from tornado import gen
  8. from distributed import Scheduler, Worker, Client, config, default_client
  9. from distributed.core import rpc
  10. from distributed.metrics import time
  11. from distributed.utils_test import ( # noqa: F401
  12. cleanup,
  13. cluster,
  14. gen_cluster,
  15. inc,
  16. gen_test,
  17. wait_for_port,
  18. new_config,
  19. )
  20. from distributed.utils_test import ( # noqa: F401
  21. loop,
  22. tls_only_security,
  23. security,
  24. tls_client,
  25. tls_cluster,
  26. )
  27. from distributed.utils import get_ip
  28. def test_bare_cluster(loop):
  29. with cluster(nworkers=10) as (s, _):
  30. pass
  31. def test_cluster(loop):
  32. with cluster() as (s, [a, b]):
  33. with rpc(s["address"]) as s:
  34. ident = loop.run_sync(s.identity)
  35. assert ident["type"] == "Scheduler"
  36. assert len(ident["workers"]) == 2
  37. @gen_cluster(client=True)
  38. async def test_gen_cluster(c, s, a, b):
  39. assert isinstance(c, Client)
  40. assert isinstance(s, Scheduler)
  41. for w in [a, b]:
  42. assert isinstance(w, Worker)
  43. assert s.nthreads == {w.address: w.nthreads for w in [a, b]}
  44. assert await c.submit(lambda: 123) == 123
  45. @gen_cluster(client=True)
  46. def test_gen_cluster_legacy_implicit(c, s, a, b):
  47. assert isinstance(c, Client)
  48. assert isinstance(s, Scheduler)
  49. for w in [a, b]:
  50. assert isinstance(w, Worker)
  51. assert s.nthreads == {w.address: w.nthreads for w in [a, b]}
  52. assert (yield c.submit(lambda: 123)) == 123
  53. @gen_cluster(client=True)
  54. @gen.coroutine
  55. def test_gen_cluster_legacy_explicit(c, s, a, b):
  56. assert isinstance(c, Client)
  57. assert isinstance(s, Scheduler)
  58. for w in [a, b]:
  59. assert isinstance(w, Worker)
  60. assert s.nthreads == {w.address: w.nthreads for w in [a, b]}
  61. assert (yield c.submit(lambda: 123)) == 123
  62. @pytest.mark.skip(reason="This hangs on travis")
  63. def test_gen_cluster_cleans_up_client(loop):
  64. import dask.context
  65. assert not dask.config.get("get", None)
  66. @gen_cluster(client=True)
  67. async def f(c, s, a, b):
  68. assert dask.config.get("get", None)
  69. await c.submit(inc, 1)
  70. f()
  71. assert not dask.config.get("get", None)
  72. @gen_cluster(client=False)
  73. async def test_gen_cluster_without_client(s, a, b):
  74. assert isinstance(s, Scheduler)
  75. for w in [a, b]:
  76. assert isinstance(w, Worker)
  77. assert s.nthreads == {w.address: w.nthreads for w in [a, b]}
  78. async with Client(s.address, asynchronous=True) as c:
  79. future = c.submit(lambda x: x + 1, 1)
  80. result = await future
  81. assert result == 2
  82. @gen_cluster(
  83. client=True,
  84. scheduler="tls://127.0.0.1",
  85. nthreads=[("tls://127.0.0.1", 1), ("tls://127.0.0.1", 2)],
  86. security=tls_only_security(),
  87. )
  88. async def test_gen_cluster_tls(e, s, a, b):
  89. assert isinstance(e, Client)
  90. assert isinstance(s, Scheduler)
  91. assert s.address.startswith("tls://")
  92. for w in [a, b]:
  93. assert isinstance(w, Worker)
  94. assert w.address.startswith("tls://")
  95. assert s.nthreads == {w.address: w.nthreads for w in [a, b]}
  96. @gen_test()
  97. async def test_gen_test():
  98. await asyncio.sleep(0.01)
  99. @gen_test()
  100. def test_gen_test_legacy_implicit():
  101. yield asyncio.sleep(0.01)
  102. @gen_test()
  103. @gen.coroutine
  104. def test_gen_test_legacy_explicit():
  105. yield asyncio.sleep(0.01)
  106. @contextmanager
  107. def _listen(delay=0):
  108. serv = socket.socket()
  109. serv.bind(("127.0.0.1", 0))
  110. e = threading.Event()
  111. def do_listen():
  112. e.set()
  113. sleep(delay)
  114. serv.listen(5)
  115. ret = serv.accept()
  116. if ret is not None:
  117. cli, _ = ret
  118. cli.close()
  119. serv.close()
  120. t = threading.Thread(target=do_listen)
  121. t.daemon = True
  122. t.start()
  123. try:
  124. e.wait()
  125. sleep(0.01)
  126. yield serv
  127. finally:
  128. t.join(5.0)
  129. def test_wait_for_port():
  130. t1 = time()
  131. with pytest.raises(RuntimeError):
  132. wait_for_port((get_ip(), 9999), 0.5)
  133. t2 = time()
  134. assert t2 - t1 >= 0.5
  135. with _listen(0) as s1:
  136. t1 = time()
  137. wait_for_port(s1.getsockname())
  138. t2 = time()
  139. assert t2 - t1 <= 1.0
  140. with _listen(1) as s1:
  141. t1 = time()
  142. wait_for_port(s1.getsockname())
  143. t2 = time()
  144. assert t2 - t1 <= 2.0
  145. def test_new_config():
  146. c = config.copy()
  147. with new_config({"xyzzy": 5}):
  148. config["xyzzy"] == 5
  149. assert config == c
  150. assert "xyzzy" not in config
  151. def test_lingering_client():
  152. @gen_cluster()
  153. async def f(s, a, b):
  154. await Client(s.address, asynchronous=True)
  155. f()
  156. with pytest.raises(ValueError):
  157. default_client()
  158. def test_lingering_client(loop):
  159. with cluster() as (s, [a, b]):
  160. client = Client(s["address"], loop=loop)
  161. def test_tls_cluster(tls_client):
  162. tls_client.submit(lambda x: x + 1, 10).result() == 11
  163. assert tls_client.security
  164. @pytest.mark.asyncio
  165. async def test_tls_scheduler(security, cleanup):
  166. async with Scheduler(security=security, host="localhost") as s:
  167. assert s.address.startswith("tls")