Commit 72d0f571 authored by Stefan Behnel's avatar Stefan Behnel

moved constant folding before type analysis, disabled for type casts and float expressions

parent a5bba3a5
......@@ -901,7 +901,9 @@ class FloatNode(ConstNode):
type = PyrexTypes.c_double_type
def calculate_constant_result(self):
self.constant_result = float(self.value)
# calculating float values is usually not a good idea
#self.constant_result = float(self.value)
pass
def compile_time_value(self, denv):
return float(self.value)
......@@ -3927,7 +3929,9 @@ class TypecastNode(NewTempExprNode):
self.operand.check_const()
def calculate_constant_result(self):
self.constant_result = self.operand.constant_result
# we usually do not know the result of a type cast at code
# generation time
pass
def calculate_result_code(self):
opnd = self.operand
......@@ -4939,7 +4943,8 @@ class CoercionNode(NewTempExprNode):
print("%s Coercing %s" % (self, self.arg))
def calculate_constant_result(self):
self.constant_result = self.arg.constant_result
# constant folding can break type coercion, so this is disabled
pass
def annotate(self, code):
self.arg.annotate(code)
......@@ -4986,7 +4991,11 @@ class PyTypeTestNode(CoercionNode):
def is_ephemeral(self):
return self.arg.is_ephemeral()
def calculate_constant_result(self):
# FIXME
pass
def calculate_result_code(self):
return self.arg.result()
......
......@@ -115,6 +115,7 @@ class Context(object):
_specific_post_parse,
InterpretCompilerDirectives(self, self.pragma_overrides),
_align_function_definitions,
ConstantFolding(),
FlattenInListTransform(),
WithTransform(self),
DecoratorTransform(self),
......@@ -125,7 +126,6 @@ class Context(object):
_check_c_declarations,
AnalyseExpressionsTransform(self),
FlattenBuiltinTypeCreation(),
ConstantFolding(),
# ComprehensionTransform(),
IterationTransform(),
SwitchTransform(),
......
......@@ -566,44 +566,60 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
import traceback, sys
traceback.print_exc(file=sys.stdout)
NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
ExprNodes.LongNode, ExprNodes.FloatNode)
def _widest_node_class(self, *nodes):
try:
return self.NODE_TYPE_ORDER[
max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
except ValueError:
return None
def visit_ExprNode(self, node):
self._calculate_const(node)
return node
# def visit_NumBinopNode(self, node):
def visit_BinopNode(self, node):
self._calculate_const(node)
if node.type is PyrexTypes.py_object_type:
return node
if node.constant_result is ExprNodes.not_a_constant:
return node
# print node.constant_result, node.operand1, node.operand2, node.pos
try:
if node.operand1.type is None or node.operand2.type is None:
return node
except AttributeError:
return node
type1, type2 = node.operand1.type, node.operand2.type
if isinstance(node.operand1, ExprNodes.ConstNode) and \
node.type is node.operand1.type:
new_node = node.operand1
elif isinstance(node.operand2, ExprNodes.ConstNode) and \
node.type is node.operand2.type:
new_node = node.operand2
isinstance(node.operand1, ExprNodes.ConstNode):
if type1 is type2:
new_node = node.operand1
else:
widest_type = PyrexTypes.widest_numeric_type(type1, type2)
if type(node.operand1) is type(node.operand2):
new_node = node.operand1
new_node.type = widest_type
elif type1 is widest_type:
new_node = node.operand1
elif type2 is widest_type:
new_node = node.operand2
else:
target_class = self._widest_node_class(
node.operand1, node.operand2)
if target_class is None:
return node
new_node = target_class(type = widest_type)
else:
return node
new_node.value = new_node.constant_result = node.constant_result
new_node = new_node.coerce_to(node.type, self.current_scope)
new_node.constant_result = node.constant_result
new_node.value = str(node.constant_result)
#new_node = new_node.coerce_to(node.type, self.current_scope)
return new_node
# in the future, other nodes can have their own handler method here
# that can replace them with a constant result node
def visit_ModuleNode(self, node):
self.current_scope = node.scope
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
old_scope = self.current_scope
self.current_scope = node.entry.scope
self.visitchildren(node)
self.current_scope = old_scope
return node
visit_Node = Visitor.VisitorTransform.recurse_to_children
......
......@@ -5,6 +5,13 @@ True
True
>>> neg() == -1 -2 - (-3+4)
True
>>> int_mix() == 1 + (2 * 3) // 2
True
>>> if IS_PY3: type(int_mix()) is int
... else: type(int_mix()) is long
True
>>> int_cast() == 1 + 2 * 6000
True
>>> mul() == 1*60*1000
True
>>> arithm() == 9*2+3*8/6-10
......@@ -15,6 +22,9 @@ True
True
"""
import sys
IS_PY3 = sys.version_info[0] >= 3
def _func(a,b,c):
return a+b+c
......@@ -27,6 +37,12 @@ def add_var(a):
def neg():
return -1 -2 - (-3+4)
def int_mix():
return 1L + (2 * 3L) // 2
def int_cast():
return <int>(1 + 2 * 6000)
def mul():
return 1*60*1000
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment