Commit 13ace3a9 authored by Stefan Behnel's avatar Stefan Behnel

re-establish a simple form of generator inlining for any() and all() that does...

re-establish a simple form of generator inlining for any() and all() that does not remove the generator but inlines the evaluation into the inner loop
parent 0710b467
...@@ -7315,7 +7315,39 @@ class DictComprehensionAppendNode(ComprehensionAppendNode): ...@@ -7315,7 +7315,39 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
self.value_expr.annotate(code) self.value_expr.annotate(code)
class InlinedGeneratorExpressionNode(ScopedExprNode): class InlinedGeneratorExpressionNode(ExprNode):
# An inlined generator expression for which the result is
# calculated inside of the loop. This will only be created by
# transforms when replacing builtin calls on generator
# expressions.
#
# gen GeneratorExpressionNode the generator, not containing any YieldExprNodes
# orig_func String the name of the builtin function this node replaces
subexprs = ["gen"]
orig_func = None
type = py_object_type
def may_be_none(self):
return self.orig_func not in ('any', 'all')
def infer_type(self, env):
return py_object_type
def analyse_types(self, env):
self.gen = self.gen.analyse_expressions(env)
self.is_temp = True
return self
def generate_result_code(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("GetGenexpResult", "Coroutine.c"))
code.putln("%s = __Pyx_Generator_GetGenexpResult(%s); %s" % (
self.result(), self.gen.result(),
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.result())
class __InlinedGeneratorExpressionNode(ScopedExprNode):
# An inlined generator expression for which the result is # An inlined generator expression for which the result is
# calculated inside of the loop. This will only be created by # calculated inside of the loop. This will only be created by
# transforms when replacing builtin calls on generator # transforms when replacing builtin calls on generator
......
...@@ -1455,16 +1455,16 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1455,16 +1455,16 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
visit_Node = Visitor.TreeVisitor.visitchildren visit_Node = Visitor.TreeVisitor.visitchildren
# XXX: disable inlining while it's not back supported # XXX: disable inlining while it's not back supported
def __visit_YieldExprNode(self, node): def visit_YieldExprNode(self, node):
self.yield_nodes.append(node) self.yield_nodes.append(node)
self.visitchildren(node) self.visitchildren(node)
def __visit_ExprStatNode(self, node): def visit_ExprStatNode(self, node):
self.visitchildren(node) self.visitchildren(node)
if node.expr in self.yield_nodes: if node.expr in self.yield_nodes:
self.yield_stat_nodes[node.expr] = node self.yield_stat_nodes[node.expr] = node
def __visit_GeneratorExpressionNode(self, node): def visit_GeneratorExpressionNode(self, node):
# enable when we support generic generator expressions # enable when we support generic generator expressions
# #
# everything below this node is out of scope # everything below this node is out of scope
...@@ -1527,7 +1527,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1527,7 +1527,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
return node return node
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop generator_body = gen_expr_node.def_node.gbody
loop_node = generator_body.body
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
if yield_expression is None: if yield_expression is None:
return node return node
...@@ -1535,46 +1536,37 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1535,46 +1536,37 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if is_any: if is_any:
condition = yield_expression condition = yield_expression
else: else:
condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression) condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression)
result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
test_node = Nodes.IfStatNode( test_node = Nodes.IfStatNode(
yield_expression.pos, yield_expression.pos, else_clause=None, if_clauses=[
else_clause = None, Nodes.IfClauseNode(
if_clauses = [ Nodes.IfClauseNode( yield_expression.pos,
yield_expression.pos, condition=condition,
condition = condition, body=Nodes.ReturnStatNode(
body = Nodes.StatListNode( node.pos,
node.pos, value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any),
stats = [ in_generator=True)
Nodes.SingleAssignmentNode( )]
node.pos, )
lhs = result_ref,
rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
constant_result = is_any)),
Nodes.BreakStatNode(node.pos)
])) ]
)
loop = loop_node loop = loop_node
while isinstance(loop.body, Nodes.LoopNode): while isinstance(loop.body, Nodes.LoopNode):
next_loop = loop.body next_loop = loop.body
loop.body = Nodes.StatListNode(loop.body.pos, stats = [ loop.body = Nodes.StatListNode(loop.body.pos, stats=[
loop.body, loop.body,
Nodes.BreakStatNode(yield_expression.pos) Nodes.BreakStatNode(yield_expression.pos)
]) ])
next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos) next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
loop = next_loop loop = next_loop
loop_node.else_clause = Nodes.SingleAssignmentNode( loop_node.else_clause = Nodes.ReturnStatNode(
node.pos, node.pos,
lhs = result_ref, value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any),
rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any, in_generator=True)
constant_result = not is_any))
Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node) Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
return ExprNodes.InlinedGeneratorExpressionNode( return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = loop_node, result_node = result_ref, gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all')
expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
PySequence_List_func_type = PyrexTypes.CFuncType( PySequence_List_func_type = PyrexTypes.CFuncType(
Builtin.list_type, Builtin.list_type,
...@@ -1597,6 +1589,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1597,6 +1589,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional
yield_expression = None
if yield_expression is None: if yield_expression is None:
return node return node
...@@ -1642,7 +1636,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1642,7 +1636,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
result_node, result_node,
Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ])) Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
def _handle_simple_function_sum(self, node, pos_args): def __handle_simple_function_sum(self, node, pos_args):
"""Transform sum(genexpr) into an equivalent inlined aggregation loop. """Transform sum(genexpr) into an equivalent inlined aggregation loop.
""" """
if len(pos_args) not in (1,2): if len(pos_args) not in (1,2):
...@@ -1655,6 +1649,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1655,6 +1649,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode): if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional
yield_expression = None
if yield_expression is None: if yield_expression is None:
return node return node
else: # ComprehensionNode else: # ComprehensionNode
...@@ -1786,6 +1782,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1786,6 +1782,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional
yield_expression = None
if yield_expression is None: if yield_expression is None:
return node return node
...@@ -1818,6 +1816,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1818,6 +1816,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node) yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
# FIXME: currently nonfunctional
yield_expression = None
if yield_expression is None: if yield_expression is None:
return node return node
......
...@@ -246,6 +246,28 @@ static void __Pyx_Generator_Replace_StopIteration(void) { ...@@ -246,6 +246,28 @@ static void __Pyx_Generator_Replace_StopIteration(void) {
} }
//////////////////// GetGenexpResult.proto ////////////////////
static CYTHON_INLINE PyObject* __Pyx_Generator_GetGenexpResult(PyObject* gen); /*proto*/
//////////////////// GetGenexpResult ////////////////////
//@requires: Generator
static CYTHON_INLINE PyObject* __Pyx_Generator_GetGenexpResult(PyObject* gen) {
PyObject *result;
result = __Pyx_Generator_Next(gen);
if (unlikely(result)) {
PyErr_Format(PyExc_RuntimeError, "Generator expression returned with non-StopIteration result '%.100s'",
result ? Py_TYPE(result)->tp_name : "NULL");
Py_XDECREF(result);
return NULL;
}
if (unlikely(__Pyx_PyGen_FetchStopIterationValue(&result) < 0))
return NULL;
return result;
}
//////////////////// CoroutineBase.proto //////////////////// //////////////////// CoroutineBase.proto ////////////////////
typedef PyObject *(*__pyx_coroutine_body_t)(PyObject *, PyObject *); typedef PyObject *(*__pyx_coroutine_body_t)(PyObject *, PyObject *);
......
...@@ -52,6 +52,4 @@ pyregr.test_urllib2net ...@@ -52,6 +52,4 @@ pyregr.test_urllib2net
pyregr.test_urllibnet pyregr.test_urllibnet
# Inlined generators # Inlined generators
all
any
inlined_generator_expressions inlined_generator_expressions
...@@ -53,10 +53,14 @@ def all_item(x): ...@@ -53,10 +53,14 @@ def all_item(x):
""" """
return all(x) return all(x)
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode") "//ForInStatNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode"
"//YieldExprNode") )
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def all_in_simple_gen(seq): def all_in_simple_gen(seq):
""" """
>>> all_in_simple_gen([1,1,1]) >>> all_in_simple_gen([1,1,1])
...@@ -82,10 +86,14 @@ def all_in_simple_gen(seq): ...@@ -82,10 +86,14 @@ def all_in_simple_gen(seq):
""" """
return all(x for x in seq) return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode") "//ForInStatNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode"
"//YieldExprNode") )
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def all_in_simple_gen_scope(seq): def all_in_simple_gen_scope(seq):
""" """
>>> all_in_simple_gen_scope([1,1,1]) >>> all_in_simple_gen_scope([1,1,1])
...@@ -114,10 +122,14 @@ def all_in_simple_gen_scope(seq): ...@@ -114,10 +122,14 @@ def all_in_simple_gen_scope(seq):
assert x == 'abc' assert x == 'abc'
return result return result
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode") "//ForInStatNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode"
"//YieldExprNode") )
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def all_in_conditional_gen(seq): def all_in_conditional_gen(seq):
""" """
>>> all_in_conditional_gen([3,6,9]) >>> all_in_conditional_gen([3,6,9])
...@@ -150,10 +162,14 @@ mixed_ustring = u'AbcDefGhIjKlmnoP' ...@@ -150,10 +162,14 @@ mixed_ustring = u'AbcDefGhIjKlmnoP'
lower_ustring = mixed_ustring.lower() lower_ustring = mixed_ustring.lower()
upper_ustring = mixed_ustring.upper() upper_ustring = mixed_ustring.upper()
@cython.test_assert_path_exists('//PythonCapiCallNode', @cython.test_assert_path_exists(
'//ForFromStatNode') '//PythonCapiCallNode',
@cython.test_fail_if_path_exists('//SimpleCallNode', '//ForFromStatNode'
'//ForInStatNode') )
@cython.test_fail_if_path_exists(
'//SimpleCallNode',
'//ForInStatNode'
)
def all_lower_case_characters(unicode ustring): def all_lower_case_characters(unicode ustring):
""" """
>>> all_lower_case_characters(mixed_ustring) >>> all_lower_case_characters(mixed_ustring)
...@@ -165,12 +181,16 @@ def all_lower_case_characters(unicode ustring): ...@@ -165,12 +181,16 @@ def all_lower_case_characters(unicode ustring):
""" """
return all(uchar.islower() for uchar in ustring) return all(uchar.islower() for uchar in ustring)
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode", "//ForInStatNode",
"//InlinedGeneratorExpressionNode//IfStatNode") "//InlinedGeneratorExpressionNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode//IfStatNode"
"//YieldExprNode", )
"//IfStatNode//CoerceToBooleanNode") @cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode",
# "//IfStatNode//CoerceToBooleanNode"
)
def all_in_typed_gen(seq): def all_in_typed_gen(seq):
""" """
>>> all_in_typed_gen([1,1,1]) >>> all_in_typed_gen([1,1,1])
...@@ -197,12 +217,16 @@ def all_in_typed_gen(seq): ...@@ -197,12 +217,16 @@ def all_in_typed_gen(seq):
cdef int x cdef int x
return all(x for x in seq) return all(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode", "//ForInStatNode",
"//InlinedGeneratorExpressionNode//IfStatNode") "//InlinedGeneratorExpressionNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode//IfStatNode"
"//YieldExprNode", )
"//IfStatNode//CoerceToBooleanNode") @cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode",
# "//IfStatNode//CoerceToBooleanNode"
)
def all_in_double_gen(seq): def all_in_double_gen(seq):
""" """
>>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L) >>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L)
......
...@@ -52,10 +52,14 @@ def any_item(x): ...@@ -52,10 +52,14 @@ def any_item(x):
return any(x) return any(x)
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode") "//ForInStatNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode"
"//YieldExprNode") )
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def any_in_simple_gen(seq): def any_in_simple_gen(seq):
""" """
>>> any_in_simple_gen([0,1,0]) >>> any_in_simple_gen([0,1,0])
...@@ -80,10 +84,14 @@ def any_in_simple_gen(seq): ...@@ -80,10 +84,14 @@ def any_in_simple_gen(seq):
return any(x for x in seq) return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode") "//ForInStatNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode"
"//YieldExprNode") )
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def any_in_simple_gen_scope(seq): def any_in_simple_gen_scope(seq):
""" """
>>> any_in_simple_gen_scope([0,1,0]) >>> any_in_simple_gen_scope([0,1,0])
...@@ -111,10 +119,14 @@ def any_in_simple_gen_scope(seq): ...@@ -111,10 +119,14 @@ def any_in_simple_gen_scope(seq):
return result return result
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode") "//ForInStatNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode"
"//YieldExprNode") )
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def any_in_conditional_gen(seq): def any_in_conditional_gen(seq):
""" """
>>> any_in_conditional_gen([3,6,9]) >>> any_in_conditional_gen([3,6,9])
...@@ -146,11 +158,15 @@ lower_ustring = mixed_ustring.lower() ...@@ -146,11 +158,15 @@ lower_ustring = mixed_ustring.lower()
upper_ustring = mixed_ustring.upper() upper_ustring = mixed_ustring.upper()
@cython.test_assert_path_exists('//PythonCapiCallNode', @cython.test_assert_path_exists(
'//ForFromStatNode', '//PythonCapiCallNode',
"//InlinedGeneratorExpressionNode") '//ForFromStatNode',
@cython.test_fail_if_path_exists('//SimpleCallNode', "//InlinedGeneratorExpressionNode"
'//ForInStatNode') )
@cython.test_fail_if_path_exists(
'//SimpleCallNode',
'//ForInStatNode'
)
def any_lower_case_characters(unicode ustring): def any_lower_case_characters(unicode ustring):
""" """
>>> any_lower_case_characters(upper_ustring) >>> any_lower_case_characters(upper_ustring)
...@@ -163,12 +179,16 @@ def any_lower_case_characters(unicode ustring): ...@@ -163,12 +179,16 @@ def any_lower_case_characters(unicode ustring):
return any(uchar.islower() for uchar in ustring) return any(uchar.islower() for uchar in ustring)
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode", "//ForInStatNode",
"//InlinedGeneratorExpressionNode//IfStatNode") "//InlinedGeneratorExpressionNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode//IfStatNode"
"//YieldExprNode", )
"//IfStatNode//CoerceToBooleanNode") @cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode",
# "//IfStatNode//CoerceToBooleanNode"
)
def any_in_typed_gen(seq): def any_in_typed_gen(seq):
""" """
>>> any_in_typed_gen([0,1,0]) >>> any_in_typed_gen([0,1,0])
...@@ -194,11 +214,15 @@ def any_in_typed_gen(seq): ...@@ -194,11 +214,15 @@ def any_in_typed_gen(seq):
return any(x for x in seq) return any(x for x in seq)
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode", "//ForInStatNode",
"//InlinedGeneratorExpressionNode//IfStatNode") "//InlinedGeneratorExpressionNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode//IfStatNode"
"//YieldExprNode") )
@cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode"
)
def any_in_gen_builtin_name(seq): def any_in_gen_builtin_name(seq):
""" """
>>> any_in_gen_builtin_name([0,1,0]) >>> any_in_gen_builtin_name([0,1,0])
...@@ -223,12 +247,16 @@ def any_in_gen_builtin_name(seq): ...@@ -223,12 +247,16 @@ def any_in_gen_builtin_name(seq):
return any(type for type in seq) return any(type for type in seq)
@cython.test_assert_path_exists("//ForInStatNode", @cython.test_assert_path_exists(
"//InlinedGeneratorExpressionNode", "//ForInStatNode",
"//InlinedGeneratorExpressionNode//IfStatNode") "//InlinedGeneratorExpressionNode",
@cython.test_fail_if_path_exists("//SimpleCallNode", "//InlinedGeneratorExpressionNode//IfStatNode"
"//YieldExprNode", )
"//IfStatNode//CoerceToBooleanNode") @cython.test_fail_if_path_exists(
"//SimpleCallNode",
"//YieldExprNode",
# "//IfStatNode//CoerceToBooleanNode"
)
def any_in_double_gen(seq): def any_in_double_gen(seq):
""" """
>>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L) >>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment