Commit 28c98012 authored by scoder's avatar scoder Committed by GitHub

Optimise "[...] * N" where N is a non-literal C integer value. (GH-4233)

Closes https://github.com/cython/cython/issues/3922
parent 07b8fed9
......@@ -34,8 +34,11 @@ from . import PyrexTypes
from .PyrexTypes import py_object_type, c_long_type, typecast, error_type, \
unspecified_type
from . import TypeSlots
from .Builtin import list_type, tuple_type, set_type, dict_type, type_type, \
unicode_type, str_type, bytes_type, bytearray_type, basestring_type, slice_type
from .Builtin import (
list_type, tuple_type, set_type, dict_type, type_type,
unicode_type, str_type, bytes_type, bytearray_type, basestring_type,
slice_type, long_type,
)
from . import Builtin
from . import Symtab
from .. import Utils
......@@ -7523,9 +7526,10 @@ class SequenceNode(ExprNode):
arg = arg.analyse_types(env)
self.args[i] = arg.coerce_to_pyobject(env)
if self.mult_factor:
self.mult_factor = self.mult_factor.analyse_types(env)
if not self.mult_factor.type.is_int:
self.mult_factor = self.mult_factor.coerce_to_pyobject(env)
mult_factor = self.mult_factor.analyse_types(env)
if not mult_factor.type.is_int:
mult_factor = mult_factor.coerce_to_pyobject(env)
self.mult_factor = mult_factor.coerce_to_simple(env)
self.is_temp = 1
# not setting self.type here, subtypes do this
return self
......@@ -11598,6 +11602,24 @@ class SubNode(NumBinopNode):
class MulNode(NumBinopNode):
# '*' operator.
def analyse_types(self, env):
# TODO: we could also optimise the case of "[...] * 2 * n", i.e. with an existing 'mult_factor'
if self.operand1.is_sequence_constructor and self.operand1.mult_factor is None:
operand2 = self.operand2.analyse_types(env)
if operand2.type.is_int or operand2.type is long_type:
return self.analyse_sequence_mul(env, self.operand1, operand2)
elif self.operand2.is_sequence_constructor and self.operand2.mult_factor is None:
operand1 = self.operand1.analyse_types(env)
if operand1.type.is_int or operand1.type is long_type:
return self.analyse_sequence_mul(env, self.operand2, operand1)
return NumBinopNode.analyse_types(self, env)
def analyse_sequence_mul(self, env, seq, mult):
assert seq.mult_factor is None
seq.mult_factor = mult
return seq.analyse_types(env)
def is_py_operation_types(self, type1, type2):
if ((type1.is_string and type2.is_int) or
(type2.is_string and type1.is_int)):
......
# mode: run
# tag: list, mulop, pure3.0
import cython
@cython.test_fail_if_path_exists("//MulNode")
@cython.test_assert_path_exists("//ListNode[@mult_factor]")
def cint_times_list(n: cython.int):
"""
>>> cint_times_list(3)
[]
[None, None, None]
[3, 3, 3]
[1, 2, 3, 1, 2, 3, 1, 2, 3]
"""
a = n * []
b = n * [None]
c = n * [n]
d = n * [1, 2, 3]
print(a)
print(b)
print(c)
print(d)
@cython.test_fail_if_path_exists("//MulNode")
@cython.test_assert_path_exists("//ListNode[@mult_factor]")
def list_times_cint(n: cython.int):
"""
>>> list_times_cint(3)
[]
[None, None, None]
[3, 3, 3]
[1, 2, 3, 1, 2, 3, 1, 2, 3]
"""
a = [] * n
b = [None] * n
c = [n] * n
d = [1, 2, 3] * n
print(a)
print(b)
print(c)
print(d)
@cython.test_fail_if_path_exists("//MulNode")
@cython.test_assert_path_exists("//TupleNode[@mult_factor]")
def cint_times_tuple(n: cython.int):
"""
>>> cint_times_tuple(3)
()
(None, None, None)
(3, 3, 3)
(1, 2, 3, 1, 2, 3, 1, 2, 3)
"""
a = n * ()
b = n * (None,)
c = n * (n,)
d = n * (1, 2, 3)
print(a)
print(b)
print(c)
print(d)
@cython.test_fail_if_path_exists("//MulNode")
@cython.test_assert_path_exists("//TupleNode[@mult_factor]")
def tuple_times_cint(n: cython.int):
"""
>>> tuple_times_cint(3)
()
(None, None, None)
(3, 3, 3)
(1, 2, 3, 1, 2, 3, 1, 2, 3)
"""
a = () * n
b = (None,) * n
c = (n,) * n
d = (1, 2, 3) * n
print(a)
print(b)
print(c)
print(d)
# TODO: enable in Cython 3.1 when we can infer unsafe C int operations as PyLong
#@cython.test_fail_if_path_exists("//MulNode")
def list_times_pyint(n: cython.longlong):
"""
>>> list_times_cint(3)
[]
[None, None, None]
[3, 3, 3]
[1, 2, 3, 1, 2, 3, 1, 2, 3]
"""
py_n = n + 1 # might overflow => should be inferred as Python long!
a = [] * py_n
b = [None] * py_n
c = py_n * [n]
d = py_n * [1, 2, 3]
print(a)
print(b)
print(c)
print(d)
@cython.cfunc
def sideeffect(x) -> cython.int:
global _sideeffect_value
_sideeffect_value += 1
return _sideeffect_value + x
def reset_sideeffect():
global _sideeffect_value
_sideeffect_value = 0
@cython.test_fail_if_path_exists("//MulNode")
@cython.test_assert_path_exists("//ListNode[@mult_factor]")
def complicated_cint_times_list(n: cython.int):
"""
>>> complicated_cint_times_list(3)
[]
[None, None, None, None]
[3, 3, 3, 3]
[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]
"""
reset_sideeffect()
a = [] * sideeffect((lambda: n)())
reset_sideeffect()
b = sideeffect((lambda: n)()) * [None]
reset_sideeffect()
c = [n] * sideeffect((lambda: n)())
reset_sideeffect()
d = sideeffect((lambda: n)()) * [1, 2, 3]
print(a)
print(b)
print(c)
print(d)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment