PageRenderTime 54ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 0ms

/Lib/test/test_coroutines.py

https://bitbucket.org/mirror/cpython/
Python | 1720 lines | 1401 code | 284 blank | 35 comment | 152 complexity | 490ec1e9146e2bebed72a1e11af52b52 MD5 | raw file
Possible License(s): Unlicense, 0BSD, BSD-3-Clause
  1. import contextlib
  2. import copy
  3. import inspect
  4. import pickle
  5. import sys
  6. import types
  7. import unittest
  8. import warnings
  9. from test import support
  10. class AsyncYieldFrom:
  11. def __init__(self, obj):
  12. self.obj = obj
  13. def __await__(self):
  14. yield from self.obj
  15. class AsyncYield:
  16. def __init__(self, value):
  17. self.value = value
  18. def __await__(self):
  19. yield self.value
  20. def run_async(coro):
  21. assert coro.__class__ in {types.GeneratorType, types.CoroutineType}
  22. buffer = []
  23. result = None
  24. while True:
  25. try:
  26. buffer.append(coro.send(None))
  27. except StopIteration as ex:
  28. result = ex.args[0] if ex.args else None
  29. break
  30. return buffer, result
  31. def run_async__await__(coro):
  32. assert coro.__class__ is types.CoroutineType
  33. aw = coro.__await__()
  34. buffer = []
  35. result = None
  36. i = 0
  37. while True:
  38. try:
  39. if i % 2:
  40. buffer.append(next(aw))
  41. else:
  42. buffer.append(aw.send(None))
  43. i += 1
  44. except StopIteration as ex:
  45. result = ex.args[0] if ex.args else None
  46. break
  47. return buffer, result
  48. @contextlib.contextmanager
  49. def silence_coro_gc():
  50. with warnings.catch_warnings():
  51. warnings.simplefilter("ignore")
  52. yield
  53. support.gc_collect()
  54. class AsyncBadSyntaxTest(unittest.TestCase):
  55. def test_badsyntax_1(self):
  56. with self.assertRaisesRegex(SyntaxError, "'await' outside"):
  57. import test.badsyntax_async1
  58. def test_badsyntax_2(self):
  59. with self.assertRaisesRegex(SyntaxError, "'await' outside"):
  60. import test.badsyntax_async2
  61. def test_badsyntax_3(self):
  62. with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
  63. import test.badsyntax_async3
  64. def test_badsyntax_4(self):
  65. with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
  66. import test.badsyntax_async4
  67. def test_badsyntax_5(self):
  68. with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
  69. import test.badsyntax_async5
  70. def test_badsyntax_6(self):
  71. with self.assertRaisesRegex(
  72. SyntaxError, "'yield' inside async function"):
  73. import test.badsyntax_async6
  74. def test_badsyntax_7(self):
  75. with self.assertRaisesRegex(
  76. SyntaxError, "'yield from' inside async function"):
  77. import test.badsyntax_async7
  78. def test_badsyntax_8(self):
  79. with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
  80. import test.badsyntax_async8
  81. def test_badsyntax_9(self):
  82. ns = {}
  83. for comp in {'(await a for a in b)',
  84. '[await a for a in b]',
  85. '{await a for a in b}',
  86. '{await a: c for a in b}'}:
  87. with self.assertRaisesRegex(SyntaxError, 'await.*in comprehen'):
  88. exec('async def f():\n\t{}'.format(comp), ns, ns)
  89. def test_badsyntax_10(self):
  90. # Tests for issue 24619
  91. samples = [
  92. """async def foo():
  93. def bar(): pass
  94. await = 1
  95. """,
  96. """async def foo():
  97. def bar(): pass
  98. await = 1
  99. """,
  100. """async def foo():
  101. def bar(): pass
  102. if 1:
  103. await = 1
  104. """,
  105. """def foo():
  106. async def bar(): pass
  107. if 1:
  108. await a
  109. """,
  110. """def foo():
  111. async def bar(): pass
  112. await a
  113. """,
  114. """def foo():
  115. def baz(): pass
  116. async def bar(): pass
  117. await a
  118. """,
  119. """def foo():
  120. def baz(): pass
  121. # 456
  122. async def bar(): pass
  123. # 123
  124. await a
  125. """,
  126. """async def foo():
  127. def baz(): pass
  128. # 456
  129. async def bar(): pass
  130. # 123
  131. await = 2
  132. """,
  133. """def foo():
  134. def baz(): pass
  135. async def bar(): pass
  136. await a
  137. """,
  138. """async def foo():
  139. def baz(): pass
  140. async def bar(): pass
  141. await = 2
  142. """,
  143. """async def foo():
  144. def async(): pass
  145. """,
  146. """async def foo():
  147. def await(): pass
  148. """,
  149. """async def foo():
  150. def bar():
  151. await
  152. """,
  153. """async def foo():
  154. return lambda async: await
  155. """,
  156. """async def foo():
  157. return lambda a: await
  158. """,
  159. """await a()""",
  160. """async def foo(a=await b):
  161. pass
  162. """,
  163. """async def foo(a:await b):
  164. pass
  165. """,
  166. """def baz():
  167. async def foo(a=await b):
  168. pass
  169. """,
  170. """async def foo(async):
  171. pass
  172. """,
  173. """async def foo():
  174. def bar():
  175. def baz():
  176. async = 1
  177. """,
  178. """async def foo():
  179. def bar():
  180. def baz():
  181. pass
  182. async = 1
  183. """,
  184. """def foo():
  185. async def bar():
  186. async def baz():
  187. pass
  188. def baz():
  189. 42
  190. async = 1
  191. """,
  192. """async def foo():
  193. def bar():
  194. def baz():
  195. pass\nawait foo()
  196. """,
  197. """def foo():
  198. def bar():
  199. async def baz():
  200. pass\nawait foo()
  201. """,
  202. """async def foo(await):
  203. pass
  204. """,
  205. """def foo():
  206. async def bar(): pass
  207. await a
  208. """,
  209. """def foo():
  210. async def bar():
  211. pass\nawait a
  212. """]
  213. for code in samples:
  214. with self.subTest(code=code), self.assertRaises(SyntaxError):
  215. compile(code, "<test>", "exec")
  216. def test_goodsyntax_1(self):
  217. # Tests for issue 24619
  218. def foo(await):
  219. async def foo(): pass
  220. async def foo():
  221. pass
  222. return await + 1
  223. self.assertEqual(foo(10), 11)
  224. def foo(await):
  225. async def foo(): pass
  226. async def foo(): pass
  227. return await + 2
  228. self.assertEqual(foo(20), 22)
  229. def foo(await):
  230. async def foo(): pass
  231. async def foo(): pass
  232. return await + 2
  233. self.assertEqual(foo(20), 22)
  234. def foo(await):
  235. """spam"""
  236. async def foo(): \
  237. pass
  238. # 123
  239. async def foo(): pass
  240. # 456
  241. return await + 2
  242. self.assertEqual(foo(20), 22)
  243. def foo(await):
  244. def foo(): pass
  245. def foo(): pass
  246. async def bar(): return await_
  247. await_ = await
  248. try:
  249. bar().send(None)
  250. except StopIteration as ex:
  251. return ex.args[0]
  252. self.assertEqual(foo(42), 42)
  253. async def f():
  254. async def g(): pass
  255. await z
  256. await = 1
  257. self.assertTrue(inspect.iscoroutinefunction(f))
  258. class TokenizerRegrTest(unittest.TestCase):
  259. def test_oneline_defs(self):
  260. buf = []
  261. for i in range(500):
  262. buf.append('def i{i}(): return {i}'.format(i=i))
  263. buf = '\n'.join(buf)
  264. # Test that 500 consequent, one-line defs is OK
  265. ns = {}
  266. exec(buf, ns, ns)
  267. self.assertEqual(ns['i499'](), 499)
  268. # Test that 500 consequent, one-line defs *and*
  269. # one 'async def' following them is OK
  270. buf += '\nasync def foo():\n return'
  271. ns = {}
  272. exec(buf, ns, ns)
  273. self.assertEqual(ns['i499'](), 499)
  274. self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
  275. class CoroutineTest(unittest.TestCase):
  276. def test_gen_1(self):
  277. def gen(): yield
  278. self.assertFalse(hasattr(gen, '__await__'))
  279. def test_func_1(self):
  280. async def foo():
  281. return 10
  282. f = foo()
  283. self.assertIsInstance(f, types.CoroutineType)
  284. self.assertTrue(bool(foo.__code__.co_flags & inspect.CO_COROUTINE))
  285. self.assertFalse(bool(foo.__code__.co_flags & inspect.CO_GENERATOR))
  286. self.assertTrue(bool(f.cr_code.co_flags & inspect.CO_COROUTINE))
  287. self.assertFalse(bool(f.cr_code.co_flags & inspect.CO_GENERATOR))
  288. self.assertEqual(run_async(f), ([], 10))
  289. self.assertEqual(run_async__await__(foo()), ([], 10))
  290. def bar(): pass
  291. self.assertFalse(bool(bar.__code__.co_flags & inspect.CO_COROUTINE))
  292. def test_func_2(self):
  293. async def foo():
  294. raise StopIteration
  295. with self.assertRaisesRegex(
  296. RuntimeError, "coroutine raised StopIteration"):
  297. run_async(foo())
  298. def test_func_3(self):
  299. async def foo():
  300. raise StopIteration
  301. with silence_coro_gc():
  302. self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$')
  303. def test_func_4(self):
  304. async def foo():
  305. raise StopIteration
  306. check = lambda: self.assertRaisesRegex(
  307. TypeError, "'coroutine' object is not iterable")
  308. with check():
  309. list(foo())
  310. with check():
  311. tuple(foo())
  312. with check():
  313. sum(foo())
  314. with check():
  315. iter(foo())
  316. with silence_coro_gc(), check():
  317. for i in foo():
  318. pass
  319. with silence_coro_gc(), check():
  320. [i for i in foo()]
  321. def test_func_5(self):
  322. @types.coroutine
  323. def bar():
  324. yield 1
  325. async def foo():
  326. await bar()
  327. check = lambda: self.assertRaisesRegex(
  328. TypeError, "'coroutine' object is not iterable")
  329. with check():
  330. for el in foo(): pass
  331. # the following should pass without an error
  332. for el in bar():
  333. self.assertEqual(el, 1)
  334. self.assertEqual([el for el in bar()], [1])
  335. self.assertEqual(tuple(bar()), (1,))
  336. self.assertEqual(next(iter(bar())), 1)
  337. def test_func_6(self):
  338. @types.coroutine
  339. def bar():
  340. yield 1
  341. yield 2
  342. async def foo():
  343. await bar()
  344. f = foo()
  345. self.assertEqual(f.send(None), 1)
  346. self.assertEqual(f.send(None), 2)
  347. with self.assertRaises(StopIteration):
  348. f.send(None)
  349. def test_func_7(self):
  350. async def bar():
  351. return 10
  352. def foo():
  353. yield from bar()
  354. with silence_coro_gc(), self.assertRaisesRegex(
  355. TypeError,
  356. "cannot 'yield from' a coroutine object in a non-coroutine generator"):
  357. list(foo())
  358. def test_func_8(self):
  359. @types.coroutine
  360. def bar():
  361. return (yield from foo())
  362. async def foo():
  363. return 'spam'
  364. self.assertEqual(run_async(bar()), ([], 'spam') )
  365. def test_func_9(self):
  366. async def foo(): pass
  367. with self.assertWarnsRegex(
  368. RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"):
  369. foo()
  370. support.gc_collect()
  371. def test_func_10(self):
  372. N = 0
  373. @types.coroutine
  374. def gen():
  375. nonlocal N
  376. try:
  377. a = yield
  378. yield (a ** 2)
  379. except ZeroDivisionError:
  380. N += 100
  381. raise
  382. finally:
  383. N += 1
  384. async def foo():
  385. await gen()
  386. coro = foo()
  387. aw = coro.__await__()
  388. self.assertIs(aw, iter(aw))
  389. next(aw)
  390. self.assertEqual(aw.send(10), 100)
  391. self.assertEqual(N, 0)
  392. aw.close()
  393. self.assertEqual(N, 1)
  394. coro = foo()
  395. aw = coro.__await__()
  396. next(aw)
  397. with self.assertRaises(ZeroDivisionError):
  398. aw.throw(ZeroDivisionError, None, None)
  399. self.assertEqual(N, 102)
  400. def test_func_11(self):
  401. async def func(): pass
  402. coro = func()
  403. # Test that PyCoro_Type and _PyCoroWrapper_Type types were properly
  404. # initialized
  405. self.assertIn('__await__', dir(coro))
  406. self.assertIn('__iter__', dir(coro.__await__()))
  407. self.assertIn('coroutine_wrapper', repr(coro.__await__()))
  408. coro.close() # avoid RuntimeWarning
  409. def test_func_12(self):
  410. async def g():
  411. i = me.send(None)
  412. await foo
  413. me = g()
  414. with self.assertRaisesRegex(ValueError,
  415. "coroutine already executing"):
  416. me.send(None)
  417. def test_func_13(self):
  418. async def g():
  419. pass
  420. with self.assertRaisesRegex(
  421. TypeError,
  422. "can't send non-None value to a just-started coroutine"):
  423. g().send('spam')
  424. def test_func_14(self):
  425. @types.coroutine
  426. def gen():
  427. yield
  428. async def coro():
  429. try:
  430. await gen()
  431. except GeneratorExit:
  432. await gen()
  433. c = coro()
  434. c.send(None)
  435. with self.assertRaisesRegex(RuntimeError,
  436. "coroutine ignored GeneratorExit"):
  437. c.close()
  438. def test_func_15(self):
  439. # See http://bugs.python.org/issue25887 for details
  440. async def spammer():
  441. return 'spam'
  442. async def reader(coro):
  443. return await coro
  444. spammer_coro = spammer()
  445. with self.assertRaisesRegex(StopIteration, 'spam'):
  446. reader(spammer_coro).send(None)
  447. with self.assertRaisesRegex(RuntimeError,
  448. 'cannot reuse already awaited coroutine'):
  449. reader(spammer_coro).send(None)
  450. def test_func_16(self):
  451. # See http://bugs.python.org/issue25887 for details
  452. @types.coroutine
  453. def nop():
  454. yield
  455. async def send():
  456. await nop()
  457. return 'spam'
  458. async def read(coro):
  459. await nop()
  460. return await coro
  461. spammer = send()
  462. reader = read(spammer)
  463. reader.send(None)
  464. reader.send(None)
  465. with self.assertRaisesRegex(Exception, 'ham'):
  466. reader.throw(Exception('ham'))
  467. reader = read(spammer)
  468. reader.send(None)
  469. with self.assertRaisesRegex(RuntimeError,
  470. 'cannot reuse already awaited coroutine'):
  471. reader.send(None)
  472. with self.assertRaisesRegex(RuntimeError,
  473. 'cannot reuse already awaited coroutine'):
  474. reader.throw(Exception('wat'))
  475. def test_func_17(self):
  476. # See http://bugs.python.org/issue25887 for details
  477. async def coroutine():
  478. return 'spam'
  479. coro = coroutine()
  480. with self.assertRaisesRegex(StopIteration, 'spam'):
  481. coro.send(None)
  482. with self.assertRaisesRegex(RuntimeError,
  483. 'cannot reuse already awaited coroutine'):
  484. coro.send(None)
  485. with self.assertRaisesRegex(RuntimeError,
  486. 'cannot reuse already awaited coroutine'):
  487. coro.throw(Exception('wat'))
  488. # Closing a coroutine shouldn't raise any exception even if it's
  489. # already closed/exhausted (similar to generators)
  490. coro.close()
  491. coro.close()
  492. def test_func_18(self):
  493. # See http://bugs.python.org/issue25887 for details
  494. async def coroutine():
  495. return 'spam'
  496. coro = coroutine()
  497. await_iter = coro.__await__()
  498. it = iter(await_iter)
  499. with self.assertRaisesRegex(StopIteration, 'spam'):
  500. it.send(None)
  501. with self.assertRaisesRegex(RuntimeError,
  502. 'cannot reuse already awaited coroutine'):
  503. it.send(None)
  504. with self.assertRaisesRegex(RuntimeError,
  505. 'cannot reuse already awaited coroutine'):
  506. # Although the iterator protocol requires iterators to
  507. # raise another StopIteration here, we don't want to do
  508. # that. In this particular case, the iterator will raise
  509. # a RuntimeError, so that 'yield from' and 'await'
  510. # expressions will trigger the error, instead of silently
  511. # ignoring the call.
  512. next(it)
  513. with self.assertRaisesRegex(RuntimeError,
  514. 'cannot reuse already awaited coroutine'):
  515. it.throw(Exception('wat'))
  516. with self.assertRaisesRegex(RuntimeError,
  517. 'cannot reuse already awaited coroutine'):
  518. it.throw(Exception('wat'))
  519. # Closing a coroutine shouldn't raise any exception even if it's
  520. # already closed/exhausted (similar to generators)
  521. it.close()
  522. it.close()
  523. def test_func_19(self):
  524. CHK = 0
  525. @types.coroutine
  526. def foo():
  527. nonlocal CHK
  528. yield
  529. try:
  530. yield
  531. except GeneratorExit:
  532. CHK += 1
  533. async def coroutine():
  534. await foo()
  535. coro = coroutine()
  536. coro.send(None)
  537. coro.send(None)
  538. self.assertEqual(CHK, 0)
  539. coro.close()
  540. self.assertEqual(CHK, 1)
  541. for _ in range(3):
  542. # Closing a coroutine shouldn't raise any exception even if it's
  543. # already closed/exhausted (similar to generators)
  544. coro.close()
  545. self.assertEqual(CHK, 1)
  546. def test_cr_await(self):
  547. @types.coroutine
  548. def a():
  549. self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
  550. self.assertIsNone(coro_b.cr_await)
  551. yield
  552. self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
  553. self.assertIsNone(coro_b.cr_await)
  554. async def c():
  555. await a()
  556. async def b():
  557. self.assertIsNone(coro_b.cr_await)
  558. await c()
  559. self.assertIsNone(coro_b.cr_await)
  560. coro_b = b()
  561. self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CREATED)
  562. self.assertIsNone(coro_b.cr_await)
  563. coro_b.send(None)
  564. self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_SUSPENDED)
  565. self.assertEqual(coro_b.cr_await.cr_await.gi_code.co_name, 'a')
  566. with self.assertRaises(StopIteration):
  567. coro_b.send(None) # complete coroutine
  568. self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CLOSED)
  569. self.assertIsNone(coro_b.cr_await)
  570. def test_corotype_1(self):
  571. ct = types.CoroutineType
  572. self.assertIn('into coroutine', ct.send.__doc__)
  573. self.assertIn('inside coroutine', ct.close.__doc__)
  574. self.assertIn('in coroutine', ct.throw.__doc__)
  575. self.assertIn('of the coroutine', ct.__dict__['__name__'].__doc__)
  576. self.assertIn('of the coroutine', ct.__dict__['__qualname__'].__doc__)
  577. self.assertEqual(ct.__name__, 'coroutine')
  578. async def f(): pass
  579. c = f()
  580. self.assertIn('coroutine object', repr(c))
  581. c.close()
  582. def test_await_1(self):
  583. async def foo():
  584. await 1
  585. with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
  586. run_async(foo())
  587. def test_await_2(self):
  588. async def foo():
  589. await []
  590. with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
  591. run_async(foo())
  592. def test_await_3(self):
  593. async def foo():
  594. await AsyncYieldFrom([1, 2, 3])
  595. self.assertEqual(run_async(foo()), ([1, 2, 3], None))
  596. self.assertEqual(run_async__await__(foo()), ([1, 2, 3], None))
  597. def test_await_4(self):
  598. async def bar():
  599. return 42
  600. async def foo():
  601. return await bar()
  602. self.assertEqual(run_async(foo()), ([], 42))
  603. def test_await_5(self):
  604. class Awaitable:
  605. def __await__(self):
  606. return
  607. async def foo():
  608. return (await Awaitable())
  609. with self.assertRaisesRegex(
  610. TypeError, "__await__.*returned non-iterator of type"):
  611. run_async(foo())
  612. def test_await_6(self):
  613. class Awaitable:
  614. def __await__(self):
  615. return iter([52])
  616. async def foo():
  617. return (await Awaitable())
  618. self.assertEqual(run_async(foo()), ([52], None))
  619. def test_await_7(self):
  620. class Awaitable:
  621. def __await__(self):
  622. yield 42
  623. return 100
  624. async def foo():
  625. return (await Awaitable())
  626. self.assertEqual(run_async(foo()), ([42], 100))
  627. def test_await_8(self):
  628. class Awaitable:
  629. pass
  630. async def foo(): return await Awaitable()
  631. with self.assertRaisesRegex(
  632. TypeError, "object Awaitable can't be used in 'await' expression"):
  633. run_async(foo())
  634. def test_await_9(self):
  635. def wrap():
  636. return bar
  637. async def bar():
  638. return 42
  639. async def foo():
  640. b = bar()
  641. db = {'b': lambda: wrap}
  642. class DB:
  643. b = wrap
  644. return (await bar() + await wrap()() + await db['b']()()() +
  645. await bar() * 1000 + await DB.b()())
  646. async def foo2():
  647. return -await bar()
  648. self.assertEqual(run_async(foo()), ([], 42168))
  649. self.assertEqual(run_async(foo2()), ([], -42))
  650. def test_await_10(self):
  651. async def baz():
  652. return 42
  653. async def bar():
  654. return baz()
  655. async def foo():
  656. return await (await bar())
  657. self.assertEqual(run_async(foo()), ([], 42))
  658. def test_await_11(self):
  659. def ident(val):
  660. return val
  661. async def bar():
  662. return 'spam'
  663. async def foo():
  664. return ident(val=await bar())
  665. async def foo2():
  666. return await bar(), 'ham'
  667. self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
  668. def test_await_12(self):
  669. async def coro():
  670. return 'spam'
  671. class Awaitable:
  672. def __await__(self):
  673. return coro()
  674. async def foo():
  675. return await Awaitable()
  676. with self.assertRaisesRegex(
  677. TypeError, "__await__\(\) returned a coroutine"):
  678. run_async(foo())
  679. def test_await_13(self):
  680. class Awaitable:
  681. def __await__(self):
  682. return self
  683. async def foo():
  684. return await Awaitable()
  685. with self.assertRaisesRegex(
  686. TypeError, "__await__.*returned non-iterator of type"):
  687. run_async(foo())
  688. def test_await_14(self):
  689. class Wrapper:
  690. # Forces the interpreter to use CoroutineType.__await__
  691. def __init__(self, coro):
  692. assert coro.__class__ is types.CoroutineType
  693. self.coro = coro
  694. def __await__(self):
  695. return self.coro.__await__()
  696. class FutureLike:
  697. def __await__(self):
  698. return (yield)
  699. class Marker(Exception):
  700. pass
  701. async def coro1():
  702. try:
  703. return await FutureLike()
  704. except ZeroDivisionError:
  705. raise Marker
  706. async def coro2():
  707. return await Wrapper(coro1())
  708. c = coro2()
  709. c.send(None)
  710. with self.assertRaisesRegex(StopIteration, 'spam'):
  711. c.send('spam')
  712. c = coro2()
  713. c.send(None)
  714. with self.assertRaises(Marker):
  715. c.throw(ZeroDivisionError)
  716. def test_await_15(self):
  717. @types.coroutine
  718. def nop():
  719. yield
  720. async def coroutine():
  721. await nop()
  722. async def waiter(coro):
  723. await coro
  724. coro = coroutine()
  725. coro.send(None)
  726. with self.assertRaisesRegex(RuntimeError,
  727. "coroutine is being awaited already"):
  728. waiter(coro).send(None)
  729. def test_with_1(self):
  730. class Manager:
  731. def __init__(self, name):
  732. self.name = name
  733. async def __aenter__(self):
  734. await AsyncYieldFrom(['enter-1-' + self.name,
  735. 'enter-2-' + self.name])
  736. return self
  737. async def __aexit__(self, *args):
  738. await AsyncYieldFrom(['exit-1-' + self.name,
  739. 'exit-2-' + self.name])
  740. if self.name == 'B':
  741. return True
  742. async def foo():
  743. async with Manager("A") as a, Manager("B") as b:
  744. await AsyncYieldFrom([('managers', a.name, b.name)])
  745. 1/0
  746. f = foo()
  747. result, _ = run_async(f)
  748. self.assertEqual(
  749. result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
  750. ('managers', 'A', 'B'),
  751. 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
  752. )
  753. async def foo():
  754. async with Manager("A") as a, Manager("C") as c:
  755. await AsyncYieldFrom([('managers', a.name, c.name)])
  756. 1/0
  757. with self.assertRaises(ZeroDivisionError):
  758. run_async(foo())
  759. def test_with_2(self):
  760. class CM:
  761. def __aenter__(self):
  762. pass
  763. async def foo():
  764. async with CM():
  765. pass
  766. with self.assertRaisesRegex(AttributeError, '__aexit__'):
  767. run_async(foo())
  768. def test_with_3(self):
  769. class CM:
  770. def __aexit__(self):
  771. pass
  772. async def foo():
  773. async with CM():
  774. pass
  775. with self.assertRaisesRegex(AttributeError, '__aenter__'):
  776. run_async(foo())
  777. def test_with_4(self):
  778. class CM:
  779. def __enter__(self):
  780. pass
  781. def __exit__(self):
  782. pass
  783. async def foo():
  784. async with CM():
  785. pass
  786. with self.assertRaisesRegex(AttributeError, '__aexit__'):
  787. run_async(foo())
  788. def test_with_5(self):
  789. # While this test doesn't make a lot of sense,
  790. # it's a regression test for an early bug with opcodes
  791. # generation
  792. class CM:
  793. async def __aenter__(self):
  794. return self
  795. async def __aexit__(self, *exc):
  796. pass
  797. async def func():
  798. async with CM():
  799. assert (1, ) == 1
  800. with self.assertRaises(AssertionError):
  801. run_async(func())
  802. def test_with_6(self):
  803. class CM:
  804. def __aenter__(self):
  805. return 123
  806. def __aexit__(self, *e):
  807. return 456
  808. async def foo():
  809. async with CM():
  810. pass
  811. with self.assertRaisesRegex(
  812. TypeError, "object int can't be used in 'await' expression"):
  813. # it's important that __aexit__ wasn't called
  814. run_async(foo())
  815. def test_with_7(self):
  816. class CM:
  817. async def __aenter__(self):
  818. return self
  819. def __aexit__(self, *e):
  820. return 444
  821. async def foo():
  822. async with CM():
  823. 1/0
  824. try:
  825. run_async(foo())
  826. except TypeError as exc:
  827. self.assertRegex(
  828. exc.args[0], "object int can't be used in 'await' expression")
  829. self.assertTrue(exc.__context__ is not None)
  830. self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
  831. else:
  832. self.fail('invalid asynchronous context manager did not fail')
  833. def test_with_8(self):
  834. CNT = 0
  835. class CM:
  836. async def __aenter__(self):
  837. return self
  838. def __aexit__(self, *e):
  839. return 456
  840. async def foo():
  841. nonlocal CNT
  842. async with CM():
  843. CNT += 1
  844. with self.assertRaisesRegex(
  845. TypeError, "object int can't be used in 'await' expression"):
  846. run_async(foo())
  847. self.assertEqual(CNT, 1)
  848. def test_with_9(self):
  849. CNT = 0
  850. class CM:
  851. async def __aenter__(self):
  852. return self
  853. async def __aexit__(self, *e):
  854. 1/0
  855. async def foo():
  856. nonlocal CNT
  857. async with CM():
  858. CNT += 1
  859. with self.assertRaises(ZeroDivisionError):
  860. run_async(foo())
  861. self.assertEqual(CNT, 1)
  862. def test_with_10(self):
  863. CNT = 0
  864. class CM:
  865. async def __aenter__(self):
  866. return self
  867. async def __aexit__(self, *e):
  868. 1/0
  869. async def foo():
  870. nonlocal CNT
  871. async with CM():
  872. async with CM():
  873. raise RuntimeError
  874. try:
  875. run_async(foo())
  876. except ZeroDivisionError as exc:
  877. self.assertTrue(exc.__context__ is not None)
  878. self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
  879. self.assertTrue(isinstance(exc.__context__.__context__,
  880. RuntimeError))
  881. else:
  882. self.fail('exception from __aexit__ did not propagate')
  883. def test_with_11(self):
  884. CNT = 0
  885. class CM:
  886. async def __aenter__(self):
  887. raise NotImplementedError
  888. async def __aexit__(self, *e):
  889. 1/0
  890. async def foo():
  891. nonlocal CNT
  892. async with CM():
  893. raise RuntimeError
  894. try:
  895. run_async(foo())
  896. except NotImplementedError as exc:
  897. self.assertTrue(exc.__context__ is None)
  898. else:
  899. self.fail('exception from __aenter__ did not propagate')
  900. def test_with_12(self):
  901. CNT = 0
  902. class CM:
  903. async def __aenter__(self):
  904. return self
  905. async def __aexit__(self, *e):
  906. return True
  907. async def foo():
  908. nonlocal CNT
  909. async with CM() as cm:
  910. self.assertIs(cm.__class__, CM)
  911. raise RuntimeError
  912. run_async(foo())
  913. def test_with_13(self):
  914. CNT = 0
  915. class CM:
  916. async def __aenter__(self):
  917. 1/0
  918. async def __aexit__(self, *e):
  919. return True
  920. async def foo():
  921. nonlocal CNT
  922. CNT += 1
  923. async with CM():
  924. CNT += 1000
  925. CNT += 10000
  926. with self.assertRaises(ZeroDivisionError):
  927. run_async(foo())
  928. self.assertEqual(CNT, 1)
  929. def test_for_1(self):
  930. aiter_calls = 0
  931. class AsyncIter:
  932. def __init__(self):
  933. self.i = 0
  934. async def __aiter__(self):
  935. nonlocal aiter_calls
  936. aiter_calls += 1
  937. return self
  938. async def __anext__(self):
  939. self.i += 1
  940. if not (self.i % 10):
  941. await AsyncYield(self.i * 10)
  942. if self.i > 100:
  943. raise StopAsyncIteration
  944. return self.i, self.i
  945. buffer = []
  946. async def test1():
  947. with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
  948. async for i1, i2 in AsyncIter():
  949. buffer.append(i1 + i2)
  950. yielded, _ = run_async(test1())
  951. # Make sure that __aiter__ was called only once
  952. self.assertEqual(aiter_calls, 1)
  953. self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
  954. self.assertEqual(buffer, [i*2 for i in range(1, 101)])
  955. buffer = []
  956. async def test2():
  957. nonlocal buffer
  958. with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
  959. async for i in AsyncIter():
  960. buffer.append(i[0])
  961. if i[0] == 20:
  962. break
  963. else:
  964. buffer.append('what?')
  965. buffer.append('end')
  966. yielded, _ = run_async(test2())
  967. # Make sure that __aiter__ was called only once
  968. self.assertEqual(aiter_calls, 2)
  969. self.assertEqual(yielded, [100, 200])
  970. self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
  971. buffer = []
  972. async def test3():
  973. nonlocal buffer
  974. with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
  975. async for i in AsyncIter():
  976. if i[0] > 20:
  977. continue
  978. buffer.append(i[0])
  979. else:
  980. buffer.append('what?')
  981. buffer.append('end')
  982. yielded, _ = run_async(test3())
  983. # Make sure that __aiter__ was called only once
  984. self.assertEqual(aiter_calls, 3)
  985. self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
  986. self.assertEqual(buffer, [i for i in range(1, 21)] +
  987. ['what?', 'end'])
  988. def test_for_2(self):
  989. tup = (1, 2, 3)
  990. refs_before = sys.getrefcount(tup)
  991. async def foo():
  992. async for i in tup:
  993. print('never going to happen')
  994. with self.assertRaisesRegex(
  995. TypeError, "async for' requires an object.*__aiter__.*tuple"):
  996. run_async(foo())
  997. self.assertEqual(sys.getrefcount(tup), refs_before)
  998. def test_for_3(self):
  999. class I:
  1000. def __aiter__(self):
  1001. return self
  1002. aiter = I()
  1003. refs_before = sys.getrefcount(aiter)
  1004. async def foo():
  1005. async for i in aiter:
  1006. print('never going to happen')
  1007. with self.assertRaisesRegex(
  1008. TypeError,
  1009. "async for' received an invalid object.*__aiter.*\: I"):
  1010. run_async(foo())
  1011. self.assertEqual(sys.getrefcount(aiter), refs_before)
  1012. def test_for_4(self):
  1013. class I:
  1014. def __aiter__(self):
  1015. return self
  1016. def __anext__(self):
  1017. return ()
  1018. aiter = I()
  1019. refs_before = sys.getrefcount(aiter)
  1020. async def foo():
  1021. async for i in aiter:
  1022. print('never going to happen')
  1023. with self.assertRaisesRegex(
  1024. TypeError,
  1025. "async for' received an invalid object.*__anext__.*tuple"):
  1026. run_async(foo())
  1027. self.assertEqual(sys.getrefcount(aiter), refs_before)
  1028. def test_for_5(self):
  1029. class I:
  1030. async def __aiter__(self):
  1031. return self
  1032. def __anext__(self):
  1033. return 123
  1034. async def foo():
  1035. with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
  1036. async for i in I():
  1037. print('never going to happen')
  1038. with self.assertRaisesRegex(
  1039. TypeError,
  1040. "async for' received an invalid object.*__anext.*int"):
  1041. run_async(foo())
  1042. def test_for_6(self):
  1043. I = 0
  1044. class Manager:
  1045. async def __aenter__(self):
  1046. nonlocal I
  1047. I += 10000
  1048. async def __aexit__(self, *args):
  1049. nonlocal I
  1050. I += 100000
  1051. class Iterable:
  1052. def __init__(self):
  1053. self.i = 0
  1054. def __aiter__(self):
  1055. return self
  1056. async def __anext__(self):
  1057. if self.i > 10:
  1058. raise StopAsyncIteration
  1059. self.i += 1
  1060. return self.i
  1061. ##############
  1062. manager = Manager()
  1063. iterable = Iterable()
  1064. mrefs_before = sys.getrefcount(manager)
  1065. irefs_before = sys.getrefcount(iterable)
  1066. async def main():
  1067. nonlocal I
  1068. async with manager:
  1069. async for i in iterable:
  1070. I += 1
  1071. I += 1000
  1072. with warnings.catch_warnings():
  1073. warnings.simplefilter("error")
  1074. # Test that __aiter__ that returns an asynchronous iterator
  1075. # directly does not throw any warnings.
  1076. run_async(main())
  1077. self.assertEqual(I, 111011)
  1078. self.assertEqual(sys.getrefcount(manager), mrefs_before)
  1079. self.assertEqual(sys.getrefcount(iterable), irefs_before)
  1080. ##############
  1081. async def main():
  1082. nonlocal I
  1083. async with Manager():
  1084. async for i in Iterable():
  1085. I += 1
  1086. I += 1000
  1087. async with Manager():
  1088. async for i in Iterable():
  1089. I += 1
  1090. I += 1000
  1091. run_async(main())
  1092. self.assertEqual(I, 333033)
  1093. ##############
  1094. async def main():
  1095. nonlocal I
  1096. async with Manager():
  1097. I += 100
  1098. async for i in Iterable():
  1099. I += 1
  1100. else:
  1101. I += 10000000
  1102. I += 1000
  1103. async with Manager():
  1104. I += 100
  1105. async for i in Iterable():
  1106. I += 1
  1107. else:
  1108. I += 10000000
  1109. I += 1000
  1110. run_async(main())
  1111. self.assertEqual(I, 20555255)
  1112. def test_for_7(self):
  1113. CNT = 0
  1114. class AI:
  1115. async def __aiter__(self):
  1116. 1/0
  1117. async def foo():
  1118. nonlocal CNT
  1119. with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
  1120. async for i in AI():
  1121. CNT += 1
  1122. CNT += 10
  1123. with self.assertRaises(ZeroDivisionError):
  1124. run_async(foo())
  1125. self.assertEqual(CNT, 0)
  1126. def test_for_8(self):
  1127. CNT = 0
  1128. class AI:
  1129. def __aiter__(self):
  1130. 1/0
  1131. async def foo():
  1132. nonlocal CNT
  1133. async for i in AI():
  1134. CNT += 1
  1135. CNT += 10
  1136. with self.assertRaises(ZeroDivisionError):
  1137. with warnings.catch_warnings():
  1138. warnings.simplefilter("error")
  1139. # Test that if __aiter__ raises an exception it propagates
  1140. # without any kind of warning.
  1141. run_async(foo())
  1142. self.assertEqual(CNT, 0)
  1143. def test_for_9(self):
  1144. # Test that PendingDeprecationWarning can safely be converted into
  1145. # an exception (__aiter__ should not have a chance to raise
  1146. # a ZeroDivisionError.)
  1147. class AI:
  1148. async def __aiter__(self):
  1149. 1/0
  1150. async def foo():
  1151. async for i in AI():
  1152. pass
  1153. with self.assertRaises(PendingDeprecationWarning):
  1154. with warnings.catch_warnings():
  1155. warnings.simplefilter("error")
  1156. run_async(foo())
  1157. def test_for_10(self):
  1158. # Test that PendingDeprecationWarning can safely be converted into
  1159. # an exception.
  1160. class AI:
  1161. async def __aiter__(self):
  1162. pass
  1163. async def foo():
  1164. async for i in AI():
  1165. pass
  1166. with self.assertRaises(PendingDeprecationWarning):
  1167. with warnings.catch_warnings():
  1168. warnings.simplefilter("error")
  1169. run_async(foo())
  1170. def test_copy(self):
  1171. async def func(): pass
  1172. coro = func()
  1173. with self.assertRaises(TypeError):
  1174. copy.copy(coro)
  1175. aw = coro.__await__()
  1176. try:
  1177. with self.assertRaises(TypeError):
  1178. copy.copy(aw)
  1179. finally:
  1180. aw.close()
  1181. def test_pickle(self):
  1182. async def func(): pass
  1183. coro = func()
  1184. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  1185. with self.assertRaises((TypeError, pickle.PicklingError)):
  1186. pickle.dumps(coro, proto)
  1187. aw = coro.__await__()
  1188. try:
  1189. for proto in range(pickle.HIGHEST_PROTOCOL + 1):
  1190. with self.assertRaises((TypeError, pickle.PicklingError)):
  1191. pickle.dumps(aw, proto)
  1192. finally:
  1193. aw.close()
  1194. class CoroAsyncIOCompatTest(unittest.TestCase):
  1195. def test_asyncio_1(self):
  1196. # asyncio cannot be imported when Python is compiled without thread
  1197. # support
  1198. asyncio = support.import_module('asyncio')
  1199. class MyException(Exception):
  1200. pass
  1201. buffer = []
  1202. class CM:
  1203. async def __aenter__(self):
  1204. buffer.append(1)
  1205. await asyncio.sleep(0.01)
  1206. buffer.append(2)
  1207. return self
  1208. async def __aexit__(self, exc_type, exc_val, exc_tb):
  1209. await asyncio.sleep(0.01)
  1210. buffer.append(exc_type.__name__)
  1211. async def f():
  1212. async with CM() as c:
  1213. await asyncio.sleep(0.01)
  1214. raise MyException
  1215. buffer.append('unreachable')
  1216. loop = asyncio.new_event_loop()
  1217. asyncio.set_event_loop(loop)
  1218. try:
  1219. loop.run_until_complete(f())
  1220. except MyException:
  1221. pass
  1222. finally:
  1223. loop.close()
  1224. asyncio.set_event_loop(None)
  1225. self.assertEqual(buffer, [1, 2, 'MyException'])
  1226. class SysSetCoroWrapperTest(unittest.TestCase):
  1227. def test_set_wrapper_1(self):
  1228. async def foo():
  1229. return 'spam'
  1230. wrapped = None
  1231. def wrap(gen):
  1232. nonlocal wrapped
  1233. wrapped = gen
  1234. return gen
  1235. self.assertIsNone(sys.get_coroutine_wrapper())
  1236. sys.set_coroutine_wrapper(wrap)
  1237. self.assertIs(sys.get_coroutine_wrapper(), wrap)
  1238. try:
  1239. f = foo()
  1240. self.assertTrue(wrapped)
  1241. self.assertEqual(run_async(f), ([], 'spam'))
  1242. finally:
  1243. sys.set_coroutine_wrapper(None)
  1244. self.assertIsNone(sys.get_coroutine_wrapper())
  1245. wrapped = None
  1246. with silence_coro_gc():
  1247. foo()
  1248. self.assertFalse(wrapped)
  1249. def test_set_wrapper_2(self):
  1250. self.assertIsNone(sys.get_coroutine_wrapper())
  1251. with self.assertRaisesRegex(TypeError, "callable expected, got int"):
  1252. sys.set_coroutine_wrapper(1)
  1253. self.assertIsNone(sys.get_coroutine_wrapper())
  1254. def test_set_wrapper_3(self):
  1255. async def foo():
  1256. return 'spam'
  1257. def wrapper(coro):
  1258. async def wrap(coro):
  1259. return await coro
  1260. return wrap(coro)
  1261. sys.set_coroutine_wrapper(wrapper)
  1262. try:
  1263. with silence_coro_gc(), self.assertRaisesRegex(
  1264. RuntimeError,
  1265. "coroutine wrapper.*\.wrapper at 0x.*attempted to "
  1266. "recursively wrap .* wrap .*"):
  1267. foo()
  1268. finally:
  1269. sys.set_coroutine_wrapper(None)
  1270. def test_set_wrapper_4(self):
  1271. @types.coroutine
  1272. def foo():
  1273. return 'spam'
  1274. wrapped = None
  1275. def wrap(gen):
  1276. nonlocal wrapped
  1277. wrapped = gen
  1278. return gen
  1279. sys.set_coroutine_wrapper(wrap)
  1280. try:
  1281. foo()
  1282. self.assertIs(
  1283. wrapped, None,
  1284. "generator-based coroutine was wrapped via "
  1285. "sys.set_coroutine_wrapper")
  1286. finally:
  1287. sys.set_coroutine_wrapper(None)
  1288. class CAPITest(unittest.TestCase):
  1289. def test_tp_await_1(self):
  1290. from _testcapi import awaitType as at
  1291. async def foo():
  1292. future = at(iter([1]))
  1293. return (await future)
  1294. self.assertEqual(foo().send(None), 1)
  1295. def test_tp_await_2(self):
  1296. # Test tp_await to __await__ mapping
  1297. from _testcapi import awaitType as at
  1298. future = at(iter([1]))
  1299. self.assertEqual(next(future.__await__()), 1)
  1300. def test_tp_await_3(self):
  1301. from _testcapi import awaitType as at
  1302. async def foo():
  1303. future = at(1)
  1304. return (await future)
  1305. with self.assertRaisesRegex(
  1306. TypeError, "__await__.*returned non-iterator of type 'int'"):
  1307. self.assertEqual(foo().send(None), 1)
  1308. if __name__=="__main__":
  1309. unittest.main()