PageRenderTime 2008ms CodeModel.GetById 23ms RepoModel.GetById 1ms app.codeStats 0ms

/zmq/tests/asyncio/_test_asyncio.py

http://github.com/zeromq/pyzmq
Python | 425 lines | 362 code | 48 blank | 15 comment | 17 complexity | a4782379955e837d84cff8d78c0547ab MD5 | raw file
Possible License(s): BSD-3-Clause, LGPL-3.0, Apache-2.0
  1. """Test asyncio support"""
  2. # Copyright (c) PyZMQ Developers
  3. # Distributed under the terms of the Modified BSD License.
  4. import json
  5. import os
  6. import sys
  7. import pytest
  8. from pytest import mark
  9. import zmq
  10. from zmq.utils.strtypes import u
  11. try:
  12. import asyncio
  13. import zmq.asyncio as zaio
  14. from zmq.auth.asyncio import AsyncioAuthenticator
  15. except ImportError:
  16. if sys.version_info >= (3,4):
  17. raise
  18. asyncio = None
  19. from concurrent.futures import CancelledError
  20. from zmq.tests import BaseZMQTestCase, SkipTest
  21. from zmq.tests.test_auth import TestThreadAuthentication
  22. class TestAsyncIOSocket(BaseZMQTestCase):
  23. if asyncio is not None:
  24. Context = zaio.Context
  25. def setUp(self):
  26. if asyncio is None:
  27. raise SkipTest()
  28. self.loop = asyncio.new_event_loop()
  29. asyncio.set_event_loop(self.loop)
  30. super(TestAsyncIOSocket, self).setUp()
  31. def tearDown(self):
  32. self.loop.close()
  33. super().tearDown()
  34. def test_socket_class(self):
  35. s = self.context.socket(zmq.PUSH)
  36. assert isinstance(s, zaio.Socket)
  37. s.close()
  38. def test_recv_multipart(self):
  39. @asyncio.coroutine
  40. def test():
  41. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  42. f = b.recv_multipart()
  43. assert not f.done()
  44. yield from a.send(b'hi')
  45. recvd = yield from f
  46. self.assertEqual(recvd, [b'hi'])
  47. self.loop.run_until_complete(test())
  48. def test_recv(self):
  49. @asyncio.coroutine
  50. def test():
  51. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  52. f1 = b.recv()
  53. f2 = b.recv()
  54. assert not f1.done()
  55. assert not f2.done()
  56. yield from a.send_multipart([b'hi', b'there'])
  57. recvd = yield from f2
  58. assert f1.done()
  59. self.assertEqual(f1.result(), b'hi')
  60. self.assertEqual(recvd, b'there')
  61. self.loop.run_until_complete(test())
  62. @mark.skipif(not hasattr(zmq, 'RCVTIMEO'), reason="requires RCVTIMEO")
  63. def test_recv_timeout(self):
  64. @asyncio.coroutine
  65. def test():
  66. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  67. b.rcvtimeo = 100
  68. f1 = b.recv()
  69. b.rcvtimeo = 1000
  70. f2 = b.recv_multipart()
  71. with self.assertRaises(zmq.Again):
  72. yield from f1
  73. yield from a.send_multipart([b'hi', b'there'])
  74. recvd = yield from f2
  75. assert f2.done()
  76. self.assertEqual(recvd, [b'hi', b'there'])
  77. self.loop.run_until_complete(test())
  78. @mark.skipif(not hasattr(zmq, 'SNDTIMEO'), reason="requires SNDTIMEO")
  79. def test_send_timeout(self):
  80. @asyncio.coroutine
  81. def test():
  82. s = self.socket(zmq.PUSH)
  83. s.sndtimeo = 100
  84. with self.assertRaises(zmq.Again):
  85. yield from s.send(b'not going anywhere')
  86. self.loop.run_until_complete(test())
  87. def test_recv_string(self):
  88. @asyncio.coroutine
  89. def test():
  90. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  91. f = b.recv_string()
  92. assert not f.done()
  93. msg = u('πøøπ')
  94. yield from a.send_string(msg)
  95. recvd = yield from f
  96. assert f.done()
  97. self.assertEqual(f.result(), msg)
  98. self.assertEqual(recvd, msg)
  99. self.loop.run_until_complete(test())
  100. def test_recv_json(self):
  101. @asyncio.coroutine
  102. def test():
  103. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  104. f = b.recv_json()
  105. assert not f.done()
  106. obj = dict(a=5)
  107. yield from a.send_json(obj)
  108. recvd = yield from f
  109. assert f.done()
  110. self.assertEqual(f.result(), obj)
  111. self.assertEqual(recvd, obj)
  112. self.loop.run_until_complete(test())
  113. def test_recv_json_cancelled(self):
  114. @asyncio.coroutine
  115. def test():
  116. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  117. f = b.recv_json()
  118. assert not f.done()
  119. f.cancel()
  120. # cycle eventloop to allow cancel events to fire
  121. yield from asyncio.sleep(0)
  122. obj = dict(a=5)
  123. yield from a.send_json(obj)
  124. # CancelledError change in 3.8 https://bugs.python.org/issue32528
  125. if sys.version_info < (3, 8):
  126. with pytest.raises(CancelledError):
  127. recvd = yield from f
  128. else:
  129. with pytest.raises(asyncio.exceptions.CancelledError):
  130. recvd = yield from f
  131. assert f.done()
  132. # give it a chance to incorrectly consume the event
  133. events = yield from b.poll(timeout=5)
  134. assert events
  135. yield from asyncio.sleep(0)
  136. # make sure cancelled recv didn't eat up event
  137. f = b.recv_json()
  138. recvd = yield from asyncio.wait_for(f, timeout=5)
  139. assert recvd == obj
  140. self.loop.run_until_complete(test())
  141. def test_recv_pyobj(self):
  142. @asyncio.coroutine
  143. def test():
  144. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  145. f = b.recv_pyobj()
  146. assert not f.done()
  147. obj = dict(a=5)
  148. yield from a.send_pyobj(obj)
  149. recvd = yield from f
  150. assert f.done()
  151. self.assertEqual(f.result(), obj)
  152. self.assertEqual(recvd, obj)
  153. self.loop.run_until_complete(test())
  154. def test_custom_serialize(self):
  155. def serialize(msg):
  156. frames = []
  157. frames.extend(msg.get('identities', []))
  158. content = json.dumps(msg['content']).encode('utf8')
  159. frames.append(content)
  160. return frames
  161. def deserialize(frames):
  162. identities = frames[:-1]
  163. content = json.loads(frames[-1].decode('utf8'))
  164. return {
  165. 'identities': identities,
  166. 'content': content,
  167. }
  168. @asyncio.coroutine
  169. def test():
  170. a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
  171. msg = {
  172. 'content': {
  173. 'a': 5,
  174. 'b': 'bee',
  175. }
  176. }
  177. yield from a.send_serialized(msg, serialize)
  178. recvd = yield from b.recv_serialized(deserialize)
  179. assert recvd['content'] == msg['content']
  180. assert recvd['identities']
  181. # bounce back, tests identities
  182. yield from b.send_serialized(recvd, serialize)
  183. r2 = yield from a.recv_serialized(deserialize)
  184. assert r2['content'] == msg['content']
  185. assert not r2['identities']
  186. self.loop.run_until_complete(test())
  187. def test_custom_serialize_error(self):
  188. @asyncio.coroutine
  189. def test():
  190. a, b = self.create_bound_pair(zmq.DEALER, zmq.ROUTER)
  191. msg = {
  192. 'content': {
  193. 'a': 5,
  194. 'b': 'bee',
  195. }
  196. }
  197. with pytest.raises(TypeError):
  198. yield from a.send_serialized(json, json.dumps)
  199. yield from a.send(b'not json')
  200. with pytest.raises(TypeError):
  201. recvd = yield from b.recv_serialized(json.loads)
  202. self.loop.run_until_complete(test())
  203. def test_recv_dontwait(self):
  204. @asyncio.coroutine
  205. def test():
  206. push, pull = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  207. f = pull.recv(zmq.DONTWAIT)
  208. with self.assertRaises(zmq.Again):
  209. yield from f
  210. yield from push.send(b'ping')
  211. yield from pull.poll() # ensure message will be waiting
  212. f = pull.recv(zmq.DONTWAIT)
  213. assert f.done()
  214. msg = yield from f
  215. self.assertEqual(msg, b'ping')
  216. self.loop.run_until_complete(test())
  217. def test_recv_cancel(self):
  218. @asyncio.coroutine
  219. def test():
  220. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  221. f1 = b.recv()
  222. f2 = b.recv_multipart()
  223. assert f1.cancel()
  224. assert f1.done()
  225. assert not f2.done()
  226. yield from a.send_multipart([b'hi', b'there'])
  227. recvd = yield from f2
  228. assert f1.cancelled()
  229. assert f2.done()
  230. self.assertEqual(recvd, [b'hi', b'there'])
  231. self.loop.run_until_complete(test())
  232. def test_poll(self):
  233. @asyncio.coroutine
  234. def test():
  235. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  236. f = b.poll(timeout=0)
  237. yield from asyncio.sleep(0)
  238. self.assertEqual(f.result(), 0)
  239. f = b.poll(timeout=1)
  240. assert not f.done()
  241. evt = yield from f
  242. self.assertEqual(evt, 0)
  243. f = b.poll(timeout=1000)
  244. assert not f.done()
  245. yield from a.send_multipart([b'hi', b'there'])
  246. evt = yield from f
  247. self.assertEqual(evt, zmq.POLLIN)
  248. recvd = yield from b.recv_multipart()
  249. self.assertEqual(recvd, [b'hi', b'there'])
  250. self.loop.run_until_complete(test())
  251. def test_poll_base_socket(self):
  252. @asyncio.coroutine
  253. def test():
  254. ctx = zmq.Context()
  255. url = 'inproc://test'
  256. a = ctx.socket(zmq.PUSH)
  257. b = ctx.socket(zmq.PULL)
  258. self.sockets.extend([a, b])
  259. a.bind(url)
  260. b.connect(url)
  261. poller = zaio.Poller()
  262. poller.register(b, zmq.POLLIN)
  263. f = poller.poll(timeout=1000)
  264. assert not f.done()
  265. a.send_multipart([b'hi', b'there'])
  266. evt = yield from f
  267. self.assertEqual(evt, [(b, zmq.POLLIN)])
  268. recvd = b.recv_multipart()
  269. self.assertEqual(recvd, [b'hi', b'there'])
  270. self.loop.run_until_complete(test())
  271. def test_poll_on_closed_socket(self):
  272. @asyncio.coroutine
  273. def test():
  274. a, b = self.create_bound_pair(zmq.PUSH, zmq.PULL)
  275. f = b.poll(timeout=1)
  276. b.close()
  277. # The test might stall if we try to yield from f directly so instead just make a few
  278. # passes through the event loop to schedule and execute all callbacks
  279. for _ in range(5):
  280. yield from asyncio.sleep(0)
  281. if f.cancelled():
  282. break
  283. assert f.cancelled()
  284. self.loop.run_until_complete(test())
  285. @pytest.mark.skipif(
  286. sys.platform.startswith('win'),
  287. reason='Windows does not support polling on files')
  288. def test_poll_raw(self):
  289. @asyncio.coroutine
  290. def test():
  291. p = zaio.Poller()
  292. # make a pipe
  293. r, w = os.pipe()
  294. r = os.fdopen(r, 'rb')
  295. w = os.fdopen(w, 'wb')
  296. # POLLOUT
  297. p.register(r, zmq.POLLIN)
  298. p.register(w, zmq.POLLOUT)
  299. evts = yield from p.poll(timeout=1)
  300. evts = dict(evts)
  301. assert r.fileno() not in evts
  302. assert w.fileno() in evts
  303. assert evts[w.fileno()] == zmq.POLLOUT
  304. # POLLIN
  305. p.unregister(w)
  306. w.write(b'x')
  307. w.flush()
  308. evts = yield from p.poll(timeout=1000)
  309. evts = dict(evts)
  310. assert r.fileno() in evts
  311. assert evts[r.fileno()] == zmq.POLLIN
  312. assert r.read(1) == b'x'
  313. r.close()
  314. w.close()
  315. loop = asyncio.get_event_loop()
  316. loop.run_until_complete(test())
  317. def test_shadow(self):
  318. @asyncio.coroutine
  319. def test():
  320. ctx = zmq.Context()
  321. s = ctx.socket(zmq.PULL)
  322. async_s = zaio.Socket(s)
  323. assert isinstance(async_s, self.socket_class)
  324. class TestAsyncioAuthentication(TestThreadAuthentication):
  325. """Test authentication running in a asyncio task"""
  326. if asyncio is not None:
  327. Context = zaio.Context
  328. def shortDescription(self):
  329. """Rewrite doc strings from TestThreadAuthentication from
  330. 'threaded' to 'asyncio'.
  331. """
  332. doc = self._testMethodDoc
  333. if doc:
  334. doc = doc.split("\n")[0].strip()
  335. if doc.startswith('threaded auth'):
  336. doc = doc.replace('threaded auth', 'asyncio auth')
  337. return doc
  338. def setUp(self):
  339. if asyncio is None:
  340. raise SkipTest()
  341. self.loop = zaio.ZMQEventLoop()
  342. asyncio.set_event_loop(self.loop)
  343. super().setUp()
  344. def tearDown(self):
  345. super().tearDown()
  346. self.loop.close()
  347. def make_auth(self):
  348. return AsyncioAuthenticator(self.context)
  349. def can_connect(self, server, client):
  350. """Check if client can connect to server using tcp transport"""
  351. @asyncio.coroutine
  352. def go():
  353. result = False
  354. iface = 'tcp://127.0.0.1'
  355. port = server.bind_to_random_port(iface)
  356. client.connect("%s:%i" % (iface, port))
  357. msg = [b"Hello World"]
  358. yield from server.send_multipart(msg)
  359. if (yield from client.poll(1000)):
  360. rcvd_msg = yield from client.recv_multipart()
  361. self.assertEqual(rcvd_msg, msg)
  362. result = True
  363. return result
  364. return self.loop.run_until_complete(go())
  365. def _select_recv(self, multipart, socket, **kwargs):
  366. recv = socket.recv_multipart if multipart else socket.recv
  367. @asyncio.coroutine
  368. def coro():
  369. if not (yield from socket.poll(5000)):
  370. raise TimeoutError("Should have received a message")
  371. return (yield from recv(**kwargs))
  372. return self.loop.run_until_complete(coro())