Commit 72d0f571 authored by Stefan Behnel's avatar Stefan Behnel

moved constant folding before type analysis, disabled for type casts and float expressions

parent a5bba3a5
...@@ -901,7 +901,9 @@ class FloatNode(ConstNode): ...@@ -901,7 +901,9 @@ class FloatNode(ConstNode):
type = PyrexTypes.c_double_type type = PyrexTypes.c_double_type
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = float(self.value) # calculating float values is usually not a good idea
#self.constant_result = float(self.value)
pass
def compile_time_value(self, denv): def compile_time_value(self, denv):
return float(self.value) return float(self.value)
...@@ -3927,7 +3929,9 @@ class TypecastNode(NewTempExprNode): ...@@ -3927,7 +3929,9 @@ class TypecastNode(NewTempExprNode):
self.operand.check_const() self.operand.check_const()
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.operand.constant_result # we usually do not know the result of a type cast at code
# generation time
pass
def calculate_result_code(self): def calculate_result_code(self):
opnd = self.operand opnd = self.operand
...@@ -4939,7 +4943,8 @@ class CoercionNode(NewTempExprNode): ...@@ -4939,7 +4943,8 @@ class CoercionNode(NewTempExprNode):
print("%s Coercing %s" % (self, self.arg)) print("%s Coercing %s" % (self, self.arg))
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.arg.constant_result # constant folding can break type coercion, so this is disabled
pass
def annotate(self, code): def annotate(self, code):
self.arg.annotate(code) self.arg.annotate(code)
...@@ -4986,7 +4991,11 @@ class PyTypeTestNode(CoercionNode): ...@@ -4986,7 +4991,11 @@ class PyTypeTestNode(CoercionNode):
def is_ephemeral(self): def is_ephemeral(self):
return self.arg.is_ephemeral() return self.arg.is_ephemeral()
def calculate_constant_result(self):
# FIXME
pass
def calculate_result_code(self): def calculate_result_code(self):
return self.arg.result() return self.arg.result()
......
...@@ -115,6 +115,7 @@ class Context(object): ...@@ -115,6 +115,7 @@ class Context(object):
_specific_post_parse, _specific_post_parse,
InterpretCompilerDirectives(self, self.pragma_overrides), InterpretCompilerDirectives(self, self.pragma_overrides),
_align_function_definitions, _align_function_definitions,
ConstantFolding(),
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(self), WithTransform(self),
DecoratorTransform(self), DecoratorTransform(self),
...@@ -125,7 +126,6 @@ class Context(object): ...@@ -125,7 +126,6 @@ class Context(object):
_check_c_declarations, _check_c_declarations,
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
FlattenBuiltinTypeCreation(), FlattenBuiltinTypeCreation(),
ConstantFolding(),
# ComprehensionTransform(), # ComprehensionTransform(),
IterationTransform(), IterationTransform(),
SwitchTransform(), SwitchTransform(),
......
...@@ -566,44 +566,60 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -566,44 +566,60 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
import traceback, sys import traceback, sys
traceback.print_exc(file=sys.stdout) traceback.print_exc(file=sys.stdout)
NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
ExprNodes.LongNode, ExprNodes.FloatNode)
def _widest_node_class(self, *nodes):
try:
return self.NODE_TYPE_ORDER[
max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
except ValueError:
return None
def visit_ExprNode(self, node): def visit_ExprNode(self, node):
self._calculate_const(node) self._calculate_const(node)
return node return node
# def visit_NumBinopNode(self, node):
def visit_BinopNode(self, node): def visit_BinopNode(self, node):
self._calculate_const(node) self._calculate_const(node)
if node.type is PyrexTypes.py_object_type:
return node
if node.constant_result is ExprNodes.not_a_constant: if node.constant_result is ExprNodes.not_a_constant:
return node return node
# print node.constant_result, node.operand1, node.operand2, node.pos try:
if node.operand1.type is None or node.operand2.type is None:
return node
except AttributeError:
return node
type1, type2 = node.operand1.type, node.operand2.type
if isinstance(node.operand1, ExprNodes.ConstNode) and \ if isinstance(node.operand1, ExprNodes.ConstNode) and \
node.type is node.operand1.type: isinstance(node.operand1, ExprNodes.ConstNode):
new_node = node.operand1 if type1 is type2:
elif isinstance(node.operand2, ExprNodes.ConstNode) and \ new_node = node.operand1
node.type is node.operand2.type: else:
new_node = node.operand2 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
if type(node.operand1) is type(node.operand2):
new_node = node.operand1
new_node.type = widest_type
elif type1 is widest_type:
new_node = node.operand1
elif type2 is widest_type:
new_node = node.operand2
else:
target_class = self._widest_node_class(
node.operand1, node.operand2)
if target_class is None:
return node
new_node = target_class(type = widest_type)
else: else:
return node return node
new_node.value = new_node.constant_result = node.constant_result
new_node = new_node.coerce_to(node.type, self.current_scope) new_node.constant_result = node.constant_result
new_node.value = str(node.constant_result)
#new_node = new_node.coerce_to(node.type, self.current_scope)
return new_node return new_node
# in the future, other nodes can have their own handler method here # in the future, other nodes can have their own handler method here
# that can replace them with a constant result node # that can replace them with a constant result node
def visit_ModuleNode(self, node):
self.current_scope = node.scope
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
old_scope = self.current_scope
self.current_scope = node.entry.scope
self.visitchildren(node)
self.current_scope = old_scope
return node
visit_Node = Visitor.VisitorTransform.recurse_to_children visit_Node = Visitor.VisitorTransform.recurse_to_children
......
...@@ -5,6 +5,13 @@ True ...@@ -5,6 +5,13 @@ True
True True
>>> neg() == -1 -2 - (-3+4) >>> neg() == -1 -2 - (-3+4)
True True
>>> int_mix() == 1 + (2 * 3) // 2
True
>>> if IS_PY3: type(int_mix()) is int
... else: type(int_mix()) is long
True
>>> int_cast() == 1 + 2 * 6000
True
>>> mul() == 1*60*1000 >>> mul() == 1*60*1000
True True
>>> arithm() == 9*2+3*8/6-10 >>> arithm() == 9*2+3*8/6-10
...@@ -15,6 +22,9 @@ True ...@@ -15,6 +22,9 @@ True
True True
""" """
import sys
IS_PY3 = sys.version_info[0] >= 3
def _func(a,b,c): def _func(a,b,c):
return a+b+c return a+b+c
...@@ -27,6 +37,12 @@ def add_var(a): ...@@ -27,6 +37,12 @@ def add_var(a):
def neg(): def neg():
return -1 -2 - (-3+4) return -1 -2 - (-3+4)
def int_mix():
return 1L + (2 * 3L) // 2
def int_cast():
return <int>(1 + 2 * 6000)
def mul(): def mul():
return 1*60*1000 return 1*60*1000
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment