Commit 55750171 authored by Stefan Behnel's avatar Stefan Behnel

optimise multiplication of constant lists/tuples as in [1,2,3]*5

parent 8be2ae86
......@@ -4331,11 +4331,13 @@ class SequenceNode(ExprNode):
# args [ExprNode]
# unpacked_items [ExprNode] or None
# coerced_unpacked_items [ExprNode] or None
# mult_factor ExprNode the integer number of content repetitions ([1,2]*3)
subexprs = ['args']
subexprs = ['args', 'mult_factor']
is_sequence_constructor = 1
unpacked_items = None
mult_factor = None
def compile_time_value_list(self, denv):
return [arg.compile_time_value(denv) for arg in self.args]
......@@ -4364,6 +4366,15 @@ class SequenceNode(ExprNode):
arg = self.args[i]
if not skip_children: arg.analyse_types(env)
self.args[i] = arg.coerce_to_pyobject(env)
if self.mult_factor:
self.mult_factor.analyse_types(env)
if not self.mult_factor.type.is_int:
if self.mult_factor.type.is_pyobject:
self.mult_factor = self.mult_factor.coerce_to(
PyrexTypes.c_py_ssize_t_type, env)
else:
error(self.pos, "can't multiply sequence by non-int of type '%s'" %
self.mult_factor.type)
self.is_temp = 1
# not setting self.type here, subtypes do this
......@@ -4371,6 +4382,8 @@ class SequenceNode(ExprNode):
return False
def analyse_target_types(self, env):
if self.mult_factor:
error(arg.pos, "can't assign to multiplied sequence")
self.unpacked_items = []
self.coerced_unpacked_items = []
self.any_coerced_items = False
......@@ -4393,6 +4406,66 @@ class SequenceNode(ExprNode):
def generate_result_code(self, code):
self.generate_operation_code(code)
def generate_sequence_packing_code(self, code):
if self.type is Builtin.list_type:
create_func, set_item_func = 'PyList_New', 'PyList_SET_ITEM'
elif self.type is Builtin.tuple_type:
create_func, set_item_func = 'PyTuple_New', 'PyTuple_SET_ITEM'
else:
raise InternalError("sequence unpacking for unexpected type %s" % self.type)
if self.mult_factor:
mult = self.mult_factor.result()
if isinstance(self.mult_factor.constant_result, (int,long)) \
and self.mult_factor.constant_result > 0:
size_factor = ' * %s' % self.mult_factor.constant_result
else:
size_factor = ' * ((%s<0) ? 0:%s)' % (mult, mult)
else:
size_factor = ''
mult = ''
arg_count = len(self.args)
code.putln("%s = %s(%s%s); %s" % (
self.result(),
create_func,
arg_count,
size_factor,
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result())
if mult:
counter = code.funcstate.allocate_temp(self.mult_factor.type, manage_ref=False)
offset = '%s * %s + ' % (counter, arg_count)
code.putln('for (%s=0; %s < %s; %s++) {' % (
counter, counter, mult, counter
))
else:
offset = ''
for i in xrange(arg_count):
arg = self.args[i]
if mult or not arg.result_in_temp():
code.put_incref(arg.result(), arg.ctype())
code.putln("%s(%s, %s%s, %s);" % (
set_item_func,
self.result(),
offset,
i,
arg.py_result()))
code.put_giveref(arg.py_result())
if mult:
code.putln('}')
code.funcstate.release_temp(counter)
def generate_subexpr_disposal_code(self, code):
if self.mult_factor:
super(SequenceNode, self).generate_subexpr_disposal_code(code)
else:
# We call generate_post_assignment_code here instead
# of generate_disposal_code, because values were stored
# in the tuple using a reference-stealing operation.
for arg in self.args:
arg.generate_post_assignment_code(code)
# Should NOT call free_temps -- this is invoked by the default
# generate_evaluation_code which will do that.
def generate_assignment_code(self, rhs, code):
if self.starred_assignment:
self.generate_starred_assignment_code(rhs, code)
......@@ -4651,35 +4724,10 @@ class TupleNode(SequenceNode):
self.result_code = code.get_py_const(py_object_type, 'tuple_', cleanup_level=2)
code = code.get_cached_constants_writer()
code.mark_pos(self.pos)
code.putln(
"%s = PyTuple_New(%s); %s" % (
self.result(),
len(self.args),
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result())
for i in range(len(self.args)):
arg = self.args[i]
if not arg.result_in_temp():
code.put_incref(arg.result(), arg.ctype())
code.putln(
"PyTuple_SET_ITEM(%s, %s, %s);" % (
self.result(),
i,
arg.py_result()))
code.put_giveref(arg.py_result())
self.generate_sequence_packing_code(code)
if self.is_literal:
code.put_giveref(self.py_result())
def generate_subexpr_disposal_code(self, code):
# We call generate_post_assignment_code here instead
# of generate_disposal_code, because values were stored
# in the tuple using a reference-stealing operation.
for arg in self.args:
arg.generate_post_assignment_code(code)
# Should NOT call free_temps -- this is invoked by the default
# generate_evaluation_code which will do that.
class ListNode(SequenceNode):
# List constructor.
......@@ -4717,6 +4765,8 @@ class ListNode(SequenceNode):
self.obj_conversion_errors = []
if not self.type.subtype_of(dst_type):
error(self.pos, "Cannot coerce list to type '%s'" % dst_type)
elif self.mult_factor:
error(self.pos, "Cannot coerce multiplied list to '%s'" % dst_type)
elif dst_type.is_ptr and dst_type.base_type is not PyrexTypes.c_void_type:
base_type = dst_type.base_type
self.type = PyrexTypes.CArrayType(base_type, len(self.args))
......@@ -4750,31 +4800,22 @@ class ListNode(SequenceNode):
SequenceNode.release_temp(self, env)
def calculate_constant_result(self):
if self.mult_factor:
raise ValueError() # may exceed the compile time memory
self.constant_result = [
arg.constant_result for arg in self.args]
def compile_time_value(self, denv):
return self.compile_time_value_list(denv)
l = self.compile_time_value_list(denv)
if self.mult_factor:
l *= self.mult_factor.compile_time_value(denv)
return l
def generate_operation_code(self, code):
if self.type.is_pyobject:
for err in self.obj_conversion_errors:
report_error(err)
code.putln("%s = PyList_New(%s); %s" %
(self.result(),
len(self.args),
code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result())
for i in range(len(self.args)):
arg = self.args[i]
#if not arg.is_temp:
if not arg.result_in_temp():
code.put_incref(arg.result(), arg.ctype())
code.putln("PyList_SET_ITEM(%s, %s, %s);" %
(self.result(),
i,
arg.py_result()))
code.put_giveref(arg.py_result())
self.generate_sequence_packing_code(code)
elif self.type.is_array:
for i, arg in enumerate(self.args):
code.putln("%s[%s] = %s;" % (
......@@ -4790,15 +4831,6 @@ class ListNode(SequenceNode):
else:
raise InternalError("List type never specified")
def generate_subexpr_disposal_code(self, code):
# We call generate_post_assignment_code here instead
# of generate_disposal_code, because values were stored
# in the list using a reference-stealing operation.
for arg in self.args:
arg.generate_post_assignment_code(code)
# Should NOT call free_temps -- this is invoked by the default
# generate_evaluation_code which will do that.
class ScopedExprNode(ExprNode):
# Abstract base class for ExprNodes that have their own local
......
......@@ -3501,6 +3501,17 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
constant_result = node.constant_result)
return new_node
def visit_MulNode(self, node):
if isinstance(node.operand1, (ExprNodes.ListNode, ExprNodes.TupleNode)):
sequence_node = node.operand1
factor = node.operand2
self._calculate_const(factor)
if factor.constant_result != 1:
sequence_node.mult_factor = factor
self.visitchildren(sequence_node)
return sequence_node
return self.visit_BinopNode(node)
def visit_PrimaryCmpNode(self, node):
self._calculate_const(node)
if node.constant_result is ExprNodes.not_a_constant:
......
......@@ -91,6 +91,61 @@ def lists():
"""
return [1,2,3] + [4,5,6]
@cython.test_fail_if_path_exists("//BinopNode")
def multiplied_lists():
"""
>>> multiplied_lists() == [1,2,3] * 5
True
"""
return [1,2,3] * 5
@cython.test_fail_if_path_exists("//BinopNode")
def multiplied_lists_neg():
"""
>>> multiplied_lists_neg() == [1,2,3] * -5
True
"""
return [1,2,3] * -5
@cython.test_fail_if_path_exists("//BinopNode")
def multiplied_lists_nonconst(x):
"""
>>> multiplied_lists_nonconst(5) == [1,2,3] * 5
True
>>> multiplied_lists_nonconst(-5) == [1,2,3] * -5
True
>>> multiplied_lists_nonconst(0) == [1,2,3] * 0
True
"""
return [1,2,3] * x
@cython.test_fail_if_path_exists("//BinopNode")
def multiplied_lists_nonconst_expression(x):
"""
>>> multiplied_lists_nonconst_expression(5) == [1,2,3] * (5 * 2)
True
>>> multiplied_lists_nonconst_expression(-5) == [1,2,3] * (-5 * 2)
True
>>> multiplied_lists_nonconst_expression(0) == [1,2,3] * (0 * 2)
True
"""
return [1,2,3] * (x*2)
cdef side_effect(int x):
print x
return x
@cython.test_fail_if_path_exists("//BinopNode")
def multiplied_lists_with_side_effects():
"""
>>> multiplied_lists_with_side_effects() == [1,2,3] * 5
1
2
3
True
"""
return [side_effect(1), side_effect(2), side_effect(3)] * 5
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def compile_time_DEF():
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment