PageRenderTime 26ms CodeModel.GetById 41ms RepoModel.GetById 0ms app.codeStats 0ms

/zmq/tests/test_asyncio.py

https://github.com/zeromq/pyzmq
Python | 499 lines | 409 code | 73 blank | 17 comment | 20 complexity | 5bdd77e4091bb22327caa93bb5359b2c MD5 | raw file
Possible License(s): BSD-3-Clause, LGPL-3.0, Apache-2.0
  1. """Test asyncio support"""
  2. # Copyright (c) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. import json
  5. from multiprocessing import Process
  6. import os
  7. import sys
  8. import pytest
  9. from pytest import mark
  10. import zmq
  11. from zmq.utils.strtypes import u
  12. import asyncio
  13. import zmq.asyncio as zaio
  14. from zmq.auth.asyncio import AsyncioAuthenticator
  15. from concurrent.futures import CancelledError
  16. from zmq.tests import BaseZMQTestCase
  17. from zmq.tests.test_auth import TestThreadAuthentication
  18. class ProcessForTeardownTest(Process):
  19. def __init__(self, event_loop_policy_class):
  20. Process.__init__(self)
  21. self.event_loop_policy_class = event_loop_policy_class
  22. def run(self):
  23. """Leave context, socket and event loop upon implicit disposal"""
  24. asyncio.set_event_loop_policy(self.event_loop_policy_class())
  25. actx = zaio.Context.instance()
  26. socket = actx.socket(zmq.PAIR)
  27. socket.bind_to_random_port("tcp://127.0.0.1")
  28. async def never_ending_task(socket):
  29. await socket.recv() # never ever receive anything
  30. loop = asyncio.get_event_loop()
  31. coro = asyncio.wait_for(never_ending_task(socket), timeout=1)
  32. try:
  33. loop.run_until_complete(coro)
  34. except asyncio.TimeoutError:
  35. pass # expected timeout
  36. else:
  37. assert False, "never_ending_task was completed unexpectedly"
  38. class TestAsyncIOSocket(BaseZMQTestCase):
  39. Context = zaio.Context
  40. def setUp(self):
  41. self.loop = asyncio.new_event_loop()
  42. asyncio.set_event_loop(self.loop)
  43. super(TestAsyncIOSocket, self).setUp()
  44. def tearDown(self):
  45. super().tearDown()
  46. self.loop.close()
  47. # verify cleanup of references to selectors
  48. assert zaio._selectors == {}
  49. if 'zmq._asyncio_selector' in sys.modules:
  50. assert zmq._asyncio_selector._selector_loops == set()
  51. def test_socket_class(self):
  52. s = self.context.socket(zmq.PUSH)
  53. assert isinstance(s, zaio.Socket)
  54. s.close()
  55. def test_instance_subclass_first(self):
  56. actx = zmq.asyncio.Context.instance()
  57. ctx = zmq.Context.instance()
  58. ctx.term()
  59. actx.term()
  60. assert type(ctx) is zmq.Context
  61. assert type(actx) is zmq.asyncio.Context
  62. def test_instance_subclass_second(self):
  63. ctx = zmq.Context.instance()
  64. actx = zmq.asyncio.Context.instance()
  65. ctx.term()
  66. actx.term()
  67. assert type(ctx) is zmq.Context
  68. assert type(actx) is zmq.asyncio.Context
  69. def test_recv_multipart(self):
  70. async def test():
  71. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  72. f = b.recv_multipart()
  73. assert not f.done()
  74. await a.send(b"hi")
  75. recvd = await f
  76. self.assertEqual(recvd, [b"hi"])
  77. self.loop.run_until_complete(test())
  78. def test_recv(self):
  79. async def test():
  80. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  81. f1 = b.recv()
  82. f2 = b.recv()
  83. assert not f1.done()
  84. assert not f2.done()
  85. await a.send_multipart([b"hi", b"there"])
  86. recvd = await f2
  87. assert f1.done()
  88. self.assertEqual(f1.result(), b"hi")
  89. self.assertEqual(recvd, b"there")
  90. self.loop.run_until_complete(test())
  91. @mark.skipif(not hasattr(zmq, "RCVTIMEO"), reason="requires RCVTIMEO")
  92. def test_recv_timeout(self):
  93. async def test():
  94. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  95. b.rcvtimeo = 100
  96. f1 = b.recv()
  97. b.rcvtimeo = 1000
  98. f2 = b.recv_multipart()
  99. with self.assertRaises(zmq.Again):
  100. await f1
  101. await a.send_multipart([b"hi", b"there"])
  102. recvd = await f2
  103. assert f2.done()
  104. self.assertEqual(recvd, [b"hi", b"there"])
  105. self.loop.run_until_complete(test())
  106. @mark.skipif(not hasattr(zmq, "SNDTIMEO"), reason="requires SNDTIMEO")
  107. def test_send_timeout(self):
  108. async def test():
  109. s = self.socket(zmq.PUSH)
  110. s.sndtimeo = 100
  111. with self.assertRaises(zmq.Again):
  112. await s.send(b"not going anywhere")
  113. self.loop.run_until_complete(test())
  114. def test_recv_string(self):
  115. async def test():
  116. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  117. f = b.recv_string()
  118. assert not f.done()
  119. msg = u("πøøπ")
  120. await a.send_string(msg)
  121. recvd = await f
  122. assert f.done()
  123. self.assertEqual(f.result(), msg)
  124. self.assertEqual(recvd, msg)
  125. self.loop.run_until_complete(test())
  126. def test_recv_json(self):
  127. async def test():
  128. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  129. f = b.recv_json()
  130. assert not f.done()
  131. obj = dict(a=5)
  132. await a.send_json(obj)
  133. recvd = await f
  134. assert f.done()
  135. self.assertEqual(f.result(), obj)
  136. self.assertEqual(recvd, obj)
  137. self.loop.run_until_complete(test())
  138. def test_recv_json_cancelled(self):
  139. async def test():
  140. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  141. f = b.recv_json()
  142. assert not f.done()
  143. f.cancel()
  144. # cycle eventloop to allow cancel events to fire
  145. await asyncio.sleep(0)
  146. obj = dict(a=5)
  147. await a.send_json(obj)
  148. # CancelledError change in 3.8 https://bugs.python.org/issue32528
  149. if sys.version_info < (3, 8):
  150. with pytest.raises(CancelledError):
  151. recvd = await f
  152. else:
  153. with pytest.raises(asyncio.exceptions.CancelledError):
  154. recvd = await f
  155. assert f.done()
  156. # give it a chance to incorrectly consume the event
  157. events = await b.poll(timeout=5)
  158. assert events
  159. await asyncio.sleep(0)
  160. # make sure cancelled recv didn't eat up event
  161. f = b.recv_json()
  162. recvd = await asyncio.wait_for(f, timeout=5)
  163. assert recvd == obj
  164. self.loop.run_until_complete(test())
  165. def test_recv_pyobj(self):
  166. async def test():
  167. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  168. f = b.recv_pyobj()
  169. assert not f.done()
  170. obj = dict(a=5)
  171. await a.send_pyobj(obj)
  172. recvd = await f
  173. assert f.done()
  174. self.assertEqual(f.result(), obj)
  175. self.assertEqual(recvd, obj)
  176. self.loop.run_until_complete(test())
  177. def test_custom_serialize(self):
  178. def serialize(msg):
  179. frames = []
  180. frames.extend(msg.get("identities", []))
  181. content = json.dumps(msg["content"]).encode("utf8")
  182. frames.append(content)
  183. return frames
  184. def deserialize(frames):
  185. identities = frames[:-1]
  186. content = json.loads(frames[-1].decode("utf8"))
  187. return {
  188. "identities": identities,
  189. "content": content,
  190. }
  191. async def test():
  192. a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
  193. msg = {
  194. "content": {
  195. "a": 5,
  196. "b": "bee",
  197. }
  198. }
  199. await a.send_serialized(msg, serialize)
  200. recvd = await b.recv_serialized(deserialize)
  201. assert recvd["content"] == msg["content"]
  202. assert recvd["identities"]
  203. # bounce back, tests identities
  204. await b.send_serialized(recvd, serialize)
  205. r2 = await a.recv_serialized(deserialize)
  206. assert r2["content"] == msg["content"]
  207. assert not r2["identities"]
  208. self.loop.run_until_complete(test())
  209. def test_custom_serialize_error(self):
  210. async def test():
  211. a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
  212. msg = {
  213. "content": {
  214. "a": 5,
  215. "b": "bee",
  216. }
  217. }
  218. with pytest.raises(TypeError):
  219. await a.send_serialized(json, json.dumps)
  220. await a.send(b"not json")
  221. with pytest.raises(TypeError):
  222. recvd = await b.recv_serialized(json.loads)
  223. self.loop.run_until_complete(test())
  224. def test_recv_dontwait(self):
  225. async def test():
  226. push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  227. f = pull.recv(zmq.DONTWAIT)
  228. with self.assertRaises(zmq.Again):
  229. await f
  230. await push.send(b"ping")
  231. await pull.poll() # ensure message will be waiting
  232. f = pull.recv(zmq.DONTWAIT)
  233. assert f.done()
  234. msg = await f
  235. self.assertEqual(msg, b"ping")
  236. self.loop.run_until_complete(test())
  237. def test_recv_cancel(self):
  238. async def test():
  239. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  240. f1 = b.recv()
  241. f2 = b.recv_multipart()
  242. assert f1.cancel()
  243. assert f1.done()
  244. assert not f2.done()
  245. await a.send_multipart([b"hi", b"there"])
  246. recvd = await f2
  247. assert f1.cancelled()
  248. assert f2.done()
  249. self.assertEqual(recvd, [b"hi", b"there"])
  250. self.loop.run_until_complete(test())
  251. def test_poll(self):
  252. async def test():
  253. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  254. f = b.poll(timeout=0)
  255. await asyncio.sleep(0)
  256. self.assertEqual(f.result(), 0)
  257. f = b.poll(timeout=1)
  258. assert not f.done()
  259. evt = await f
  260. self.assertEqual(evt, 0)
  261. f = b.poll(timeout=1000)
  262. assert not f.done()
  263. await a.send_multipart([b"hi", b"there"])
  264. evt = await f
  265. self.assertEqual(evt, zmq.POLLIN)
  266. recvd = await b.recv_multipart()
  267. self.assertEqual(recvd, [b"hi", b"there"])
  268. self.loop.run_until_complete(test())
  269. def test_poll_base_socket(self):
  270. async def test():
  271. ctx = zmq.Context()
  272. url = "inproc://test"
  273. a = ctx.socket(zmq.PUSH)
  274. b = ctx.socket(zmq.PULL)
  275. self.sockets.extend([a, b])
  276. a.bind(url)
  277. b.connect(url)
  278. poller = zaio.Poller()
  279. poller.register(b, zmq.POLLIN)
  280. f = poller.poll(timeout=1000)
  281. assert not f.done()
  282. a.send_multipart([b"hi", b"there"])
  283. evt = await f
  284. self.assertEqual(evt, [(b, zmq.POLLIN)])
  285. recvd = b.recv_multipart()
  286. self.assertEqual(recvd, [b"hi", b"there"])
  287. self.loop.run_until_complete(test())
  288. def test_poll_on_closed_socket(self):
  289. async def test():
  290. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  291. f = b.poll(timeout=1)
  292. b.close()
  293. # The test might stall if we try to await f directly so instead just make a few
  294. # passes through the event loop to schedule and execute all callbacks
  295. for _ in range(5):
  296. await asyncio.sleep(0)
  297. if f.cancelled():
  298. break
  299. assert f.cancelled()
  300. self.loop.run_until_complete(test())
  301. @pytest.mark.skipif(
  302. sys.platform.startswith("win"),
  303. reason="Windows does not support polling on files",
  304. )
  305. def test_poll_raw(self):
  306. async def test():
  307. p = zaio.Poller()
  308. # make a pipe
  309. r, w = os.pipe()
  310. r = os.fdopen(r, "rb")
  311. w = os.fdopen(w, "wb")
  312. # POLLOUT
  313. p.register(r, zmq.POLLIN)
  314. p.register(w, zmq.POLLOUT)
  315. evts = await p.poll(timeout=1)
  316. evts = dict(evts)
  317. assert r.fileno() not in evts
  318. assert w.fileno() in evts
  319. assert evts[w.fileno()] == zmq.POLLOUT
  320. # POLLIN
  321. p.unregister(w)
  322. w.write(b"x")
  323. w.flush()
  324. evts = await p.poll(timeout=1000)
  325. evts = dict(evts)
  326. assert r.fileno() in evts
  327. assert evts[r.fileno()] == zmq.POLLIN
  328. assert r.read(1) == b"x"
  329. r.close()
  330. w.close()
  331. loop = asyncio.get_event_loop()
  332. loop.run_until_complete(test())
  333. def test_multiple_loops(self):
  334. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  335. async def test():
  336. await a.send(b'buf')
  337. msg = await b.recv()
  338. assert msg == b'buf'
  339. for i in range(3):
  340. loop = asyncio.new_event_loop()
  341. asyncio.set_event_loop(loop)
  342. loop.run_until_complete(asyncio.wait_for(test(), timeout=10))
  343. loop.close()
  344. def test_shadow(self):
  345. async def test():
  346. ctx = zmq.Context()
  347. s = ctx.socket(zmq.PULL)
  348. async_s = zaio.Socket(s)
  349. assert isinstance(async_s, self.socket_class)
  350. def test_process_teardown(self):
  351. event_loop_policy_class = type(asyncio.get_event_loop_policy())
  352. proc = ProcessForTeardownTest(event_loop_policy_class)
  353. proc.start()
  354. try:
  355. proc.join(10) # starting new Python process may cost a lot
  356. self.assertEqual(
  357. proc.exitcode,
  358. 0,
  359. "Python process died with code %d" % proc.exitcode
  360. if proc.exitcode
  361. else "process teardown hangs",
  362. )
  363. finally:
  364. proc.terminate()
  365. class TestAsyncioAuthentication(TestThreadAuthentication):
  366. """Test authentication running in a asyncio task"""
  367. Context = zaio.Context
  368. def shortDescription(self):
  369. """Rewrite doc strings from TestThreadAuthentication from
  370. 'threaded' to 'asyncio'.
  371. """
  372. doc = self._testMethodDoc
  373. if doc:
  374. doc = doc.split("\n")[0].strip()
  375. if doc.startswith("threaded auth"):
  376. doc = doc.replace("threaded auth", "asyncio auth")
  377. return doc
  378. def setUp(self):
  379. self.loop = asyncio.new_event_loop()
  380. asyncio.set_event_loop(self.loop)
  381. super().setUp()
  382. def tearDown(self):
  383. super().tearDown()
  384. self.loop.close()
  385. def make_auth(self):
  386. return AsyncioAuthenticator(self.context)
  387. def can_connect(self, server, client):
  388. """Check if client can connect to server using tcp transport"""
  389. async def go():
  390. result = False
  391. iface = "tcp://127.0.0.1"
  392. port = server.bind_to_random_port(iface)
  393. client.connect("%s:%i" % (iface, port))
  394. msg = [b"Hello World"]
  395. # set timeouts
  396. server.SNDTIMEO = client.RCVTIMEO = 1000
  397. try:
  398. await server.send_multipart(msg)
  399. except zmq.Again:
  400. return False
  401. try:
  402. rcvd_msg = await client.recv_multipart()
  403. except zmq.Again:
  404. return False
  405. else:
  406. assert rcvd_msg == msg
  407. result = True
  408. return result
  409. return self.loop.run_until_complete(go())
  410. def _select_recv(self, multipart, socket, **kwargs):
  411. recv = socket.recv_multipart if multipart else socket.recv
  412. async def coro():
  413. if not await socket.poll(5000):
  414. raise TimeoutError("Should have received a message")
  415. return await recv(**kwargs)
  416. return self.loop.run_until_complete(coro())