Commit b8e2b990 authored by Stefan Behnel's avatar Stefan Behnel

merge

parents d1170ab8 a128757e
...@@ -2070,7 +2070,6 @@ class IndexNode(ExprNode): ...@@ -2070,7 +2070,6 @@ class IndexNode(ExprNode):
skip_child_analysis = False skip_child_analysis = False
buffer_access = False buffer_access = False
if self.base.type.is_buffer: if self.base.type.is_buffer:
assert hasattr(self.base, "entry") # Must be a NameNode-like node
if self.indices: if self.indices:
indices = self.indices indices = self.indices
else: else:
...@@ -2085,6 +2084,8 @@ class IndexNode(ExprNode): ...@@ -2085,6 +2084,8 @@ class IndexNode(ExprNode):
x.analyse_types(env) x.analyse_types(env)
if not x.type.is_int: if not x.type.is_int:
buffer_access = False buffer_access = False
if buffer_access:
assert hasattr(self.base, "entry") # Must be a NameNode-like node
# On cloning, indices is cloned. Otherwise, unpack index into indices # On cloning, indices is cloned. Otherwise, unpack index into indices
assert not (buffer_access and isinstance(self.index, CloneNode)) assert not (buffer_access and isinstance(self.index, CloneNode))
...@@ -2746,6 +2747,7 @@ class SimpleCallNode(CallNode): ...@@ -2746,6 +2747,7 @@ class SimpleCallNode(CallNode):
wrapper_call = False wrapper_call = False
has_optional_args = False has_optional_args = False
nogil = False nogil = False
analysed = False
def compile_time_value(self, denv): def compile_time_value(self, denv):
function = self.function.compile_time_value(denv) function = self.function.compile_time_value(denv)
...@@ -2799,6 +2801,9 @@ class SimpleCallNode(CallNode): ...@@ -2799,6 +2801,9 @@ class SimpleCallNode(CallNode):
def analyse_types(self, env): def analyse_types(self, env):
if self.analyse_as_type_constructor(env): if self.analyse_as_type_constructor(env):
return return
if self.analysed:
return
self.analysed = True
function = self.function function = self.function
function.is_called = 1 function.is_called = 1
self.function.analyse_types(env) self.function.analyse_types(env)
...@@ -5402,7 +5407,7 @@ class BinopNode(ExprNode): ...@@ -5402,7 +5407,7 @@ class BinopNode(ExprNode):
#print "BinopNode.generate_result_code:", self.operand1, self.operand2 ### #print "BinopNode.generate_result_code:", self.operand1, self.operand2 ###
if self.operand1.type.is_pyobject: if self.operand1.type.is_pyobject:
function = self.py_operation_function() function = self.py_operation_function()
if function == "PyNumber_Power": if self.operator == '**':
extra_args = ", Py_None" extra_args = ", Py_None"
else: else:
extra_args = "" extra_args = ""
...@@ -5505,7 +5510,10 @@ class NumBinopNode(BinopNode): ...@@ -5505,7 +5510,10 @@ class NumBinopNode(BinopNode):
BinopNode.is_py_operation_types(self, type1, type2)) BinopNode.is_py_operation_types(self, type1, type2))
def py_operation_function(self): def py_operation_function(self):
return self.py_functions[self.operator] fuction = self.py_functions[self.operator]
if self.inplace:
fuction = fuction.replace('PyNumber_', 'PyNumber_InPlace')
return fuction
py_functions = { py_functions = {
"|": "PyNumber_Or", "|": "PyNumber_Or",
...@@ -5522,7 +5530,6 @@ class NumBinopNode(BinopNode): ...@@ -5522,7 +5530,6 @@ class NumBinopNode(BinopNode):
"**": "PyNumber_Power" "**": "PyNumber_Power"
} }
class IntBinopNode(NumBinopNode): class IntBinopNode(NumBinopNode):
# Binary operation taking integer arguments. # Binary operation taking integer arguments.
...@@ -6637,13 +6644,14 @@ binop_node_classes = { ...@@ -6637,13 +6644,14 @@ binop_node_classes = {
"**": PowNode "**": PowNode
} }
def binop_node(pos, operator, operand1, operand2): def binop_node(pos, operator, operand1, operand2, inplace=False):
# Construct binop node of appropriate class for # Construct binop node of appropriate class for
# given operator. # given operator.
return binop_node_classes[operator](pos, return binop_node_classes[operator](pos,
operator = operator, operator = operator,
operand1 = operand1, operand1 = operand1,
operand2 = operand2) operand2 = operand2,
inplace = inplace)
#------------------------------------------------------------------- #-------------------------------------------------------------------
# #
......
...@@ -98,6 +98,7 @@ class Context(object): ...@@ -98,6 +98,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators
from TypeInference import MarkAssignments, MarkOverflowingArithmetic from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform from AnalysedTreeTransforms import AutoTestDictTransform
...@@ -143,6 +144,7 @@ class Context(object): ...@@ -143,6 +144,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
ExpandInplaceOperators(self),
OptimizeBuiltinCalls(self), ## Necessary? OptimizeBuiltinCalls(self), ## Necessary?
IterationTransform(), IterationTransform(),
SwitchTransform(), SwitchTransform(),
......
...@@ -3520,132 +3520,41 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -3520,132 +3520,41 @@ class InPlaceAssignmentNode(AssignmentNode):
# (it must be a NameNode, AttributeNode, or IndexNode). # (it must be a NameNode, AttributeNode, or IndexNode).
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
dup = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env) self.lhs.analyse_target_declaration(env)
def analyse_types(self, env): def analyse_types(self, env):
self.dup = self.create_dup_node(env) # re-assigns lhs to a shallow copy
self.rhs.analyse_types(env) self.rhs.analyse_types(env)
self.lhs.analyse_target_types(env) self.lhs.analyse_target_types(env)
import ExprNodes
if self.lhs.type.is_pyobject:
self.rhs = self.rhs.coerce_to_pyobject(env)
elif self.rhs.type.is_pyobject or (self.lhs.type.is_numeric and self.rhs.type.is_numeric):
self.rhs = self.rhs.coerce_to(self.lhs.type, env)
if self.lhs.type.is_pyobject:
self.result_value_temp = ExprNodes.PyTempNode(self.pos, env)
self.result_value = self.result_value_temp.coerce_to(self.lhs.type, env)
def generate_execution_code(self, code): def generate_execution_code(self, code):
import ExprNodes import ExprNodes
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
self.dup.generate_subexpr_evaluation_code(code) self.lhs.generate_subexpr_evaluation_code(code)
if self.dup.is_temp: c_op = self.operator
self.dup.allocate_temp_result(code) if c_op == "//":
# self.dup.generate_result_code is run only if it is not buffer access c_op = "/"
if self.operator == "**": elif c_op == "**":
extra = ", Py_None" error(self.pos, "No C inplace power operator")
else: if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
extra = "" if self.lhs.type.is_pyobject:
if self.lhs.type.is_pyobject:
if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
error(self.pos, "In-place operators not allowed on object buffers in this release.") error(self.pos, "In-place operators not allowed on object buffers in this release.")
self.dup.generate_result_code(code) if c_op in ('/', '%') and self.lhs.type.is_int and not code.directives['cdivision']:
self.result_value_temp.allocate(code) error(self.pos, "In-place non-c divide operators not allowed on int buffers.")
code.putln( self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
"%s = %s(%s, %s%s); %s" % (
self.result_value.result(),
self.py_operation_function(),
self.dup.py_result(),
self.rhs.py_result(),
extra,
code.error_goto_if_null(self.result_value.py_result(), self.pos)))
code.put_gotref(self.result_value.py_result())
self.result_value.generate_evaluation_code(code) # May be a type check...
self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code)
self.dup.generate_disposal_code(code)
self.dup.free_temps(code)
self.lhs.generate_assignment_code(self.result_value, code)
self.result_value_temp.release(code)
else:
c_op = self.operator
if c_op == "//":
c_op = "/"
elif c_op == "**":
error(self.pos, "No C inplace power operator")
elif self.lhs.type.is_complex:
error(self.pos, "Inplace operators not implemented for complex types.")
# have to do assignment directly to avoid side-effects
if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
else:
self.dup.generate_result_code(code)
code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()) )
self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code)
if self.dup.is_temp:
self.dup.generate_subexpr_disposal_code(code)
self.dup.free_subexpr_temps(code)
def create_dup_node(self, env):
import ExprNodes
self.dup = self.lhs
self.dup.analyse_types(env)
if isinstance(self.lhs, ExprNodes.NameNode):
target_lhs = ExprNodes.NameNode(self.dup.pos,
name = self.dup.name,
is_temp = self.dup.is_temp,
entry = self.dup.entry)
elif isinstance(self.lhs, ExprNodes.AttributeNode):
target_lhs = ExprNodes.AttributeNode(self.dup.pos,
obj = ExprNodes.CloneNode(self.lhs.obj),
attribute = self.dup.attribute,
is_temp = self.dup.is_temp)
elif isinstance(self.lhs, ExprNodes.IndexNode):
if self.lhs.index:
index = ExprNodes.CloneNode(self.lhs.index)
else:
index = None
if self.lhs.indices:
indices = [ExprNodes.CloneNode(x) for x in self.lhs.indices]
else:
indices = []
target_lhs = ExprNodes.IndexNode(self.dup.pos,
base = ExprNodes.CloneNode(self.dup.base),
index = index,
indices = indices,
is_temp = self.dup.is_temp)
else: else:
assert False, "Unsupported node: %s" % type(self.lhs) # C++
self.lhs = target_lhs # TODO: make sure overload is declared
return self.dup code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()))
self.lhs.generate_subexpr_disposal_code(code)
def py_operation_function(self): self.lhs.free_subexpr_temps(code)
return self.py_functions[self.operator] self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code)
py_functions = {
"|": "PyNumber_InPlaceOr",
"^": "PyNumber_InPlaceXor",
"&": "PyNumber_InPlaceAnd",
"+": "PyNumber_InPlaceAdd",
"-": "PyNumber_InPlaceSubtract",
"*": "PyNumber_InPlaceMultiply",
"/": "__Pyx_PyNumber_InPlaceDivide",
"%": "PyNumber_InPlaceRemainder",
"<<": "PyNumber_InPlaceLshift",
">>": "PyNumber_InPlaceRshift",
"**": "PyNumber_InPlacePower",
"//": "PyNumber_InPlaceFloorDivide",
}
def annotate(self, code): def annotate(self, code):
self.lhs.annotate(code) self.lhs.annotate(code)
self.rhs.annotate(code) self.rhs.annotate(code)
self.dup.annotate(code)
def create_binop_node(self): def create_binop_node(self):
import ExprNodes import ExprNodes
......
...@@ -1194,7 +1194,73 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1194,7 +1194,73 @@ class AnalyseExpressionsTransform(CythonTransform):
node.analyse_scoped_expressions(node.expr_scope) node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
class ExpandInplaceOperators(CythonTransform):
def __call__(self, root):
self.env_stack = [root.scope]
return super(ExpandInplaceOperators, self).__call__(root)
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
return node
def visit_InPlaceAssignmentNode(self, node):
lhs = node.lhs
rhs = node.rhs
if lhs.type.is_cpp_class:
# No getting around this exact operator here.
return node
if isinstance(lhs, IndexNode) and lhs.is_buffer_access:
# There is code to handle this case.
return node
def side_effect_free_reference(node, setting=False):
if isinstance(node, NameNode):
return node, []
elif node.type.is_pyobject and not setting:
node = LetRefNode(node)
return node, [node]
elif isinstance(node, IndexNode):
if node.is_buffer_access:
raise ValueError, "Buffer access"
base, temps = side_effect_free_reference(node.base)
index = LetRefNode(node.index)
return IndexNode(node.pos, base=base, index=index), temps + [index]
elif isinstance(node, AttributeNode):
obj, temps = side_effect_free_reference(node.obj)
return AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
else:
node = LetRefNode(node)
return node, [node]
try:
lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
except ValueError:
return node
dup = lhs.__class__(**lhs.__dict__)
binop = binop_node(node.pos,
operator = node.operator,
operand1 = dup,
operand2 = rhs,
inplace=True)
node = SingleAssignmentNode(node.pos, lhs=lhs, rhs=binop)
# Use LetRefNode to avoid side effects.
let_ref_nodes.reverse()
for t in let_ref_nodes:
node = LetNode(t, node)
node.analyse_expressions(self.env_stack[-1])
return node
def visit_ExprNode(self, node):
# In-place assignments can't happen within an expression.
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
class AlignFunctionDefinitions(CythonTransform): class AlignFunctionDefinitions(CythonTransform):
""" """
This class takes the signatures from a .pxd file and applies them to This class takes the signatures from a .pxd file and applies them to
......
...@@ -8,6 +8,7 @@ import Nodes ...@@ -8,6 +8,7 @@ import Nodes
import ExprNodes import ExprNodes
from Nodes import Node from Nodes import Node
from ExprNodes import AtomicExprNode from ExprNodes import AtomicExprNode
from PyrexTypes import c_ptr_type
class TempHandle(object): class TempHandle(object):
# THIS IS DEPRECATED, USE LetRefNode instead # THIS IS DEPRECATED, USE LetRefNode instead
...@@ -196,6 +197,8 @@ class LetNodeMixin: ...@@ -196,6 +197,8 @@ class LetNodeMixin:
def setup_temp_expr(self, code): def setup_temp_expr(self, code):
self.temp_expression.generate_evaluation_code(code) self.temp_expression.generate_evaluation_code(code)
self.temp_type = self.temp_expression.type self.temp_type = self.temp_expression.type
if self.temp_type.is_array:
self.temp_type = c_ptr_type(self.temp_type.base_type)
self._result_in_temp = self.temp_expression.result_in_temp() self._result_in_temp = self.temp_expression.result_in_temp()
if self._result_in_temp: if self._result_in_temp:
self.temp = self.temp_expression.result() self.temp = self.temp_expression.result()
......
__doc__ = u""" cimport cython
>>> str(f(5, 7))
'29509034655744'
"""
def f(a,b): def f(a,b):
"""
>>> str(f(5, 7))
'29509034655744'
"""
a += b a += b
a *= b a *= b
a **= b a **= b
...@@ -117,3 +117,130 @@ def test_side_effects(): ...@@ -117,3 +117,130 @@ def test_side_effects():
b[side_effect(3)] += 10 b[side_effect(3)] += 10
b[c_side_effect(4)] += 100 b[c_side_effect(4)] += 100
return a, [b[i] for i from 0 <= i < 5] return a, [b[i] for i from 0 <= i < 5]
@cython.cdivision(True)
def test_inplace_cdivision(int a, int b):
"""
>>> test_inplace_cdivision(13, 10)
3
>>> test_inplace_cdivision(13, -10)
3
>>> test_inplace_cdivision(-13, 10)
-3
>>> test_inplace_cdivision(-13, -10)
-3
"""
a %= b
return a
@cython.cdivision(False)
def test_inplace_pydivision(int a, int b):
"""
>>> test_inplace_pydivision(13, 10)
3
>>> test_inplace_pydivision(13, -10)
-7
>>> test_inplace_pydivision(-13, 10)
7
>>> test_inplace_pydivision(-13, -10)
-3
"""
a %= b
return a
def test_complex_inplace(double complex x, double complex y):
"""
>>> test_complex_inplace(1, 1)
(2+0j)
>>> test_complex_inplace(2, 3)
(15+0j)
>>> test_complex_inplace(2+3j, 4+5j)
(-16+62j)
"""
x += y
x *= y
return x
# The following is more subtle than one might expect.
cdef struct Inner:
int x
cdef struct Aa:
int value
Inner inner
cdef struct NestedA:
Aa a
cdef struct ArrayOfA:
Aa[10] a
def nested_struct_assignment():
"""
>>> nested_struct_assignment()
"""
cdef NestedA nested
nested.a.value = 2
nested.a.value += 3
assert nested.a.value == 5
nested.a.inner.x = 5
nested.a.inner.x += 10
assert nested.a.inner.x == 15
def nested_array_assignment():
"""
>>> nested_array_assignment()
c side effect 0
c side effect 1
"""
cdef ArrayOfA array
array.a[0].value = 2
array.a[c_side_effect(0)].value += 3
assert array.a[0].value == 5
array.a[1].inner.x = 5
array.a[c_side_effect(1)].inner.x += 10
assert array.a[1].inner.x == 15
cdef class VerboseDict(object):
cdef name
cdef dict dict
def __init__(self, name, **kwds):
self.name = name
self.dict = kwds
def __getitem__(self, key):
print self.name, "__getitem__", key
return self.dict[key]
def __setitem__(self, key, value):
print self.name, "__setitem__", key, value
self.dict[key] = value
def __repr__(self):
return repr(self.name)
def deref_and_increment(o, key):
"""
>>> deref_and_increment({'a': 1}, 'a')
side effect a
>>> v = VerboseDict('v', a=10)
>>> deref_and_increment(v, 'a')
side effect a
v __getitem__ a
v __setitem__ a 11
"""
o[side_effect(key)] += 1
def double_deref_and_increment(o, key1, key2):
"""
>>> v = VerboseDict('v', a=10)
>>> w = VerboseDict('w', vkey=v)
>>> double_deref_and_increment(w, 'vkey', 'a')
side effect vkey
w __getitem__ vkey
side effect a
v __getitem__ a
v __setitem__ a 11
"""
o[side_effect(key1)][side_effect(key2)] += 1
...@@ -350,8 +350,7 @@ cdef object some_float_value(): ...@@ -350,8 +350,7 @@ cdef object some_float_value():
@cython.test_fail_if_path_exists('//NameNode[@type.is_pyobject = True]') @cython.test_fail_if_path_exists('//NameNode[@type.is_pyobject = True]')
@cython.test_assert_path_exists('//InPlaceAssignmentNode/NameNode', @cython.test_assert_path_exists('//NameNode[@type.is_pyobject]',
'//NameNode[@type.is_pyobject]',
'//NameNode[@type.is_pyobject = False]') '//NameNode[@type.is_pyobject = False]')
@infer_types(None) @infer_types(None)
def double_loop(): def double_loop():
......
...@@ -243,7 +243,6 @@ def index_add(unicode ustring, Py_ssize_t i, Py_ssize_t j): ...@@ -243,7 +243,6 @@ def index_add(unicode ustring, Py_ssize_t i, Py_ssize_t j):
@cython.test_assert_path_exists("//CoerceToPyTypeNode", @cython.test_assert_path_exists("//CoerceToPyTypeNode",
"//IndexNode", "//IndexNode",
"//InPlaceAssignmentNode",
"//CoerceToPyTypeNode//IndexNode") "//CoerceToPyTypeNode//IndexNode")
@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode") @cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
def index_concat_loop(unicode ustring): def index_concat_loop(unicode ustring):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment