Commit 33f8fe4e authored by Stefan Behnel's avatar Stefan Behnel

fix ticket 467: restore eval-once semantics for all rhs items in parallel...

fix ticket 467: restore eval-once semantics for all rhs items in parallel assignments by extracting common subexpressions into temps
parent eba0c000
...@@ -265,6 +265,9 @@ class PostParse(CythonTransform): ...@@ -265,6 +265,9 @@ class PostParse(CythonTransform):
expr_list_list = [] expr_list_list = []
flatten_parallel_assignments(expr_list, expr_list_list) flatten_parallel_assignments(expr_list, expr_list_list)
temp_refs = []
eliminate_rhs_duplicates(expr_list_list, temp_refs)
nodes = [] nodes = []
for expr_list in expr_list_list: for expr_list in expr_list_list:
lhs_list = expr_list[:-1] lhs_list = expr_list[:-1]
...@@ -276,11 +279,94 @@ class PostParse(CythonTransform): ...@@ -276,11 +279,94 @@ class PostParse(CythonTransform):
node = Nodes.CascadedAssignmentNode(rhs.pos, node = Nodes.CascadedAssignmentNode(rhs.pos,
lhs_list = lhs_list, rhs = rhs) lhs_list = lhs_list, rhs = rhs)
nodes.append(node) nodes.append(node)
if len(nodes) == 1: if len(nodes) == 1:
return nodes[0] assign_node = nodes[0]
else:
assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
if temp_refs:
duplicates_and_temps = [ (temp.expression, temp)
for temp in temp_refs ]
sort_common_subsequences(duplicates_and_temps)
for _, temp_ref in duplicates_and_temps[::-1]:
assign_node = LetNode(temp_ref, assign_node)
return assign_node
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
"""Replace rhs items by LetRefNodes if they appear more than once.
Creates a sequence of LetRefNodes that set up the required temps
and appends them to ref_node_sequence. The input list is modified
in-place.
"""
seen_nodes = set()
ref_nodes = {}
def find_duplicates(node):
if node.is_literal or node.is_name:
# no need to replace those; can't include attributes here
# as their access is not necessarily side-effect free
return
if node in seen_nodes:
if node not in ref_nodes:
ref_node = LetRefNode(node)
ref_nodes[node] = ref_node
ref_node_sequence.append(ref_node)
else: else:
return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes) seen_nodes.add(node)
if node.is_sequence_constructor:
for item in node.args:
find_duplicates(item)
for expr_list in expr_list_list:
rhs = expr_list[-1]
find_duplicates(rhs)
if not ref_nodes:
return
def substitute_nodes(node):
if node in ref_nodes:
return ref_nodes[node]
elif node.is_sequence_constructor:
node.args = map(substitute_nodes, node.args)
return node
# replace nodes inside of the common subexpressions
for node in ref_nodes:
if node.is_sequence_constructor:
node.args = map(substitute_nodes, node.args)
# replace common subexpressions on all rhs items
for expr_list in expr_list_list:
expr_list[-1] = substitute_nodes(expr_list[-1])
def sort_common_subsequences(items):
"""Sort items/subsequences so that all items and subsequences that
an item contains appear before the item itself. This implies a
partial order, and the sort must be stable to preserve the
original order as much as possible, so we use a simple insertion
sort.
"""
def contains(seq, x):
for item in seq:
if item is x:
return True
elif item.is_sequence_constructor and contains(item.args, x):
return True
return False
def lower_than(a,b):
return b.is_sequence_constructor and contains(b.args, a)
for pos, item in enumerate(items):
new_pos = pos
key = item[0]
for i in xrange(pos-1, -1, -1):
if lower_than(key, items[i][0]):
new_pos = i
if new_pos != pos:
for i in xrange(pos, new_pos, -1):
items[i] = items[i-1]
items[new_pos] = item
def flatten_parallel_assignments(input, output): def flatten_parallel_assignments(input, output):
# The input is a list of expression nodes, representing the LHSs # The input is a list of expression nodes, representing the LHSs
......
...@@ -130,6 +130,9 @@ class ResultRefNode(AtomicExprNode): ...@@ -130,6 +130,9 @@ class ResultRefNode(AtomicExprNode):
def infer_type(self, env): def infer_type(self, env):
return self.expression.infer_type(env) return self.expression.infer_type(env)
def is_simple(self):
return True
def result(self): def result(self):
return self.result_code return self.result_code
...@@ -222,7 +225,8 @@ class LetNode(Nodes.StatNode, LetNodeMixin): ...@@ -222,7 +225,8 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
# BLOCK (can modify temp) # BLOCK (can modify temp)
# if temp is an object, decref # if temp is an object, decref
# #
# To be used after analysis phase, does no analysis. # Usually used after analysis phase, but forwards analysis methods
# to its children
child_attrs = ['temp_expression', 'body'] child_attrs = ['temp_expression', 'body']
...@@ -231,6 +235,17 @@ class LetNode(Nodes.StatNode, LetNodeMixin): ...@@ -231,6 +235,17 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
self.pos = body.pos self.pos = body.pos
self.body = body self.body = body
def analyse_control_flow(self, env):
self.body.analyse_control_flow(env)
def analyse_declarations(self, env):
self.temp_expression.analyse_declarations(env)
self.body.analyse_declarations(env)
def analyse_expressions(self, env):
self.temp_expression.analyse_expressions(env)
self.body.analyse_expressions(env)
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.setup_temp_expr(code) self.setup_temp_expr(code)
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment