Commit 18f7fa3c authored by Stefan Behnel's avatar Stefan Behnel

fix tree structure for generator expressions

parent 47329982
...@@ -1046,6 +1046,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1046,6 +1046,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
class YieldNodeCollector(Visitor.TreeVisitor): class YieldNodeCollector(Visitor.TreeVisitor):
def __init__(self): def __init__(self):
Visitor.TreeVisitor.__init__(self) Visitor.TreeVisitor.__init__(self)
self.yield_stat_nodes = {}
self.yield_nodes = [] self.yield_nodes = []
visit_Node = Visitor.TreeVisitor.visitchildren visit_Node = Visitor.TreeVisitor.visitchildren
...@@ -1053,12 +1054,18 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1053,12 +1054,18 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
self.yield_nodes.append(node) self.yield_nodes.append(node)
self.visitchildren(node) self.visitchildren(node)
def visit_ExprStatNode(self, node):
self.visitchildren(node)
if node.expr in self.yield_nodes:
self.yield_stat_nodes[node.expr] = node
def _find_single_yield_node(self, node): def _find_single_yield_node(self, node):
collector = self.YieldNodeCollector() collector = self.YieldNodeCollector()
collector.visitchildren(node) collector.visitchildren(node)
if len(collector.yield_nodes) != 1: if len(collector.yield_nodes) != 1:
return None return None, None
return collector.yield_nodes[0] yield_node = collector.yield_nodes[0]
return (yield_node, collector.yield_stat_nodes.get(yield_node))
def _handle_simple_function_all(self, node, pos_args): def _handle_simple_function_all(self, node, pos_args):
"""Transform """Transform
...@@ -1107,8 +1114,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1107,8 +1114,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return node return node
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_node = self._find_single_yield_node(loop_node) yield_node, yield_stat_node = self._find_single_yield_node(loop_node)
if yield_node is None: if yield_node is None or yield_stat_node is None:
return node return node
yield_expression = yield_node.arg yield_expression = yield_node.arg
...@@ -1150,7 +1157,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1150,7 +1157,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any, rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any,
constant_result = not is_any)) constant_result = not is_any))
Visitor.recursively_replace_node(loop_node, yield_node, test_node) Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
return ExprNodes.InlinedGeneratorExpressionNode( return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = loop_node, result_node = result_ref, gen_expr_node.pos, loop = loop_node, result_node = result_ref,
...@@ -1166,8 +1173,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1166,8 +1173,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
gen_expr_node = pos_args[0] gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop loop_node = gen_expr_node.loop
yield_node = self._find_single_yield_node(loop_node) yield_node, yield_stat_node = self._find_single_yield_node(loop_node)
if yield_node is None: if yield_node is None or yield_stat_node is None:
return node return node
yield_expression = yield_node.arg yield_expression = yield_node.arg
...@@ -1183,7 +1190,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -1183,7 +1190,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
) )
Visitor.recursively_replace_node(loop_node, yield_node, add_node) Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
exec_code = Nodes.StatListNode( exec_code = Nodes.StatListNode(
node.pos, node.pos,
......
...@@ -958,7 +958,8 @@ def p_testlist_comp(s): ...@@ -958,7 +958,8 @@ def p_testlist_comp(s):
def p_genexp(s, expr): def p_genexp(s, expr):
# s.sy == 'for' # s.sy == 'for'
loop = p_comp_for(s, ExprNodes.YieldExprNode(expr.pos, arg=expr)) loop = p_comp_for(s, Nodes.ExprStatNode(
expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr)))
return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop) return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop)
expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE') expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment