From b9e0bdcffc33e2966d2a88afa925b96d6137d138 Mon Sep 17 00:00:00 2001 From: Stefan Behnel <stefan_ml@behnel.de> Date: Sat, 23 May 2015 21:29:09 +0200 Subject: [PATCH] implement "async def" statement and "await" expression (PEP 492) --- Cython/Compiler/Code.py | 3 +- Cython/Compiler/ExprNodes.py | 39 +- Cython/Compiler/Nodes.py | 22 +- Cython/Compiler/ParseTreeTransforms.py | 53 +- Cython/Compiler/Parsing.pxd | 6 +- Cython/Compiler/Parsing.py | 86 +- Cython/Compiler/Scanning.pxd | 3 + Cython/Compiler/Scanning.py | 16 + Cython/Parser/Grammar | 6 +- Cython/Utility/Coroutine.c | 190 ++++- tests/run/test_coroutines_pep492.pyx | 1050 ++++++++++++++++++++++++ 11 files changed, 1389 insertions(+), 85 deletions(-) create mode 100644 tests/run/test_coroutines_pep492.pyx diff --git a/Cython/Compiler/Code.py b/Cython/Compiler/Code.py index 7719c3932..8aa64b3e2 100644 --- a/Cython/Compiler/Code.py +++ b/Cython/Compiler/Code.py @@ -49,7 +49,8 @@ non_portable_builtins_map = { 'basestring' : ('PY_MAJOR_VERSION >= 3', 'str'), 'xrange' : ('PY_MAJOR_VERSION >= 3', 'range'), 'raw_input' : ('PY_MAJOR_VERSION >= 3', 'input'), - } + 'StopAsyncIteration': ('PY_VERSION_HEX < 0x030500B1', 'StopIteration'), +} basicsize_builtins_map = { # builtins whose type has a different tp_basicsize than sizeof(...) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 3f1263bf3..3b84f9858 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -8593,10 +8593,12 @@ class YieldExprNode(ExprNode): type = py_object_type label_num = 0 is_yield_from = False + is_await = False + expr_keyword = 'yield' def analyse_types(self, env): if not self.label_num: - error(self.pos, "'yield' not supported here") + error(self.pos, "'%s' not supported here" % self.expr_keyword) self.is_temp = 1 if self.arg is not None: self.arg = self.arg.analyse_types(env) @@ -8661,6 +8663,7 @@ class YieldExprNode(ExprNode): class YieldFromExprNode(YieldExprNode): # "yield from GEN" expression is_yield_from = True + expr_keyword = 'yield from' def coerce_yield_argument(self, env): if not self.arg.type.is_string: @@ -8668,14 +8671,17 @@ class YieldFromExprNode(YieldExprNode): error(self.pos, "yielding from non-Python object not supported") self.arg = self.arg.coerce_to_pyobject(env) - def generate_evaluation_code(self, code): - code.globalstate.use_utility_code(UtilityCode.load_cached("YieldFrom", "Coroutine.c")) + def yield_from_func(self, code): + code.globalstate.use_utility_code(UtilityCode.load_cached("GeneratorYieldFrom", "Coroutine.c")) + return "__Pyx_Generator_Yield_From" + def generate_evaluation_code(self, code): self.arg.generate_evaluation_code(code) - code.putln("%s = __Pyx_Generator_Yield_From(%s, %s);" % ( + code.putln("%s = %s(%s, %s);" % ( Naming.retval_cname, + self.yield_from_func(code), Naming.generator_cname, - self.arg.result_as(py_object_type))) + self.arg.py_result())) self.arg.generate_disposal_code(code) self.arg.free_temps(code) code.put_xgotref(Naming.retval_cname) @@ -8687,9 +8693,7 @@ class YieldFromExprNode(YieldExprNode): if self.result_is_used: # YieldExprNode has allocated the result temp for us code.putln("%s = NULL;" % self.result()) - code.putln("if (unlikely(__Pyx_PyGen_FetchStopIterationValue(&%s) < 0)) %s" % ( - self.result(), - code.error_goto(self.pos))) + code.put_error_if_neg(self.pos, "__Pyx_PyGen_FetchStopIterationValue(&%s)" % self.result()) code.put_gotref(self.result()) else: code.putln("PyObject* exc_type = PyErr_Occurred();") @@ -8700,6 +8704,25 @@ class YieldFromExprNode(YieldExprNode): code.putln("}") code.putln("}") + +class AwaitExprNode(YieldFromExprNode): + # 'await' expression node + # + # arg ExprNode the Awaitable value to await + # label_num integer yield label number + # is_yield_from boolean is a YieldFromExprNode to delegate to another generator + is_await = True + expr_keyword = 'await' + + def coerce_yield_argument(self, env): + # FIXME: use same check as in YieldFromExprNode.coerce_yield_argument() ? + self.arg = self.arg.coerce_to_pyobject(env) + + def yield_from_func(self, code): + code.globalstate.use_utility_code(UtilityCode.load_cached("CoroutineYieldFrom", "Coroutine.c")) + return "__Pyx_Coroutine_Yield_From" + + class GlobalsExprNode(AtomicExprNode): type = dict_type is_temp = 1 diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 7930307b2..855974b0f 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -1574,9 +1574,11 @@ class FuncDefNode(StatNode, BlockNode): # pymethdef_required boolean Force Python method struct generation # directive_locals { string : ExprNode } locals defined by cython.locals(...) # directive_returns [ExprNode] type defined by cython.returns(...) - # star_arg PyArgDeclNode or None * argument - # starstar_arg PyArgDeclNode or None ** argument - + # star_arg PyArgDeclNode or None * argument + # starstar_arg PyArgDeclNode or None ** argument + # + # is_async_def boolean is a Coroutine function + # # has_fused_arguments boolean # Whether this cdef function has fused parameters. This is needed # by AnalyseDeclarationsTransform, so it can replace CFuncDefNodes @@ -1588,6 +1590,7 @@ class FuncDefNode(StatNode, BlockNode): pymethdef_required = False is_generator = False is_generator_body = False + is_async_def = False modifiers = [] has_fused_arguments = False star_arg = None @@ -3936,6 +3939,7 @@ class GeneratorDefNode(DefNode): # is_generator = True + is_coroutine = False needs_closure = True child_attrs = DefNode.child_attrs + ["gbody"] @@ -3956,8 +3960,9 @@ class GeneratorDefNode(DefNode): qualname = code.intern_identifier(self.qualname) code.putln('{') - code.putln('__pyx_CoroutineObject *gen = __Pyx_Generator_New(' + code.putln('__pyx_CoroutineObject *gen = __Pyx_%s_New(' '(__pyx_coroutine_body_t) %s, (PyObject *) %s, %s, %s); %s' % ( + 'Coroutine' if self.is_coroutine else 'Generator', body_cname, Naming.cur_scope_cname, name, qualname, code.error_goto_if_null('gen', self.pos))) code.put_decref(Naming.cur_scope_cname, py_object_type) @@ -3972,13 +3977,18 @@ class GeneratorDefNode(DefNode): code.putln('}') def generate_function_definitions(self, env, code): - env.use_utility_code(UtilityCode.load_cached("Generator", "Coroutine.c")) + env.use_utility_code(UtilityCode.load_cached( + 'Coroutine' if self.is_coroutine else 'Generator', "Coroutine.c")) self.gbody.generate_function_header(code, proto=True) super(GeneratorDefNode, self).generate_function_definitions(env, code) self.gbody.generate_function_definitions(env, code) +class AsyncDefNode(GeneratorDefNode): + is_coroutine = True + + class GeneratorBodyDefNode(DefNode): # Main code body of a generator implemented as a DefNode. # @@ -7108,7 +7118,7 @@ class GILStatNode(NogilTryFinallyStatNode): from .ParseTreeTransforms import YieldNodeCollector collector = YieldNodeCollector() collector.visitchildren(body) - if not collector.yields: + if not collector.yields and not collector.awaits: return if state == 'gil': diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index ed83d2dec..3594baa22 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -200,7 +200,7 @@ class PostParse(ScopeTrackingTransform): node.lambda_name = EncodedString(u'lambda%d' % lambda_id) collector = YieldNodeCollector() collector.visitchildren(node.result_expr) - if collector.yields or isinstance(node.result_expr, ExprNodes.YieldExprNode): + if collector.yields or collector.awaits or isinstance(node.result_expr, ExprNodes.YieldExprNode): body = Nodes.ExprStatNode( node.result_expr.pos, expr=node.result_expr) else: @@ -2205,6 +2205,7 @@ class YieldNodeCollector(TreeVisitor): def __init__(self): super(YieldNodeCollector, self).__init__() self.yields = [] + self.awaits = [] self.returns = [] self.has_return_value = False @@ -2215,6 +2216,10 @@ class YieldNodeCollector(TreeVisitor): self.yields.append(node) self.visitchildren(node) + def visit_AwaitExprNode(self, node): + self.awaits.append(node) + self.visitchildren(node) + def visit_ReturnStatNode(self, node): self.visitchildren(node) if node.value: @@ -2250,27 +2255,36 @@ class MarkClosureVisitor(CythonTransform): collector = YieldNodeCollector() collector.visitchildren(node) - if collector.yields: - if isinstance(node, Nodes.CFuncDefNode): - # Will report error later - return node - for i, yield_expr in enumerate(collector.yields, 1): - yield_expr.label_num = i - for retnode in collector.returns: - retnode.in_generator = True + if node.is_async_def: + if collector.yields: + error(collector.yields[0].pos, "'yield' not allowed in async coroutines (use 'await')") + yields = collector.awaits + elif collector.yields: + if collector.awaits: + error(collector.yields[0].pos, "'await' not allowed in generators (use 'yield')") + yields = collector.yields + else: + return node - gbody = Nodes.GeneratorBodyDefNode( - pos=node.pos, name=node.name, body=node.body) - generator = Nodes.GeneratorDefNode( - pos=node.pos, name=node.name, args=node.args, - star_arg=node.star_arg, starstar_arg=node.starstar_arg, - doc=node.doc, decorators=node.decorators, - gbody=gbody, lambda_name=node.lambda_name) - return generator - return node + for i, yield_expr in enumerate(yields, 1): + yield_expr.label_num = i + for retnode in collector.returns: + retnode.in_generator = True + + gbody = Nodes.GeneratorBodyDefNode( + pos=node.pos, name=node.name, body=node.body) + coroutine = (Nodes.AsyncDefNode if node.is_async_def else Nodes.GeneratorDefNode)( + pos=node.pos, name=node.name, args=node.args, + star_arg=node.star_arg, starstar_arg=node.starstar_arg, + doc=node.doc, decorators=node.decorators, + gbody=gbody, lambda_name=node.lambda_name) + return coroutine def visit_CFuncDefNode(self, node): - self.visit_FuncDefNode(node) + self.needs_closure = False + self.visitchildren(node) + node.needs_closure = self.needs_closure + self.needs_closure = True if node.needs_closure and node.overridable: error(node.pos, "closures inside cpdef functions not yet supported") return node @@ -2287,6 +2301,7 @@ class MarkClosureVisitor(CythonTransform): self.needs_closure = True return node + class CreateClosureClasses(CythonTransform): # Output closure classes in module scope for all functions # that really need it. diff --git a/Cython/Compiler/Parsing.pxd b/Cython/Compiler/Parsing.pxd index 9a96d1acc..a6348b1fc 100644 --- a/Cython/Compiler/Parsing.pxd +++ b/Cython/Compiler/Parsing.pxd @@ -44,6 +44,8 @@ cdef p_typecast(PyrexScanner s) cdef p_sizeof(PyrexScanner s) cdef p_yield_expression(PyrexScanner s) cdef p_yield_statement(PyrexScanner s) +cdef p_await_expression(PyrexScanner s) +cdef p_async_statement(PyrexScanner s, ctx) cdef p_power(PyrexScanner s) cdef p_new_expr(PyrexScanner s) cdef p_trailer(PyrexScanner s, node1) @@ -128,7 +130,7 @@ cdef p_IF_statement(PyrexScanner s, ctx) cdef p_statement(PyrexScanner s, ctx, bint first_statement = *) cdef p_statement_list(PyrexScanner s, ctx, bint first_statement = *) cdef p_suite(PyrexScanner s, ctx = *) -cdef tuple p_suite_with_docstring(PyrexScanner s, ctx, with_doc_only = *) +cdef tuple p_suite_with_docstring(PyrexScanner s, ctx, bint with_doc_only=*) cdef tuple _extract_docstring(node) cdef p_positional_and_keyword_args(PyrexScanner s, end_sy_set, templates = *) @@ -176,7 +178,7 @@ cdef p_c_modifiers(PyrexScanner s) cdef p_c_func_or_var_declaration(PyrexScanner s, pos, ctx) cdef p_ctypedef_statement(PyrexScanner s, ctx) cdef p_decorators(PyrexScanner s) -cdef p_def_statement(PyrexScanner s, list decorators = *) +cdef p_def_statement(PyrexScanner s, list decorators=*, bint is_async_def=*) cdef p_varargslist(PyrexScanner s, terminator=*, bint annotated = *) cdef p_py_arg_decl(PyrexScanner s, bint annotated = *) cdef p_class_statement(PyrexScanner s, decorators) diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 06259cd4d..7d179cb18 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -55,6 +55,7 @@ class Ctx(object): d.update(kwds) return ctx + def p_ident(s, message="Expected an identifier"): if s.sy == 'IDENT': name = s.systring @@ -350,6 +351,7 @@ def p_sizeof(s): s.expect(')') return node + def p_yield_expression(s): # s.sy == "yield" pos = s.position() @@ -370,19 +372,52 @@ def p_yield_expression(s): else: return ExprNodes.YieldExprNode(pos, arg=arg) + def p_yield_statement(s): # s.sy == "yield" yield_expr = p_yield_expression(s) return Nodes.ExprStatNode(yield_expr.pos, expr=yield_expr) -#power: atom trailer* ('**' factor)* + +def p_async_statement(s, ctx, decorators): + # s.sy >> 'async' ... + if s.sy == 'def': + # 'async def' statements aren't allowed in pxd files + if 'pxd' in ctx.level: + s.error('def statement not allowed here') + s.level = ctx.level + return p_def_statement(s, decorators, is_async_def=True) + elif decorators: + s.error("Decorators can only be followed by functions or classes") + elif s.sy == 'for': + #s.error("'async for' is not currently supported", fatal=False) + return p_statement(s, ctx) # TODO: implement + elif s.sy == 'with': + #s.error("'async with' is not currently supported", fatal=False) + return p_statement(s, ctx) # TODO: implement + else: + s.error("expected one of 'def', 'for', 'with' after 'async'") + + +def p_await_expression(s): + n1 = p_atom(s) + + +#power: atom_expr ('**' factor)* +#atom_expr: ['await'] atom trailer* def p_power(s): if s.systring == 'new' and s.peek()[0] == 'IDENT': return p_new_expr(s) + await_pos = None + if s.sy == 'await': + await_pos = s.position() + s.next() n1 = p_atom(s) while s.sy in ('(', '[', '.'): n1 = p_trailer(s, n1) + if await_pos: + n1 = ExprNodes.AwaitExprNode(await_pos, arg=n1) if s.sy == '**': pos = s.position() s.next() @@ -390,6 +425,7 @@ def p_power(s): n1 = ExprNodes.binop_node(pos, '**', n1, n2) return n1 + def p_new_expr(s): # s.systring == 'new'. pos = s.position() @@ -1929,12 +1965,14 @@ def p_statement(s, ctx, first_statement = 0): s.error('decorator not allowed here') s.level = ctx.level decorators = p_decorators(s) - bad_toks = 'def', 'cdef', 'cpdef', 'class' - if not ctx.allow_struct_enum_decorator and s.sy not in bad_toks: - s.error("Decorators can only be followed by functions or classes") + if not ctx.allow_struct_enum_decorator and s.sy not in ('def', 'cdef', 'cpdef', 'class'): + if s.sy == 'IDENT' and s.systring == 'async': + pass # handled below + else: + s.error("Decorators can only be followed by functions or classes") elif s.sy == 'pass' and cdef_flag: # empty cdef block - return p_pass_statement(s, with_newline = 1) + return p_pass_statement(s, with_newline=1) overridable = 0 if s.sy == 'cdef': @@ -1948,11 +1986,11 @@ def p_statement(s, ctx, first_statement = 0): if ctx.level not in ('module', 'module_pxd', 'function', 'c_class', 'c_class_pxd'): s.error('cdef statement not allowed here') s.level = ctx.level - node = p_cdef_statement(s, ctx(overridable = overridable)) + node = p_cdef_statement(s, ctx(overridable=overridable)) if decorators is not None: - tup = Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode + tup = (Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode) if ctx.allow_struct_enum_decorator: - tup += Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode + tup += (Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode) if not isinstance(node, tup): s.error("Decorators can only be followed by functions or classes") node.decorators = decorators @@ -1995,9 +2033,25 @@ def p_statement(s, ctx, first_statement = 0): return p_try_statement(s) elif s.sy == 'with': return p_with_statement(s) + elif s.sy == 'async': + s.next() + return p_async_statement(s, ctx, decorators) else: - return p_simple_statement_list( - s, ctx, first_statement = first_statement) + if s.sy == 'IDENT' and s.systring == 'async': + # PEP 492 enables the async/await keywords when it spots "async def ..." + s.next() + if s.sy == 'def': + s.enable_keyword('async') + s.enable_keyword('await') + result = p_async_statement(s, ctx, decorators) + s.enable_keyword('await') + s.disable_keyword('async') + return result + elif decorators: + s.error("Decorators can only be followed by functions or classes") + s.put_back('IDENT', 'async') + return p_simple_statement_list(s, ctx, first_statement=first_statement) + def p_statement_list(s, ctx, first_statement = 0): # Parse a series of statements separated by newlines. @@ -3002,7 +3056,8 @@ def p_decorators(s): s.expect_newline("Expected a newline after decorator") return decorators -def p_def_statement(s, decorators=None): + +def p_def_statement(s, decorators=None, is_async_def=False): # s.sy == 'def' pos = s.position() s.next() @@ -3017,10 +3072,11 @@ def p_def_statement(s, decorators=None): s.next() return_type_annotation = p_test(s) doc, body = p_suite_with_docstring(s, Ctx(level='function')) - return Nodes.DefNode(pos, name = name, args = args, - star_arg = star_arg, starstar_arg = starstar_arg, - doc = doc, body = body, decorators = decorators, - return_type_annotation = return_type_annotation) + return Nodes.DefNode( + pos, name=name, args=args, star_arg=star_arg, starstar_arg=starstar_arg, + doc=doc, body=body, decorators=decorators, is_async_def=is_async_def, + return_type_annotation=return_type_annotation) + def p_varargslist(s, terminator=')', annotated=1): args = p_c_arg_list(s, in_pyfunc = 1, nonempty_declarators = 1, diff --git a/Cython/Compiler/Scanning.pxd b/Cython/Compiler/Scanning.pxd index ccde720ef..54bf40fd4 100644 --- a/Cython/Compiler/Scanning.pxd +++ b/Cython/Compiler/Scanning.pxd @@ -30,6 +30,7 @@ cdef class PyrexScanner(Scanner): cdef public bint in_python_file cdef public source_encoding cdef set keywords + cdef public dict keywords_stack cdef public list indentation_stack cdef public indentation_char cdef public int bracket_nesting_level @@ -57,3 +58,5 @@ cdef class PyrexScanner(Scanner): cdef expect_indent(self) cdef expect_dedent(self) cdef expect_newline(self, message=*, bint ignore_semicolon=*) + cdef enable_keyword(self, name) + cdef disable_keyword(self, name) diff --git a/Cython/Compiler/Scanning.py b/Cython/Compiler/Scanning.py index 6170cea01..aeedf85fc 100644 --- a/Cython/Compiler/Scanning.py +++ b/Cython/Compiler/Scanning.py @@ -319,6 +319,7 @@ class PyrexScanner(Scanner): self.in_python_file = False self.keywords = set(pyx_reserved_words) self.trace = trace_scanner + self.keywords_stack = {} self.indentation_stack = [0] self.indentation_char = None self.bracket_nesting_level = 0 @@ -497,3 +498,18 @@ class PyrexScanner(Scanner): self.expect('NEWLINE', message) if useless_trailing_semicolon is not None: warning(useless_trailing_semicolon, "useless trailing semicolon") + + def enable_keyword(self, name): + if name in self.keywords_stack: + self.keywords_stack[name] += 1 + else: + self.keywords_stack[name] = 1 + self.keywords.add(name) + + def disable_keyword(self, name): + count = self.keywords_stack.get(name, 1) + if count == 1: + self.keywords.discard(name) + del self.keywords_stack[name] + else: + self.keywords_stack[name] = count - 1 diff --git a/Cython/Parser/Grammar b/Cython/Parser/Grammar index cb66a36b3..8ce663ee6 100644 --- a/Cython/Parser/Grammar +++ b/Cython/Parser/Grammar @@ -13,7 +13,8 @@ eval_input: testlist NEWLINE* ENDMARKER decorator: '@' dotted_PY_NAME [ '(' [arglist] ')' ] NEWLINE decorators: decorator+ -decorated: decorators (classdef | funcdef | cdef_stmt) +decorated: decorators (classdef | funcdef | async_funcdef | cdef_stmt) +async_funcdef: 'async' funcdef funcdef: 'def' PY_NAME parameters ['->' test] ':' suite parameters: '(' [typedargslist] ')' typedargslist: (tfpdef ['=' (test | '*')] (',' tfpdef ['=' (test | '*')])* [',' @@ -96,7 +97,8 @@ shift_expr: arith_expr (('<<'|'>>') arith_expr)* arith_expr: term (('+'|'-') term)* term: factor (('*'|'/'|'%'|'//') factor)* factor: ('+'|'-'|'~') factor | power | address | size_of | cast -power: atom trailer* ['**' factor] +power: atom_expr ['**' factor] +atom_expr: ['await'] atom trailer* atom: ('(' [yield_expr|testlist_comp] ')' | '[' [testlist_comp] ']' | '{' [dictorsetmaker] '}' | diff --git a/Cython/Utility/Coroutine.c b/Cython/Utility/Coroutine.c index 7dfa098e0..455781c8f 100644 --- a/Cython/Utility/Coroutine.c +++ b/Cython/Utility/Coroutine.c @@ -1,8 +1,8 @@ -//////////////////// YieldFrom.proto //////////////////// +//////////////////// GeneratorYieldFrom.proto //////////////////// static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject *gen, PyObject *source); -//////////////////// YieldFrom //////////////////// +//////////////////// GeneratorYieldFrom //////////////////// //@requires: Generator static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject *gen, PyObject *source) { @@ -21,6 +21,125 @@ static CYTHON_INLINE PyObject* __Pyx_Generator_Yield_From(__pyx_CoroutineObject } +//////////////////// CoroutineYieldFrom.proto //////////////////// + +static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source); + +//////////////////// CoroutineYieldFrom //////////////////// +//@requires: Coroutine +//@requires: GetAwaitIter + +static CYTHON_INLINE PyObject* __Pyx_Coroutine_Yield_From(__pyx_CoroutineObject *gen, PyObject *source) { + PyObject *retval; + if (__Pyx_Coroutine_CheckExact(source)) { + retval = __Pyx_Generator_Next(source); + if (retval) { + Py_INCREF(source); + gen->yieldfrom = source; + return retval; + } + } else { + PyObject *source_gen = __Pyx__Coroutine_GetAwaitableIter(source); + if (unlikely(!source_gen)) + return NULL; + // source_gen is now the iterator, make the first next() call + if (__Pyx_Coroutine_CheckExact(source_gen)) { + retval = __Pyx_Generator_Next(source_gen); + } else { + retval = Py_TYPE(source_gen)->tp_iternext(source_gen); + } + if (retval) { + gen->yieldfrom = source_gen; + return retval; + } + Py_DECREF(source_gen); + } + return NULL; +} + + +//////////////////// GetAwaitIter.proto //////////////////// + +static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAwaitableIter(PyObject *o); /*proto*/ +static PyObject *__Pyx__Coroutine_GetAwaitableIter(PyObject *o); /*proto*/ + +//////////////////// GetAwaitIter //////////////////// +//@requires: Coroutine +//@requires: ObjectHandling.c::PyObjectGetAttrStr +//@requires: ObjectHandling.c::PyObjectCallNoArg +//@requires: ObjectHandling.c::PyObjectCallOneArg + +static CYTHON_INLINE PyObject *__Pyx_Coroutine_GetAwaitableIter(PyObject *o) { +#ifdef __Pyx_Coroutine_USED + if (__Pyx_Coroutine_CheckExact(o)) { + Py_INCREF(o); + return o; + } +#endif + return __Pyx__Coroutine_GetAwaitableIter(o); +} + +// copied and adapted from genobject.c in Py3.5 +static PyObject *__Pyx__Coroutine_GetAwaitableIter(PyObject *o) { + PyObject *res; +#if PY_VERSION_HEX >= 0x030500B1 + unaryfunc getter = NULL; + PyTypeObject *ot; + + ot = Py_TYPE(o); + if (likely(ot->tp_as_async)) { + getter = (unaryfunc) ot->tp_as_async->am_await; + } + if (unlikely(getter)) goto slot_error; + res = (*getter)(o); +#else + PyObject *method = __Pyx_PyObject_GetAttrStr(o, PYIDENT("__await__")); + if (unlikely(!method)) goto slot_error; + #if CYTHON_COMPILING_IN_CPYTHON + if (likely(PyMethod_Check(method))) { + PyObject *self = PyMethod_GET_SELF(method); + if (likely(self)) { + PyObject *function = PyMethod_GET_FUNCTION(method); + res = __Pyx_PyObject_CallOneArg(function, self); + } else + res = __Pyx_PyObject_CallNoArg(method); + } else + #endif + res = __Pyx_PyObject_CallNoArg(method); + Py_DECREF(method); +#endif + if (unlikely(!res)) goto bad; + if (!PyIter_Check(res)) { + PyErr_Format(PyExc_TypeError, + "__await__() returned non-iterator of type '%.100s'", + Py_TYPE(res)->tp_name); + Py_CLEAR(res); + } else { + int is_coroutine = 0; + #ifdef __Pyx_Coroutine_USED + is_coroutine |= __Pyx_Coroutine_CheckExact(res); + #endif + #if PY_VERSION_HEX >= 0x030500B1 + is_coroutine |= PyGen_CheckCoroutineExact(res); + #endif + if (unlikely(is_coroutine)) { + /* __await__ must return an *iterator*, not + a coroutine or another awaitable (see PEP 492) */ + PyErr_SetString(PyExc_TypeError, + "__await__() returned a coroutine"); + Py_CLEAR(res); + } + } + return res; +slot_error: + PyErr_Format(PyExc_TypeError, + "object %.100s can't be used in 'await' expression", + Py_TYPE(o)->tp_name); +bad: + return NULL; +} + + //////////////////// pep479.proto //////////////////// static void __Pyx_Generator_Replace_StopIteration(void); /*proto*/ @@ -107,6 +226,7 @@ static int __pyx_Generator_init(void); #include <structmember.h> #include <frameobject.h> +static PyObject *__Pyx_Generator_Next(PyObject *self); static PyObject *__Pyx_Coroutine_Send(PyObject *self, PyObject *value); static PyObject *__Pyx_Coroutine_Close(PyObject *self); static PyObject *__Pyx_Coroutine_Throw(PyObject *gen, PyObject *args); @@ -403,6 +523,28 @@ static int __Pyx_Coroutine_CloseIter(__pyx_CoroutineObject *gen, PyObject *yf) { return err; } +static PyObject *__Pyx_Generator_Next(PyObject *self) { + __pyx_CoroutineObject *gen = (__pyx_CoroutineObject*) self; + PyObject *yf = gen->yieldfrom; + if (unlikely(__Pyx_Coroutine_CheckRunning(gen))) + return NULL; + if (yf) { + PyObject *ret; + // FIXME: does this really need an INCREF() ? + //Py_INCREF(yf); + // YieldFrom code ensures that yf is an iterator + gen->is_running = 1; + ret = Py_TYPE(yf)->tp_iternext(yf); + gen->is_running = 0; + //Py_DECREF(yf); + if (likely(ret)) { + return ret; + } + return __Pyx_Coroutine_FinishDelegation(gen); + } + return __Pyx_Coroutine_SendEx(gen, Py_None); +} + static PyObject *__Pyx_Coroutine_Close(PyObject *self) { __pyx_CoroutineObject *gen = (__pyx_CoroutineObject *) self; PyObject *retval, *raised_exception; @@ -729,6 +871,14 @@ static __pyx_CoroutineObject *__Pyx__Coroutine_New(PyTypeObject* type, __pyx_cor //@requires: CoroutineBase //@requires: PatchGeneratorABC +#if PY_VERSION_HEX >= 0x030500B1 +static PyAsyncMethods __pyx_Coroutine_as_async { + 0, /*am_await*/ + 0, /*am_aiter*/ + 0, /*am_anext*/ +} +#endif + static PyTypeObject __pyx_CoroutineType_type = { PyVarObject_HEAD_INIT(0, 0) "coroutine", /*tp_name*/ @@ -738,10 +888,10 @@ static PyTypeObject __pyx_CoroutineType_type = { 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ -#if PY_MAJOR_VERSION < 3 - 0, /*tp_compare*/ +#if PY_VERSION_HEX >= 0x030500B1 + __pyx_Coroutine_as_async, /*tp_as_async*/ #else - 0, /*reserved*/ + 0, /*tp_reserved resp. tp_compare*/ #endif 0, /*tp_repr*/ 0, /*tp_as_number*/ @@ -759,8 +909,9 @@ static PyTypeObject __pyx_CoroutineType_type = { 0, /*tp_clear*/ 0, /*tp_richcompare*/ offsetof(__pyx_CoroutineObject, gi_weakreflist), /*tp_weaklistoffset*/ +// no tp_iter() as iterator is only available through __await__() 0, /*tp_iter*/ - 0, /*tp_iternext*/ + (iternextfunc) __Pyx_Generator_Next, /*tp_iternext*/ __pyx_Coroutine_methods, /*tp_methods*/ __pyx_Coroutine_memberlist, /*tp_members*/ __pyx_Coroutine_getsets, /*tp_getset*/ @@ -790,7 +941,7 @@ static PyTypeObject __pyx_CoroutineType_type = { #endif }; -static int __pyx_Generator_init(void) { +static int __pyx_Coroutine_init(void) { // on Windows, C-API functions can't be used in slots statically __pyx_CoroutineType_type.tp_getattro = PyObject_GenericGetAttr; @@ -801,35 +952,10 @@ static int __pyx_Generator_init(void) { return 0; } - //////////////////// Generator //////////////////// //@requires: CoroutineBase //@requires: PatchGeneratorABC -static PyObject *__Pyx_Generator_Next(PyObject *self); - -static PyObject *__Pyx_Generator_Next(PyObject *self) { - __pyx_CoroutineObject *gen = (__pyx_CoroutineObject*) self; - PyObject *yf = gen->yieldfrom; - if (unlikely(__Pyx_Coroutine_CheckRunning(gen))) - return NULL; - if (yf) { - PyObject *ret; - // FIXME: does this really need an INCREF() ? - //Py_INCREF(yf); - // YieldFrom code ensures that yf is an iterator - gen->is_running = 1; - ret = Py_TYPE(yf)->tp_iternext(yf); - gen->is_running = 0; - //Py_DECREF(yf); - if (likely(ret)) { - return ret; - } - return __Pyx_Coroutine_FinishDelegation(gen); - } - return __Pyx_Coroutine_SendEx(gen, Py_None); -} - static PyTypeObject __pyx_GeneratorType_type = { PyVarObject_HEAD_INIT(0, 0) "generator", /*tp_name*/ diff --git a/tests/run/test_coroutines_pep492.pyx b/tests/run/test_coroutines_pep492.pyx new file mode 100644 index 000000000..1369e1dd2 --- /dev/null +++ b/tests/run/test_coroutines_pep492.pyx @@ -0,0 +1,1050 @@ +# cython: language_level=3, binding=True + +import gc +import sys +import types +import inspect +import unittest +import warnings +import contextlib + + +class AsyncYieldFrom: + def __init__(self, obj): + self.obj = obj + + def __await__(self): + yield from self.obj + + +class AsyncYield: + def __init__(self, value): + self.value = value + + def __await__(self): + yield self.value + + +def run_async(coro): + #assert coro.__class__ is types.GeneratorType + assert coro.__class__.__name__ == 'coroutine' + + buffer = [] + result = None + while True: + try: + buffer.append(coro.send(None)) + except StopIteration as ex: + result = ex.args[0] if ex.args else None + break + return buffer, result + + +@contextlib.contextmanager +def silence_coro_gc(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + yield + gc.collect() + + +class AsyncBadSyntaxTest(unittest.TestCase): + + def test_badsyntax_1(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async1 + + def test_badsyntax_2(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async2 + + def test_badsyntax_3(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async3 + + def test_badsyntax_4(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async4 + + def test_badsyntax_5(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async5 + + def test_badsyntax_6(self): + with self.assertRaisesRegex( + SyntaxError, "'yield' inside async function"): + + import test.badsyntax_async6 + + def test_badsyntax_7(self): + with self.assertRaisesRegex( + SyntaxError, "'yield from' inside async function"): + + import test.badsyntax_async7 + + def test_badsyntax_8(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async8 + + def test_badsyntax_9(self): + with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): + import test.badsyntax_async9 + + +class TokenizerRegrTest(unittest.TestCase): + + def test_oneline_defs(self): + buf = [] + for i in range(500): + buf.append('def i{i}(): return {i}'.format(i=i)) + buf = '\n'.join(buf) + + # Test that 500 consequent, one-line defs is OK + ns = {} + exec(buf, ns, ns) + self.assertEqual(ns['i499'](), 499) + + # Test that 500 consequent, one-line defs *and* + # one 'async def' following them is OK + buf += '\nasync def foo():\n return' + ns = {} + exec(buf, ns, ns) + self.assertEqual(ns['i499'](), 499) + self.assertTrue(inspect.iscoroutinefunction(ns['foo'])) + + +class CoroutineTest(unittest.TestCase): + + @contextlib.contextmanager + def assertRaisesRegex(self, exc_type, regex): + # the error messages usually don't match, so we just ignore them + try: + yield + except exc_type: + self.assertTrue(True) + else: + self.assertTrue(False) + + def test_gen_1(self): + def gen(): yield + self.assertFalse(hasattr(gen, '__await__')) + + def test_func_1(self): + async def foo(): + return 10 + + f = foo() + self.assertEqual(f.__class__.__name__, 'coroutine') + #self.assertIsInstance(f, types.GeneratorType) + #self.assertTrue(bool(foo.__code__.co_flags & 0x80)) + #self.assertTrue(bool(foo.__code__.co_flags & 0x20)) + #self.assertTrue(bool(f.gi_code.co_flags & 0x80)) + #self.assertTrue(bool(f.gi_code.co_flags & 0x20)) + self.assertEqual(run_async(f), ([], 10)) + + def bar(): pass + self.assertFalse(bool(bar.__code__.co_flags & 0x80)) + + def test_func_2(self): + async def foo(): + raise StopIteration + + with self.assertRaisesRegex( + RuntimeError, "generator raised StopIteration"): + + run_async(foo()) + + def test_func_3(self): + async def foo(): + raise StopIteration + + with silence_coro_gc(): + self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$') + + def test_func_4(self): + async def foo(): + raise StopIteration + + check = lambda: self.assertRaisesRegex( + TypeError, "coroutine-objects do not support iteration") + + with check(): + list(foo()) + + with check(): + tuple(foo()) + + with check(): + sum(foo()) + + with check(): + iter(foo()) + + with check(): + next(foo()) + + with silence_coro_gc(), check(): + for i in foo(): + pass + + with silence_coro_gc(), check(): + [i for i in foo()] + + def test_func_5(self): + @types.coroutine + def bar(): + yield 1 + + async def foo(): + await bar() + + check = lambda: self.assertRaisesRegex( + TypeError, "coroutine-objects do not support iteration") + + with check(): + for el in foo(): pass + + # the following should pass without an error + for el in bar(): + self.assertEqual(el, 1) + self.assertEqual([el for el in bar()], [1]) + self.assertEqual(tuple(bar()), (1,)) + self.assertEqual(next(iter(bar())), 1) + + def test_func_6(self): + @types.coroutine + def bar(): + yield 1 + yield 2 + + async def foo(): + await bar() + + f = foo() + self.assertEqual(f.send(None), 1) + self.assertEqual(f.send(None), 2) + with self.assertRaises(StopIteration): + f.send(None) + + def test_func_7(self): + async def bar(): + return 10 + + def foo(): + yield from bar() + + with silence_coro_gc(), self.assertRaisesRegex( + TypeError, + "cannot 'yield from' a coroutine object from a generator"): + + list(foo()) + + def test_func_8(self): + @types.coroutine + def bar(): + return (yield from foo()) + + async def foo(): + return 'spam' + + self.assertEqual(run_async(bar()), ([], 'spam') ) + + def test_func_9(self): + async def foo(): pass + + with self.assertWarnsRegex( + RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"): + + foo() + gc.collect() + + def test_await_1(self): + + async def foo(): + await 1 + with self.assertRaisesRegex(TypeError, "object int can.t.*await"): + run_async(foo()) + + def test_await_2(self): + async def foo(): + await [] + with self.assertRaisesRegex(TypeError, "object list can.t.*await"): + run_async(foo()) + + def test_await_3(self): + async def foo(): + await AsyncYieldFrom([1, 2, 3]) + + self.assertEqual(run_async(foo()), ([1, 2, 3], None)) + + def test_await_4(self): + async def bar(): + return 42 + + async def foo(): + return await bar() + + self.assertEqual(run_async(foo()), ([], 42)) + + def test_await_5(self): + class Awaitable: + def __await__(self): + return + + async def foo(): + return (await Awaitable()) + + with self.assertRaisesRegex( + TypeError, "__await__.*returned non-iterator of type"): + + run_async(foo()) + + def test_await_6(self): + class Awaitable: + def __await__(self): + return iter([52]) + + async def foo(): + return (await Awaitable()) + + self.assertEqual(run_async(foo()), ([52], None)) + + def test_await_7(self): + class Awaitable: + def __await__(self): + yield 42 + return 100 + + async def foo(): + return (await Awaitable()) + + self.assertEqual(run_async(foo()), ([42], 100)) + + def test_await_8(self): + class Awaitable: + pass + + async def foo(): + return (await Awaitable()) + + with self.assertRaisesRegex( + TypeError, "object Awaitable can't be used in 'await' expression"): + + run_async(foo()) + + def test_await_9(self): + def wrap(): + return bar + + async def bar(): + return 42 + + async def foo(): + b = bar() + + db = {'b': lambda: wrap} + + class DB: + b = staticmethod(wrap) + + return (await bar() + await wrap()() + await db['b']()()() + + await bar() * 1000 + await DB.b()()) + + async def foo2(): + return -await bar() + + self.assertEqual(run_async(foo()), ([], 42168)) + self.assertEqual(run_async(foo2()), ([], -42)) + + def test_await_10(self): + async def baz(): + return 42 + + async def bar(): + return baz() + + async def foo(): + return await (await bar()) + + self.assertEqual(run_async(foo()), ([], 42)) + + def test_await_11(self): + def ident(val): + return val + + async def bar(): + return 'spam' + + async def foo(): + return ident(val=await bar()) + + async def foo2(): + return await bar(), 'ham' + + self.assertEqual(run_async(foo2()), ([], ('spam', 'ham'))) + + def test_await_12(self): + async def coro(): + return 'spam' + + class Awaitable: + def __await__(self): + return coro() + + async def foo(): + return await Awaitable() + + with self.assertRaisesRegex( + TypeError, "__await__\(\) returned a coroutine"): + + run_async(foo()) + + def test_await_13(self): + class Awaitable: + def __await__(self): + return self + + async def foo(): + return await Awaitable() + + with self.assertRaisesRegex( + TypeError, "__await__.*returned non-iterator of type"): + + run_async(foo()) + + def test_with_1(self): + class Manager: + def __init__(self, name): + self.name = name + + async def __aenter__(self): + await AsyncYieldFrom(['enter-1-' + self.name, + 'enter-2-' + self.name]) + return self + + async def __aexit__(self, *args): + await AsyncYieldFrom(['exit-1-' + self.name, + 'exit-2-' + self.name]) + + if self.name == 'B': + return True + + + async def foo(): + async with Manager("A") as a, Manager("B") as b: + await AsyncYieldFrom([('managers', a.name, b.name)]) + 1/0 + + f = foo() + result, _ = run_async(f) + + self.assertEqual( + result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B', + ('managers', 'A', 'B'), + 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A'] + ) + + async def foo(): + async with Manager("A") as a, Manager("C") as c: + await AsyncYieldFrom([('managers', a.name, c.name)]) + 1/0 + + with self.assertRaises(ZeroDivisionError): + run_async(foo()) + + def test_with_2(self): + class CM: + def __aenter__(self): + pass + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex(AttributeError, '__aexit__'): + run_async(foo()) + + def test_with_3(self): + class CM: + def __aexit__(self): + pass + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex(AttributeError, '__aenter__'): + run_async(foo()) + + def test_with_4(self): + class CM: + def __enter__(self): + pass + + def __exit__(self): + pass + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex(AttributeError, '__aexit__'): + run_async(foo()) + + def test_with_5(self): + # While this test doesn't make a lot of sense, + # it's a regression test for an early bug with opcodes + # generation + + class CM: + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + pass + + async def func(): + async with CM(): + assert (1, ) == 1 + + with self.assertRaises(AssertionError): + run_async(func()) + + def test_with_6(self): + class CM: + def __aenter__(self): + return 123 + + def __aexit__(self, *e): + return 456 + + async def foo(): + async with CM(): + pass + + with self.assertRaisesRegex( + TypeError, "object int can't be used in 'await' expression"): + # it's important that __aexit__ wasn't called + run_async(foo()) + + def test_with_7(self): + class CM: + async def __aenter__(self): + return self + + def __aexit__(self, *e): + return 444 + + async def foo(): + async with CM(): + 1/0 + + try: + run_async(foo()) + except TypeError as exc: + self.assertRegex( + exc.args[0], "object int can't be used in 'await' expression") + self.assertTrue(exc.__context__ is not None) + self.assertTrue(isinstance(exc.__context__, ZeroDivisionError)) + else: + self.fail('invalid asynchronous context manager did not fail') + + + def test_with_8(self): + CNT = 0 + + class CM: + async def __aenter__(self): + return self + + def __aexit__(self, *e): + return 456 + + async def foo(): + nonlocal CNT + async with CM(): + CNT += 1 + + + with self.assertRaisesRegex( + TypeError, "object int can't be used in 'await' expression"): + + run_async(foo()) + + self.assertEqual(CNT, 1) + + + def test_with_9(self): + CNT = 0 + + class CM: + async def __aenter__(self): + return self + + async def __aexit__(self, *e): + 1/0 + + async def foo(): + nonlocal CNT + async with CM(): + CNT += 1 + + with self.assertRaises(ZeroDivisionError): + run_async(foo()) + + self.assertEqual(CNT, 1) + + def test_with_10(self): + CNT = 0 + + class CM: + async def __aenter__(self): + return self + + async def __aexit__(self, *e): + 1/0 + + async def foo(): + nonlocal CNT + async with CM(): + async with CM(): + raise RuntimeError + + try: + run_async(foo()) + except ZeroDivisionError as exc: + self.assertTrue(exc.__context__ is not None) + self.assertTrue(isinstance(exc.__context__, ZeroDivisionError)) + self.assertTrue(isinstance(exc.__context__.__context__, + RuntimeError)) + else: + self.fail('exception from __aexit__ did not propagate') + + def test_with_11(self): + CNT = 0 + + class CM: + async def __aenter__(self): + raise NotImplementedError + + async def __aexit__(self, *e): + 1/0 + + async def foo(): + nonlocal CNT + async with CM(): + raise RuntimeError + + try: + run_async(foo()) + except NotImplementedError as exc: + self.assertTrue(exc.__context__ is None) + else: + self.fail('exception from __aenter__ did not propagate') + + def test_with_12(self): + CNT = 0 + + class CM: + async def __aenter__(self): + return self + + async def __aexit__(self, *e): + return True + + async def foo(): + nonlocal CNT + async with CM() as cm: + self.assertIs(cm.__class__, CM) + raise RuntimeError + + run_async(foo()) + + def test_with_13(self): + CNT = 0 + + class CM: + async def __aenter__(self): + 1/0 + + async def __aexit__(self, *e): + return True + + async def foo(): + nonlocal CNT + CNT += 1 + async with CM(): + CNT += 1000 + CNT += 10000 + + with self.assertRaises(ZeroDivisionError): + run_async(foo()) + self.assertEqual(CNT, 1) + + def test_for_1(self): + aiter_calls = 0 + + class AsyncIter: + def __init__(self): + self.i = 0 + + async def __aiter__(self): + nonlocal aiter_calls + aiter_calls += 1 + return self + + async def __anext__(self): + self.i += 1 + + if not (self.i % 10): + await AsyncYield(self.i * 10) + + if self.i > 100: + raise StopAsyncIteration + + return self.i, self.i + + + buffer = [] + async def test1(): + async for i1, i2 in AsyncIter(): + buffer.append(i1 + i2) + + yielded, _ = run_async(test1()) + # Make sure that __aiter__ was called only once + self.assertEqual(aiter_calls, 1) + self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) + self.assertEqual(buffer, [i*2 for i in range(1, 101)]) + + + buffer = [] + async def test2(): + nonlocal buffer + async for i in AsyncIter(): + buffer.append(i[0]) + if i[0] == 20: + break + else: + buffer.append('what?') + buffer.append('end') + + yielded, _ = run_async(test2()) + # Make sure that __aiter__ was called only once + self.assertEqual(aiter_calls, 2) + self.assertEqual(yielded, [100, 200]) + self.assertEqual(buffer, [i for i in range(1, 21)] + ['end']) + + + buffer = [] + async def test3(): + nonlocal buffer + async for i in AsyncIter(): + if i[0] > 20: + continue + buffer.append(i[0]) + else: + buffer.append('what?') + buffer.append('end') + + yielded, _ = run_async(test3()) + # Make sure that __aiter__ was called only once + self.assertEqual(aiter_calls, 3) + self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) + self.assertEqual(buffer, [i for i in range(1, 21)] + + ['what?', 'end']) + + def test_for_2(self): + tup = (1, 2, 3) + refs_before = sys.getrefcount(tup) + + async def foo(): + async for i in tup: + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, "async for' requires an object.*__aiter__.*tuple"): + + run_async(foo()) + + self.assertEqual(sys.getrefcount(tup), refs_before) + + def test_for_3(self): + class I: + def __aiter__(self): + return self + + aiter = I() + refs_before = sys.getrefcount(aiter) + + async def foo(): + async for i in aiter: + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, + "async for' received an invalid object.*__aiter.*\: I"): + + run_async(foo()) + + self.assertEqual(sys.getrefcount(aiter), refs_before) + + def test_for_4(self): + class I: + async def __aiter__(self): + return self + + def __anext__(self): + return () + + aiter = I() + refs_before = sys.getrefcount(aiter) + + async def foo(): + async for i in aiter: + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, + "async for' received an invalid object.*__anext__.*tuple"): + + run_async(foo()) + + self.assertEqual(sys.getrefcount(aiter), refs_before) + + def test_for_5(self): + class I: + async def __aiter__(self): + return self + + def __anext__(self): + return 123 + + async def foo(): + async for i in I(): + print('never going to happen') + + with self.assertRaisesRegex( + TypeError, + "async for' received an invalid object.*__anext.*int"): + + run_async(foo()) + + def test_for_6(self): + I = 0 + + class Manager: + async def __aenter__(self): + nonlocal I + I += 10000 + + async def __aexit__(self, *args): + nonlocal I + I += 100000 + + class Iterable: + def __init__(self): + self.i = 0 + + async def __aiter__(self): + return self + + async def __anext__(self): + if self.i > 10: + raise StopAsyncIteration + self.i += 1 + return self.i + + ############## + + manager = Manager() + iterable = Iterable() + mrefs_before = sys.getrefcount(manager) + irefs_before = sys.getrefcount(iterable) + + async def main(): + nonlocal I + + async with manager: + async for i in iterable: + I += 1 + I += 1000 + + run_async(main()) + self.assertEqual(I, 111011) + + self.assertEqual(sys.getrefcount(manager), mrefs_before) + self.assertEqual(sys.getrefcount(iterable), irefs_before) + + ############## + + async def main(): + nonlocal I + + async with Manager(): + async for i in Iterable(): + I += 1 + I += 1000 + + async with Manager(): + async for i in Iterable(): + I += 1 + I += 1000 + + run_async(main()) + self.assertEqual(I, 333033) + + ############## + + async def main(): + nonlocal I + + async with Manager(): + I += 100 + async for i in Iterable(): + I += 1 + else: + I += 10000000 + I += 1000 + + async with Manager(): + I += 100 + async for i in Iterable(): + I += 1 + else: + I += 10000000 + I += 1000 + + run_async(main()) + self.assertEqual(I, 20555255) + + def test_for_7(self): + CNT = 0 + class AI: + async def __aiter__(self): + 1/0 + async def foo(): + nonlocal CNT + async for i in AI(): + CNT += 1 + CNT += 10 + with self.assertRaises(ZeroDivisionError): + run_async(foo()) + self.assertEqual(CNT, 0) + + +class CoroAsyncIOCompatTest(unittest.TestCase): + + def test_asyncio_1(self): + import asyncio + + class MyException(Exception): + pass + + buffer = [] + + class CM: + async def __aenter__(self): + buffer.append(1) + await asyncio.sleep(0.01) + buffer.append(2) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await asyncio.sleep(0.01) + buffer.append(exc_type.__name__) + + async def f(): + async with CM() as c: + await asyncio.sleep(0.01) + raise MyException + buffer.append('unreachable') + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(f()) + except MyException: + pass + finally: + loop.close() + asyncio.set_event_loop(None) + + self.assertEqual(buffer, [1, 2, 'MyException']) + + +class SysSetCoroWrapperTest(unittest.TestCase): + + def test_set_wrapper_1(self): + async def foo(): + return 'spam' + + wrapped = None + def wrap(gen): + nonlocal wrapped + wrapped = gen + return gen + + self.assertIsNone(sys.get_coroutine_wrapper()) + + sys.set_coroutine_wrapper(wrap) + self.assertIs(sys.get_coroutine_wrapper(), wrap) + try: + f = foo() + self.assertTrue(wrapped) + + self.assertEqual(run_async(f), ([], 'spam')) + finally: + sys.set_coroutine_wrapper(None) + + self.assertIsNone(sys.get_coroutine_wrapper()) + + wrapped = None + with silence_coro_gc(): + foo() + self.assertFalse(wrapped) + + def test_set_wrapper_2(self): + self.assertIsNone(sys.get_coroutine_wrapper()) + with self.assertRaisesRegex(TypeError, "callable expected, got int"): + sys.set_coroutine_wrapper(1) + self.assertIsNone(sys.get_coroutine_wrapper()) + + +class CAPITest(unittest.TestCase): + + def test_tp_await_1(self): + from _testcapi import awaitType as at + + async def foo(): + future = at(iter([1])) + return (await future) + + self.assertEqual(foo().send(None), 1) + + def test_tp_await_2(self): + # Test tp_await to __await__ mapping + from _testcapi import awaitType as at + future = at(iter([1])) + self.assertEqual(next(future.__await__()), 1) + + def test_tp_await_3(self): + from _testcapi import awaitType as at + + async def foo(): + future = at(1) + return (await future) + + with self.assertRaisesRegex( + TypeError, "__await__.*returned non-iterator of type 'int'"): + self.assertEqual(foo().send(None), 1) + + +# disable some tests that only apply to CPython + +CAPITest = None # no CAPI module + +if sys.version_info < (3, 5): + SysSetCoroWrapperTest = None + +if __name__=="__main__": + unittest.main() -- 2.30.9