/Lib/test/test_coroutines.py
Python | 1728 lines | 1407 code | 285 blank | 36 comment | 153 complexity | e85c1910ce0ff883c71a957bb345a60f MD5 | raw file
Possible License(s): BSD-3-Clause, Unlicense, CC-BY-SA-3.0, 0BSD
- import contextlib
- import copy
- import inspect
- import pickle
- import sys
- import types
- import unittest
- import warnings
- from test import support
- class AsyncYieldFrom:
- def __init__(self, obj):
- self.obj = obj
- def __await__(self):
- yield from self.obj
- class AsyncYield:
- def __init__(self, value):
- self.value = value
- def __await__(self):
- yield self.value
- def run_async(coro):
- assert coro.__class__ in {types.GeneratorType, types.CoroutineType}
- buffer = []
- result = None
- while True:
- try:
- buffer.append(coro.send(None))
- except StopIteration as ex:
- result = ex.args[0] if ex.args else None
- break
- return buffer, result
- def run_async__await__(coro):
- assert coro.__class__ is types.CoroutineType
- aw = coro.__await__()
- buffer = []
- result = None
- i = 0
- while True:
- try:
- if i % 2:
- buffer.append(next(aw))
- else:
- buffer.append(aw.send(None))
- i += 1
- except StopIteration as ex:
- result = ex.args[0] if ex.args else None
- break
- return buffer, result
- @contextlib.contextmanager
- def silence_coro_gc():
- with warnings.catch_warnings():
- warnings.simplefilter("ignore")
- yield
- support.gc_collect()
- class AsyncBadSyntaxTest(unittest.TestCase):
- def test_badsyntax_1(self):
- with self.assertRaisesRegex(SyntaxError, "'await' outside"):
- import test.badsyntax_async1
- def test_badsyntax_2(self):
- with self.assertRaisesRegex(SyntaxError, "'await' outside"):
- import test.badsyntax_async2
- def test_badsyntax_3(self):
- with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
- import test.badsyntax_async3
- def test_badsyntax_4(self):
- with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
- import test.badsyntax_async4
- def test_badsyntax_5(self):
- with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
- import test.badsyntax_async5
- def test_badsyntax_6(self):
- with self.assertRaisesRegex(
- SyntaxError, "'yield' inside async function"):
- import test.badsyntax_async6
- def test_badsyntax_7(self):
- with self.assertRaisesRegex(
- SyntaxError, "'yield from' inside async function"):
- import test.badsyntax_async7
- def test_badsyntax_8(self):
- with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
- import test.badsyntax_async8
- def test_badsyntax_9(self):
- ns = {}
- for comp in {'(await a for a in b)',
- '[await a for a in b]',
- '{await a for a in b}',
- '{await a: c for a in b}'}:
- with self.assertRaisesRegex(SyntaxError, 'await.*in comprehen'):
- exec('async def f():\n\t{}'.format(comp), ns, ns)
- def test_badsyntax_10(self):
- # Tests for issue 24619
- samples = [
- """async def foo():
- def bar(): pass
- await = 1
- """,
- """async def foo():
- def bar(): pass
- await = 1
- """,
- """async def foo():
- def bar(): pass
- if 1:
- await = 1
- """,
- """def foo():
- async def bar(): pass
- if 1:
- await a
- """,
- """def foo():
- async def bar(): pass
- await a
- """,
- """def foo():
- def baz(): pass
- async def bar(): pass
- await a
- """,
- """def foo():
- def baz(): pass
- # 456
- async def bar(): pass
- # 123
- await a
- """,
- """async def foo():
- def baz(): pass
- # 456
- async def bar(): pass
- # 123
- await = 2
- """,
- """def foo():
- def baz(): pass
- async def bar(): pass
- await a
- """,
- """async def foo():
- def baz(): pass
- async def bar(): pass
- await = 2
- """,
- """async def foo():
- def async(): pass
- """,
- """async def foo():
- def await(): pass
- """,
- """async def foo():
- def bar():
- await
- """,
- """async def foo():
- return lambda async: await
- """,
- """async def foo():
- return lambda a: await
- """,
- """await a()""",
- """async def foo(a=await b):
- pass
- """,
- """async def foo(a:await b):
- pass
- """,
- """def baz():
- async def foo(a=await b):
- pass
- """,
- """async def foo(async):
- pass
- """,
- """async def foo():
- def bar():
- def baz():
- async = 1
- """,
- """async def foo():
- def bar():
- def baz():
- pass
- async = 1
- """,
- """def foo():
- async def bar():
- async def baz():
- pass
- def baz():
- 42
- async = 1
- """,
- """async def foo():
- def bar():
- def baz():
- pass\nawait foo()
- """,
- """def foo():
- def bar():
- async def baz():
- pass\nawait foo()
- """,
- """async def foo(await):
- pass
- """,
- """def foo():
- async def bar(): pass
- await a
- """,
- """def foo():
- async def bar():
- pass\nawait a
- """]
- for code in samples:
- with self.subTest(code=code), self.assertRaises(SyntaxError):
- compile(code, "<test>", "exec")
- def test_goodsyntax_1(self):
- # Tests for issue 24619
- def foo(await):
- async def foo(): pass
- async def foo():
- pass
- return await + 1
- self.assertEqual(foo(10), 11)
- def foo(await):
- async def foo(): pass
- async def foo(): pass
- return await + 2
- self.assertEqual(foo(20), 22)
- def foo(await):
- async def foo(): pass
- async def foo(): pass
- return await + 2
- self.assertEqual(foo(20), 22)
- def foo(await):
- """spam"""
- async def foo(): \
- pass
- # 123
- async def foo(): pass
- # 456
- return await + 2
- self.assertEqual(foo(20), 22)
- def foo(await):
- def foo(): pass
- def foo(): pass
- async def bar(): return await_
- await_ = await
- try:
- bar().send(None)
- except StopIteration as ex:
- return ex.args[0]
- self.assertEqual(foo(42), 42)
- async def f():
- async def g(): pass
- await z
- await = 1
- self.assertTrue(inspect.iscoroutinefunction(f))
- class TokenizerRegrTest(unittest.TestCase):
- def test_oneline_defs(self):
- buf = []
- for i in range(500):
- buf.append('def i{i}(): return {i}'.format(i=i))
- buf = '\n'.join(buf)
- # Test that 500 consequent, one-line defs is OK
- ns = {}
- exec(buf, ns, ns)
- self.assertEqual(ns['i499'](), 499)
- # Test that 500 consequent, one-line defs *and*
- # one 'async def' following them is OK
- buf += '\nasync def foo():\n return'
- ns = {}
- exec(buf, ns, ns)
- self.assertEqual(ns['i499'](), 499)
- self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
- class CoroutineTest(unittest.TestCase):
- def test_gen_1(self):
- def gen(): yield
- self.assertFalse(hasattr(gen, '__await__'))
- def test_func_1(self):
- async def foo():
- return 10
- f = foo()
- self.assertIsInstance(f, types.CoroutineType)
- self.assertTrue(bool(foo.__code__.co_flags & inspect.CO_COROUTINE))
- self.assertFalse(bool(foo.__code__.co_flags & inspect.CO_GENERATOR))
- self.assertTrue(bool(f.cr_code.co_flags & inspect.CO_COROUTINE))
- self.assertFalse(bool(f.cr_code.co_flags & inspect.CO_GENERATOR))
- self.assertEqual(run_async(f), ([], 10))
- self.assertEqual(run_async__await__(foo()), ([], 10))
- def bar(): pass
- self.assertFalse(bool(bar.__code__.co_flags & inspect.CO_COROUTINE))
- def test_func_2(self):
- async def foo():
- raise StopIteration
- with self.assertRaisesRegex(
- RuntimeError, "coroutine raised StopIteration"):
- run_async(foo())
- def test_func_3(self):
- async def foo():
- raise StopIteration
- with silence_coro_gc():
- self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$')
- def test_func_4(self):
- async def foo():
- raise StopIteration
- check = lambda: self.assertRaisesRegex(
- TypeError, "'coroutine' object is not iterable")
- with check():
- list(foo())
- with check():
- tuple(foo())
- with check():
- sum(foo())
- with check():
- iter(foo())
- with silence_coro_gc(), check():
- for i in foo():
- pass
- with silence_coro_gc(), check():
- [i for i in foo()]
- def test_func_5(self):
- @types.coroutine
- def bar():
- yield 1
- async def foo():
- await bar()
- check = lambda: self.assertRaisesRegex(
- TypeError, "'coroutine' object is not iterable")
- with check():
- for el in foo(): pass
- # the following should pass without an error
- for el in bar():
- self.assertEqual(el, 1)
- self.assertEqual([el for el in bar()], [1])
- self.assertEqual(tuple(bar()), (1,))
- self.assertEqual(next(iter(bar())), 1)
- def test_func_6(self):
- @types.coroutine
- def bar():
- yield 1
- yield 2
- async def foo():
- await bar()
- f = foo()
- self.assertEqual(f.send(None), 1)
- self.assertEqual(f.send(None), 2)
- with self.assertRaises(StopIteration):
- f.send(None)
- def test_func_7(self):
- async def bar():
- return 10
- def foo():
- yield from bar()
- with silence_coro_gc(), self.assertRaisesRegex(
- TypeError,
- "cannot 'yield from' a coroutine object in a non-coroutine generator"):
- list(foo())
- def test_func_8(self):
- @types.coroutine
- def bar():
- return (yield from foo())
- async def foo():
- return 'spam'
- self.assertEqual(run_async(bar()), ([], 'spam') )
- def test_func_9(self):
- async def foo(): pass
- with self.assertWarnsRegex(
- RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"):
- foo()
- support.gc_collect()
- def test_func_10(self):
- N = 0
- @types.coroutine
- def gen():
- nonlocal N
- try:
- a = yield
- yield (a ** 2)
- except ZeroDivisionError:
- N += 100
- raise
- finally:
- N += 1
- async def foo():
- await gen()
- coro = foo()
- aw = coro.__await__()
- self.assertIs(aw, iter(aw))
- next(aw)
- self.assertEqual(aw.send(10), 100)
- self.assertEqual(N, 0)
- aw.close()
- self.assertEqual(N, 1)
- coro = foo()
- aw = coro.__await__()
- next(aw)
- with self.assertRaises(ZeroDivisionError):
- aw.throw(ZeroDivisionError, None, None)
- self.assertEqual(N, 102)
- def test_func_11(self):
- async def func(): pass
- coro = func()
- # Test that PyCoro_Type and _PyCoroWrapper_Type types were properly
- # initialized
- self.assertIn('__await__', dir(coro))
- self.assertIn('__iter__', dir(coro.__await__()))
- self.assertIn('coroutine_wrapper', repr(coro.__await__()))
- coro.close() # avoid RuntimeWarning
- def test_func_12(self):
- async def g():
- i = me.send(None)
- await foo
- me = g()
- with self.assertRaisesRegex(ValueError,
- "coroutine already executing"):
- me.send(None)
- def test_func_13(self):
- async def g():
- pass
- with self.assertRaisesRegex(
- TypeError,
- "can't send non-None value to a just-started coroutine"):
- g().send('spam')
- def test_func_14(self):
- @types.coroutine
- def gen():
- yield
- async def coro():
- try:
- await gen()
- except GeneratorExit:
- await gen()
- c = coro()
- c.send(None)
- with self.assertRaisesRegex(RuntimeError,
- "coroutine ignored GeneratorExit"):
- c.close()
- def test_func_15(self):
- # See http://bugs.python.org/issue25887 for details
- async def spammer():
- return 'spam'
- async def reader(coro):
- return await coro
- spammer_coro = spammer()
- with self.assertRaisesRegex(StopIteration, 'spam'):
- reader(spammer_coro).send(None)
- with self.assertRaisesRegex(RuntimeError,
- 'cannot reuse already awaited coroutine'):
- reader(spammer_coro).send(None)
- def test_func_16(self):
- # See http://bugs.python.org/issue25887 for details
- @types.coroutine
- def nop():
- yield
- async def send():
- await nop()
- return 'spam'
- async def read(coro):
- await nop()
- return await coro
- spammer = send()
- reader = read(spammer)
- reader.send(None)
- reader.send(None)
- with self.assertRaisesRegex(Exception, 'ham'):
- reader.throw(Exception('ham'))
- reader = read(spammer)
- reader.send(None)
- with self.assertRaisesRegex(RuntimeError,
- 'cannot reuse already awaited coroutine'):
- reader.send(None)
- with self.assertRaisesRegex(RuntimeError,
- 'cannot reuse already awaited coroutine'):
- reader.throw(Exception('wat'))
- def test_func_17(self):
- # See http://bugs.python.org/issue25887 for details
- async def coroutine():
- return 'spam'
- coro = coroutine()
- with self.assertRaisesRegex(StopIteration, 'spam'):
- coro.send(None)
- with self.assertRaisesRegex(RuntimeError,
- 'cannot reuse already awaited coroutine'):
- coro.send(None)
- with self.assertRaisesRegex(RuntimeError,
- 'cannot reuse already awaited coroutine'):
- coro.throw(Exception('wat'))
- # Closing a coroutine shouldn't raise any exception even if it's
- # already closed/exhausted (similar to generators)
- coro.close()
- coro.close()
- def test_func_18(self):
- # See http://bugs.python.org/issue25887 for details
- async def coroutine():
- return 'spam'
- coro = coroutine()
- await_iter = coro.__await__()
- it = iter(await_iter)
- with self.assertRaisesRegex(StopIteration, 'spam'):
- it.send(None)
- with self.assertRaisesRegex(RuntimeError,
- 'cannot reuse already awaited coroutine'):
- it.send(None)
- with self.assertRaisesRegex(RuntimeError,
- 'cannot reuse already awaited coroutine'):
- # Although the iterator protocol requires iterators to
- # raise another StopIteration here, we don't want to do
- # that. In this particular case, the iterator will raise
- # a RuntimeError, so that 'yield from' and 'await'
- # expressions will trigger the error, instead of silently
- # ignoring the call.
- next(it)
- with self.assertRaisesRegex(RuntimeError,
- 'cannot reuse already awaited coroutine'):
- it.throw(Exception('wat'))
- with self.assertRaisesRegex(RuntimeError,
- 'cannot reuse already awaited coroutine'):
- it.throw(Exception('wat'))
- # Closing a coroutine shouldn't raise any exception even if it's
- # already closed/exhausted (similar to generators)
- it.close()
- it.close()
- def test_func_19(self):
- CHK = 0
- @types.coroutine
- def foo():
- nonlocal CHK
- yield
- try:
- yield
- except GeneratorExit:
- CHK += 1
- async def coroutine():
- await foo()
- coro = coroutine()
- coro.send(None)
- coro.send(None)
- self.assertEqual(CHK, 0)
- coro.close()
- self.assertEqual(CHK, 1)
- for _ in range(3):
- # Closing a coroutine shouldn't raise any exception even if it's
- # already closed/exhausted (similar to generators)
- coro.close()
- self.assertEqual(CHK, 1)
- def test_cr_await(self):
- @types.coroutine
- def a():
- self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
- self.assertIsNone(coro_b.cr_await)
- yield
- self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
- self.assertIsNone(coro_b.cr_await)
- async def c():
- await a()
- async def b():
- self.assertIsNone(coro_b.cr_await)
- await c()
- self.assertIsNone(coro_b.cr_await)
- coro_b = b()
- self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CREATED)
- self.assertIsNone(coro_b.cr_await)
- coro_b.send(None)
- self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_SUSPENDED)
- self.assertEqual(coro_b.cr_await.cr_await.gi_code.co_name, 'a')
- with self.assertRaises(StopIteration):
- coro_b.send(None) # complete coroutine
- self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CLOSED)
- self.assertIsNone(coro_b.cr_await)
- def test_corotype_1(self):
- ct = types.CoroutineType
- self.assertIn('into coroutine', ct.send.__doc__)
- self.assertIn('inside coroutine', ct.close.__doc__)
- self.assertIn('in coroutine', ct.throw.__doc__)
- self.assertIn('of the coroutine', ct.__dict__['__name__'].__doc__)
- self.assertIn('of the coroutine', ct.__dict__['__qualname__'].__doc__)
- self.assertEqual(ct.__name__, 'coroutine')
- async def f(): pass
- c = f()
- self.assertIn('coroutine object', repr(c))
- c.close()
- def test_await_1(self):
- async def foo():
- await 1
- with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
- run_async(foo())
- def test_await_2(self):
- async def foo():
- await []
- with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
- run_async(foo())
- def test_await_3(self):
- async def foo():
- await AsyncYieldFrom([1, 2, 3])
- self.assertEqual(run_async(foo()), ([1, 2, 3], None))
- self.assertEqual(run_async__await__(foo()), ([1, 2, 3], None))
- def test_await_4(self):
- async def bar():
- return 42
- async def foo():
- return await bar()
- self.assertEqual(run_async(foo()), ([], 42))
- def test_await_5(self):
- class Awaitable:
- def __await__(self):
- return
- async def foo():
- return (await Awaitable())
- with self.assertRaisesRegex(
- TypeError, "__await__.*returned non-iterator of type"):
- run_async(foo())
- def test_await_6(self):
- class Awaitable:
- def __await__(self):
- return iter([52])
- async def foo():
- return (await Awaitable())
- self.assertEqual(run_async(foo()), ([52], None))
- def test_await_7(self):
- class Awaitable:
- def __await__(self):
- yield 42
- return 100
- async def foo():
- return (await Awaitable())
- self.assertEqual(run_async(foo()), ([42], 100))
- def test_await_8(self):
- class Awaitable:
- pass
- async def foo(): return await Awaitable()
- with self.assertRaisesRegex(
- TypeError, "object Awaitable can't be used in 'await' expression"):
- run_async(foo())
- def test_await_9(self):
- def wrap():
- return bar
- async def bar():
- return 42
- async def foo():
- b = bar()
- db = {'b': lambda: wrap}
- class DB:
- b = wrap
- return (await bar() + await wrap()() + await db['b']()()() +
- await bar() * 1000 + await DB.b()())
- async def foo2():
- return -await bar()
- self.assertEqual(run_async(foo()), ([], 42168))
- self.assertEqual(run_async(foo2()), ([], -42))
- def test_await_10(self):
- async def baz():
- return 42
- async def bar():
- return baz()
- async def foo():
- return await (await bar())
- self.assertEqual(run_async(foo()), ([], 42))
- def test_await_11(self):
- def ident(val):
- return val
- async def bar():
- return 'spam'
- async def foo():
- return ident(val=await bar())
- async def foo2():
- return await bar(), 'ham'
- self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
- def test_await_12(self):
- async def coro():
- return 'spam'
- class Awaitable:
- def __await__(self):
- return coro()
- async def foo():
- return await Awaitable()
- with self.assertRaisesRegex(
- TypeError, "__await__\(\) returned a coroutine"):
- run_async(foo())
- def test_await_13(self):
- class Awaitable:
- def __await__(self):
- return self
- async def foo():
- return await Awaitable()
- with self.assertRaisesRegex(
- TypeError, "__await__.*returned non-iterator of type"):
- run_async(foo())
- def test_await_14(self):
- class Wrapper:
- # Forces the interpreter to use CoroutineType.__await__
- def __init__(self, coro):
- assert coro.__class__ is types.CoroutineType
- self.coro = coro
- def __await__(self):
- return self.coro.__await__()
- class FutureLike:
- def __await__(self):
- return (yield)
- class Marker(Exception):
- pass
- async def coro1():
- try:
- return await FutureLike()
- except ZeroDivisionError:
- raise Marker
- async def coro2():
- return await Wrapper(coro1())
- c = coro2()
- c.send(None)
- with self.assertRaisesRegex(StopIteration, 'spam'):
- c.send('spam')
- c = coro2()
- c.send(None)
- with self.assertRaises(Marker):
- c.throw(ZeroDivisionError)
- def test_await_15(self):
- @types.coroutine
- def nop():
- yield
- async def coroutine():
- await nop()
- async def waiter(coro):
- await coro
- coro = coroutine()
- coro.send(None)
- with self.assertRaisesRegex(RuntimeError,
- "coroutine is being awaited already"):
- waiter(coro).send(None)
- def test_with_1(self):
- class Manager:
- def __init__(self, name):
- self.name = name
- async def __aenter__(self):
- await AsyncYieldFrom(['enter-1-' + self.name,
- 'enter-2-' + self.name])
- return self
- async def __aexit__(self, *args):
- await AsyncYieldFrom(['exit-1-' + self.name,
- 'exit-2-' + self.name])
- if self.name == 'B':
- return True
- async def foo():
- async with Manager("A") as a, Manager("B") as b:
- await AsyncYieldFrom([('managers', a.name, b.name)])
- 1/0
- f = foo()
- result, _ = run_async(f)
- self.assertEqual(
- result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
- ('managers', 'A', 'B'),
- 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
- )
- async def foo():
- async with Manager("A") as a, Manager("C") as c:
- await AsyncYieldFrom([('managers', a.name, c.name)])
- 1/0
- with self.assertRaises(ZeroDivisionError):
- run_async(foo())
- def test_with_2(self):
- class CM:
- def __aenter__(self):
- pass
- async def foo():
- async with CM():
- pass
- with self.assertRaisesRegex(AttributeError, '__aexit__'):
- run_async(foo())
- def test_with_3(self):
- class CM:
- def __aexit__(self):
- pass
- async def foo():
- async with CM():
- pass
- with self.assertRaisesRegex(AttributeError, '__aenter__'):
- run_async(foo())
- def test_with_4(self):
- class CM:
- def __enter__(self):
- pass
- def __exit__(self):
- pass
- async def foo():
- async with CM():
- pass
- with self.assertRaisesRegex(AttributeError, '__aexit__'):
- run_async(foo())
- def test_with_5(self):
- # While this test doesn't make a lot of sense,
- # it's a regression test for an early bug with opcodes
- # generation
- class CM:
- async def __aenter__(self):
- return self
- async def __aexit__(self, *exc):
- pass
- async def func():
- async with CM():
- assert (1, ) == 1
- with self.assertRaises(AssertionError):
- run_async(func())
- def test_with_6(self):
- class CM:
- def __aenter__(self):
- return 123
- def __aexit__(self, *e):
- return 456
- async def foo():
- async with CM():
- pass
- with self.assertRaisesRegex(
- TypeError, "object int can't be used in 'await' expression"):
- # it's important that __aexit__ wasn't called
- run_async(foo())
- def test_with_7(self):
- class CM:
- async def __aenter__(self):
- return self
- def __aexit__(self, *e):
- return 444
- async def foo():
- async with CM():
- 1/0
- try:
- run_async(foo())
- except TypeError as exc:
- self.assertRegex(
- exc.args[0], "object int can't be used in 'await' expression")
- self.assertTrue(exc.__context__ is not None)
- self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
- else:
- self.fail('invalid asynchronous context manager did not fail')
- def test_with_8(self):
- CNT = 0
- class CM:
- async def __aenter__(self):
- return self
- def __aexit__(self, *e):
- return 456
- async def foo():
- nonlocal CNT
- async with CM():
- CNT += 1
- with self.assertRaisesRegex(
- TypeError, "object int can't be used in 'await' expression"):
- run_async(foo())
- self.assertEqual(CNT, 1)
- def test_with_9(self):
- CNT = 0
- class CM:
- async def __aenter__(self):
- return self
- async def __aexit__(self, *e):
- 1/0
- async def foo():
- nonlocal CNT
- async with CM():
- CNT += 1
- with self.assertRaises(ZeroDivisionError):
- run_async(foo())
- self.assertEqual(CNT, 1)
- def test_with_10(self):
- CNT = 0
- class CM:
- async def __aenter__(self):
- return self
- async def __aexit__(self, *e):
- 1/0
- async def foo():
- nonlocal CNT
- async with CM():
- async with CM():
- raise RuntimeError
- try:
- run_async(foo())
- except ZeroDivisionError as exc:
- self.assertTrue(exc.__context__ is not None)
- self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
- self.assertTrue(isinstance(exc.__context__.__context__,
- RuntimeError))
- else:
- self.fail('exception from __aexit__ did not propagate')
- def test_with_11(self):
- CNT = 0
- class CM:
- async def __aenter__(self):
- raise NotImplementedError
- async def __aexit__(self, *e):
- 1/0
- async def foo():
- nonlocal CNT
- async with CM():
- raise RuntimeError
- try:
- run_async(foo())
- except NotImplementedError as exc:
- self.assertTrue(exc.__context__ is None)
- else:
- self.fail('exception from __aenter__ did not propagate')
- def test_with_12(self):
- CNT = 0
- class CM:
- async def __aenter__(self):
- return self
- async def __aexit__(self, *e):
- return True
- async def foo():
- nonlocal CNT
- async with CM() as cm:
- self.assertIs(cm.__class__, CM)
- raise RuntimeError
- run_async(foo())
- def test_with_13(self):
- CNT = 0
- class CM:
- async def __aenter__(self):
- 1/0
- async def __aexit__(self, *e):
- return True
- async def foo():
- nonlocal CNT
- CNT += 1
- async with CM():
- CNT += 1000
- CNT += 10000
- with self.assertRaises(ZeroDivisionError):
- run_async(foo())
- self.assertEqual(CNT, 1)
- def test_for_1(self):
- aiter_calls = 0
- class AsyncIter:
- def __init__(self):
- self.i = 0
- async def __aiter__(self):
- nonlocal aiter_calls
- aiter_calls += 1
- return self
- async def __anext__(self):
- self.i += 1
- if not (self.i % 10):
- await AsyncYield(self.i * 10)
- if self.i > 100:
- raise StopAsyncIteration
- return self.i, self.i
- buffer = []
- async def test1():
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
- async for i1, i2 in AsyncIter():
- buffer.append(i1 + i2)
- yielded, _ = run_async(test1())
- # Make sure that __aiter__ was called only once
- self.assertEqual(aiter_calls, 1)
- self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
- self.assertEqual(buffer, [i*2 for i in range(1, 101)])
- buffer = []
- async def test2():
- nonlocal buffer
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
- async for i in AsyncIter():
- buffer.append(i[0])
- if i[0] == 20:
- break
- else:
- buffer.append('what?')
- buffer.append('end')
- yielded, _ = run_async(test2())
- # Make sure that __aiter__ was called only once
- self.assertEqual(aiter_calls, 2)
- self.assertEqual(yielded, [100, 200])
- self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
- buffer = []
- async def test3():
- nonlocal buffer
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
- async for i in AsyncIter():
- if i[0] > 20:
- continue
- buffer.append(i[0])
- else:
- buffer.append('what?')
- buffer.append('end')
- yielded, _ = run_async(test3())
- # Make sure that __aiter__ was called only once
- self.assertEqual(aiter_calls, 3)
- self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
- self.assertEqual(buffer, [i for i in range(1, 21)] +
- ['what?', 'end'])
- def test_for_2(self):
- tup = (1, 2, 3)
- refs_before = sys.getrefcount(tup)
- async def foo():
- async for i in tup:
- print('never going to happen')
- with self.assertRaisesRegex(
- TypeError, "async for' requires an object.*__aiter__.*tuple"):
- run_async(foo())
- self.assertEqual(sys.getrefcount(tup), refs_before)
- def test_for_3(self):
- class I:
- def __aiter__(self):
- return self
- aiter = I()
- refs_before = sys.getrefcount(aiter)
- async def foo():
- async for i in aiter:
- print('never going to happen')
- with self.assertRaisesRegex(
- TypeError,
- "async for' received an invalid object.*__aiter.*\: I"):
- run_async(foo())
- self.assertEqual(sys.getrefcount(aiter), refs_before)
- def test_for_4(self):
- class I:
- def __aiter__(self):
- return self
- def __anext__(self):
- return ()
- aiter = I()
- refs_before = sys.getrefcount(aiter)
- async def foo():
- async for i in aiter:
- print('never going to happen')
- with self.assertRaisesRegex(
- TypeError,
- "async for' received an invalid object.*__anext__.*tuple"):
- run_async(foo())
- self.assertEqual(sys.getrefcount(aiter), refs_before)
- def test_for_5(self):
- class I:
- async def __aiter__(self):
- return self
- def __anext__(self):
- return 123
- async def foo():
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
- async for i in I():
- print('never going to happen')
- with self.assertRaisesRegex(
- TypeError,
- "async for' received an invalid object.*__anext.*int"):
- run_async(foo())
- def test_for_6(self):
- I = 0
- class Manager:
- async def __aenter__(self):
- nonlocal I
- I += 10000
- async def __aexit__(self, *args):
- nonlocal I
- I += 100000
- class Iterable:
- def __init__(self):
- self.i = 0
- def __aiter__(self):
- return self
- async def __anext__(self):
- if self.i > 10:
- raise StopAsyncIteration
- self.i += 1
- return self.i
- ##############
- manager = Manager()
- iterable = Iterable()
- mrefs_before = sys.getrefcount(manager)
- irefs_before = sys.getrefcount(iterable)
- async def main():
- nonlocal I
- async with manager:
- async for i in iterable:
- I += 1
- I += 1000
- with warnings.catch_warnings():
- warnings.simplefilter("error")
- # Test that __aiter__ that returns an asynchronous iterator
- # directly does not throw any warnings.
- run_async(main())
- self.assertEqual(I, 111011)
- self.assertEqual(sys.getrefcount(manager), mrefs_before)
- self.assertEqual(sys.getrefcount(iterable), irefs_before)
- ##############
- async def main():
- nonlocal I
- async with Manager():
- async for i in Iterable():
- I += 1
- I += 1000
- async with Manager():
- async for i in Iterable():
- I += 1
- I += 1000
- run_async(main())
- self.assertEqual(I, 333033)
- ##############
- async def main():
- nonlocal I
- async with Manager():
- I += 100
- async for i in Iterable():
- I += 1
- else:
- I += 10000000
- I += 1000
- async with Manager():
- I += 100
- async for i in Iterable():
- I += 1
- else:
- I += 10000000
- I += 1000
- run_async(main())
- self.assertEqual(I, 20555255)
- def test_for_7(self):
- CNT = 0
- class AI:
- async def __aiter__(self):
- 1/0
- async def foo():
- nonlocal CNT
- with self.assertWarnsRegex(PendingDeprecationWarning, "legacy"):
- async for i in AI():
- CNT += 1
- CNT += 10
- with self.assertRaises(ZeroDivisionError):
- run_async(foo())
- self.assertEqual(CNT, 0)
- def test_for_8(self):
- CNT = 0
- class AI:
- def __aiter__(self):
- 1/0
- async def foo():
- nonlocal CNT
- async for i in AI():
- CNT += 1
- CNT += 10
- with self.assertRaises(ZeroDivisionError):
- with warnings.catch_warnings():
- warnings.simplefilter("error")
- # Test that if __aiter__ raises an exception it propagates
- # without any kind of warning.
- run_async(foo())
- self.assertEqual(CNT, 0)
- def test_for_9(self):
- # Test that PendingDeprecationWarning can safely be converted into
- # an exception (__aiter__ should not have a chance to raise
- # a ZeroDivisionError.)
- class AI:
- async def __aiter__(self):
- 1/0
- async def foo():
- async for i in AI():
- pass
- with self.assertRaises(PendingDeprecationWarning):
- with warnings.catch_warnings():
- warnings.simplefilter("error")
- run_async(foo())
- def test_for_10(self):
- # Test that PendingDeprecationWarning can safely be converted into
- # an exception.
- class AI:
- async def __aiter__(self):
- pass
- async def foo():
- async for i in AI():
- pass
- with self.assertRaises(PendingDeprecationWarning):
- with warnings.catch_warnings():
- warnings.simplefilter("error")
- run_async(foo())
- def test_copy(self):
- async def func(): pass
- coro = func()
- with self.assertRaises(TypeError):
- copy.copy(coro)
- aw = coro.__await__()
- try:
- with self.assertRaises(TypeError):
- copy.copy(aw)
- finally:
- aw.close()
- def test_pickle(self):
- async def func(): pass
- coro = func()
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- with self.assertRaises((TypeError, pickle.PicklingError)):
- pickle.dumps(coro, proto)
- aw = coro.__await__()
- try:
- for proto in range(pickle.HIGHEST_PROTOCOL + 1):
- with self.assertRaises((TypeError, pickle.PicklingError)):
- pickle.dumps(aw, proto)
- finally:
- aw.close()
- def test_fatal_coro_warning(self):
- # Issue 27811
- async def func(): pass
- with warnings.catch_warnings():
- warnings.filterwarnings("error")
- func()
- support.gc_collect()
- class CoroAsyncIOCompatTest(unittest.TestCase):
- def test_asyncio_1(self):
- # asyncio cannot be imported when Python is compiled without thread
- # support
- asyncio = support.import_module('asyncio')
- class MyException(Exception):
- pass
- buffer = []
- class CM:
- async def __aenter__(self):
- buffer.append(1)
- await asyncio.sleep(0.01)
- buffer.append(2)
- return self
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- await asyncio.sleep(0.01)
- buffer.append(exc_type.__name__)
- async def f():
- async with CM() as c:
- await asyncio.sleep(0.01)
- raise MyException
- buffer.append('unreachable')
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(f())
- except MyException:
- pass
- finally:
- loop.close()
- asyncio.set_event_loop(None)
- self.assertEqual(buffer, [1, 2, 'MyException'])
- class SysSetCoroWrapperTest(unittest.TestCase):
- def test_set_wrapper_1(self):
- async def foo():
- return 'spam'
- wrapped = None
- def wrap(gen):
- nonlocal wrapped
- wrapped = gen
- return gen
- self.assertIsNone(sys.get_coroutine_wrapper())
- sys.set_coroutine_wrapper(wrap)
- self.assertIs(sys.get_coroutine_wrapper(), wrap)
- try:
- f = foo()
- self.assertTrue(wrapped)
- self.assertEqual(run_async(f), ([], 'spam'))
- finally:
- sys.set_coroutine_wrapper(None)
- self.assertIsNone(sys.get_coroutine_wrapper())
- wrapped = None
- with silence_coro_gc():
- foo()
- self.assertFalse(wrapped)
- def test_set_wrapper_2(self):
- self.assertIsNone(sys.get_coroutine_wrapper())
- with self.assertRaisesRegex(TypeError, "callable expected, got int"):
- sys.set_coroutine_wrapper(1)
- self.assertIsNone(sys.get_coroutine_wrapper())
- def test_set_wrapper_3(self):
- async def foo():
- return 'spam'
- def wrapper(coro):
- async def wrap(coro):
- return await coro
- return wrap(coro)
- sys.set_coroutine_wrapper(wrapper)
- try:
- with silence_coro_gc(), self.assertRaisesRegex(
- RuntimeError,
- "coroutine wrapper.*\.wrapper at 0x.*attempted to "
- "recursively wrap .* wrap .*"):
- foo()
- finally:
- sys.set_coroutine_wrapper(None)
- def test_set_wrapper_4(self):
- @types.coroutine
- def foo():
- return 'spam'
- wrapped = None
- def wrap(gen):
- nonlocal wrapped
- wrapped = gen
- return gen
- sys.set_coroutine_wrapper(wrap)
- try:
- foo()
- self.assertIs(
- wrapped, None,
- "generator-based coroutine was wrapped via "
- "sys.set_coroutine_wrapper")
- finally:
- sys.set_coroutine_wrapper(None)
- class CAPITest(unittest.TestCase):
- def test_tp_await_1(self):
- from _testcapi import awaitType as at
- async def foo():
- future = at(iter([1]))
- return (await future)
- self.assertEqual(foo().send(None), 1)
- def test_tp_await_2(self):
- # Test tp_await to __await__ mapping
- from _testcapi import awaitType as at
- future = at(iter([1]))
- self.assertEqual(next(future.__await__()), 1)
- def test_tp_await_3(self):
- from _testcapi import awaitType as at
- async def foo():
- future = at(1)
- return (await future)
- with self.assertRaisesRegex(
- TypeError, "__await__.*returned non-iterator of type 'int'"):
- self.assertEqual(foo().send(None), 1)
- if __name__=="__main__":
- unittest.main()