Commit 698f7535 authored by scoder's avatar scoder

Merge pull request #118 from vitek/_markassignments

Use assignments collected by CF for type inference
parents 7c852f79 e99f6ac7
...@@ -26,6 +26,7 @@ cdef class ExitBlock(ControlBlock): ...@@ -26,6 +26,7 @@ cdef class ExitBlock(ControlBlock):
cdef class NameAssignment: cdef class NameAssignment:
cdef public bint is_arg cdef public bint is_arg
cdef public bint is_deletion
cdef public object lhs cdef public object lhs
cdef public object rhs cdef public object rhs
cdef public object entry cdef public object entry
......
...@@ -8,7 +8,8 @@ cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object, ...@@ -8,7 +8,8 @@ cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
import Builtin import Builtin
import ExprNodes import ExprNodes
import Nodes import Nodes
from PyrexTypes import py_object_type from PyrexTypes import py_object_type, unspecified_type
import PyrexTypes
from Visitor import TreeVisitor, CythonTransform from Visitor import TreeVisitor, CythonTransform
from Errors import error, warning, InternalError from Errors import error, warning, InternalError
...@@ -24,6 +25,9 @@ class TypedExprNode(ExprNodes.ExprNode): ...@@ -24,6 +25,9 @@ class TypedExprNode(ExprNodes.ExprNode):
object_expr = TypedExprNode(py_object_type, may_be_none=True) object_expr = TypedExprNode(py_object_type, may_be_none=True)
object_expr_not_none = TypedExprNode(py_object_type, may_be_none=False) object_expr_not_none = TypedExprNode(py_object_type, may_be_none=False)
# Fake rhs to silence "unused variable" warning
fake_rhs_expr = TypedExprNode(unspecified_type)
class ControlBlock(object): class ControlBlock(object):
"""Control flow graph node. Sequence of assignments and name references. """Control flow graph node. Sequence of assignments and name references.
...@@ -174,7 +178,7 @@ class ControlFlow(object): ...@@ -174,7 +178,7 @@ class ControlFlow(object):
def mark_deletion(self, node, entry): def mark_deletion(self, node, entry):
if self.block and self.is_tracked(entry): if self.block and self.is_tracked(entry):
assignment = NameAssignment(node, None, entry) assignment = NameDeletion(node, entry)
self.block.stats.append(assignment) self.block.stats.append(assignment)
self.block.gen[entry] = Uninitialized self.block.gen[entry] = Uninitialized
self.entries.add(entry) self.entries.add(entry)
...@@ -293,6 +297,7 @@ class ExceptionDescr(object): ...@@ -293,6 +297,7 @@ class ExceptionDescr(object):
self.finally_enter = finally_enter self.finally_enter = finally_enter
self.finally_exit = finally_exit self.finally_exit = finally_exit
class NameAssignment(object): class NameAssignment(object):
def __init__(self, lhs, rhs, entry): def __init__(self, lhs, rhs, entry):
if lhs.cf_state is None: if lhs.cf_state is None:
...@@ -303,15 +308,24 @@ class NameAssignment(object): ...@@ -303,15 +308,24 @@ class NameAssignment(object):
self.pos = lhs.pos self.pos = lhs.pos
self.refs = set() self.refs = set()
self.is_arg = False self.is_arg = False
self.is_deletion = False
def __repr__(self): def __repr__(self):
return '%s(entry=%r)' % (self.__class__.__name__, self.entry) return '%s(entry=%r)' % (self.__class__.__name__, self.entry)
class Argument(NameAssignment): class Argument(NameAssignment):
def __init__(self, lhs, rhs, entry): def __init__(self, lhs, rhs, entry):
NameAssignment.__init__(self, lhs, rhs, entry) NameAssignment.__init__(self, lhs, rhs, entry)
self.is_arg = True self.is_arg = True
class NameDeletion(NameAssignment):
def __init__(self, lhs, entry):
NameAssignment.__init__(self, lhs, lhs, entry)
self.is_deletion = True
class Uninitialized(object): class Uninitialized(object):
pass pass
...@@ -462,12 +476,13 @@ def check_definitions(flow, compiler_directives): ...@@ -462,12 +476,13 @@ def check_definitions(flow, compiler_directives):
stat.lhs.cf_state.update(state) stat.lhs.cf_state.update(state)
assmt_nodes.add(stat.lhs) assmt_nodes.add(stat.lhs)
i_state = i_state & ~i_assmts.mask i_state = i_state & ~i_assmts.mask
if stat.rhs: if stat.is_deletion:
i_state |= stat.bit
else:
i_state |= i_assmts.bit i_state |= i_assmts.bit
else:
i_state |= stat.bit
assignments.add(stat) assignments.add(stat)
stat.entry.cf_assignments.append(stat) if stat.rhs is not fake_rhs_expr:
stat.entry.cf_assignments.append(stat)
elif isinstance(stat, NameReference): elif isinstance(stat, NameReference):
references[stat.node] = stat.entry references[stat.node] = stat.entry
stat.entry.cf_references.append(stat) stat.entry.cf_references.append(stat)
...@@ -754,7 +769,8 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -754,7 +769,8 @@ class ControlFlowAnalysis(CythonTransform):
entry = self.env.lookup(node.name) entry = self.env.lookup(node.name)
if entry: if entry:
may_be_none = not node.not_none may_be_none = not node.not_none
self.flow.mark_argument(node, TypedExprNode(entry.type, may_be_none), entry) self.flow.mark_argument(
node, TypedExprNode(entry.type, may_be_none), entry)
return node return node
def visit_NameNode(self, node): def visit_NameNode(self, node):
...@@ -838,6 +854,59 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -838,6 +854,59 @@ class ControlFlowAnalysis(CythonTransform):
self.flow.block = None self.flow.block = None
return node return node
def mark_forloop_target(self, node):
# TODO: Remove redundancy with range optimization...
is_special = False
sequence = node.iterator.sequence
target = node.target
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and function.is_name:
entry = self.env.lookup(function.name)
if not entry or entry.is_builtin:
if function.name == 'reversed' and len(sequence.args) == 1:
sequence = sequence.args[0]
elif function.name == 'enumerate' and len(sequence.args) == 1:
if target.is_sequence_constructor and len(target.args) == 2:
iterator = sequence.args[0]
if iterator.is_name:
iterator_type = iterator.infer_type(self.env)
if iterator_type.is_builtin_type:
# assume that builtin types have a length within Py_ssize_t
self.mark_assignment(
target.args[0],
ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
type=PyrexTypes.c_py_ssize_t_type))
target = target.args[1]
sequence = sequence.args[0]
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and function.is_name:
entry = self.env.lookup(function.name)
if not entry or entry.is_builtin:
if function.name in ('range', 'xrange'):
is_special = True
for arg in sequence.args[:2]:
self.mark_assignment(target, arg)
if len(sequence.args) > 2:
self.mark_assignment(
target,
ExprNodes.binop_node(node.pos,
'+',
sequence.args[0],
sequence.args[2]))
if not is_special:
# A for-loop basically translates to subsequent calls to
# __getitem__(), so using an IndexNode here allows us to
# naturally infer the base type of pointers, C arrays,
# Python strings, etc., while correctly falling back to an
# object type when the base type cannot be handled.
self.mark_assignment(target, ExprNodes.IndexNode(
node.pos,
base = sequence,
index = ExprNodes.IntNode(node.pos, value = '0')))
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
condition_block = self.flow.nextblock() condition_block = self.flow.nextblock()
next_block = self.flow.newblock() next_block = self.flow.newblock()
...@@ -846,7 +915,11 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -846,7 +915,11 @@ class ControlFlowAnalysis(CythonTransform):
self.visit(node.iterator) self.visit(node.iterator)
# Target assignment # Target assignment
self.flow.nextblock() self.flow.nextblock()
self.mark_assignment(node.target)
if isinstance(node, Nodes.ForInStatNode):
self.mark_forloop_target(node)
else: # Parallel
self.mark_assignment(node.target)
# Body block # Body block
if isinstance(node, Nodes.ParallelRangeNode): if isinstance(node, Nodes.ParallelRangeNode):
...@@ -916,12 +989,15 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -916,12 +989,15 @@ class ControlFlowAnalysis(CythonTransform):
self.flow.loops.append(LoopDescr(next_block, condition_block)) self.flow.loops.append(LoopDescr(next_block, condition_block))
self.visit(node.bound1) self.visit(node.bound1)
self.visit(node.bound2) self.visit(node.bound2)
if node.step: if node.step is not None:
self.visit(node.step) self.visit(node.step)
# Target assignment # Target assignment
self.flow.nextblock() self.flow.nextblock()
self.mark_assignment(node.target) self.mark_assignment(node.target, node.bound1)
if node.step is not None:
self.mark_assignment(node.target,
ExprNodes.binop_node(node.pos, '+',
node.bound1, node.step))
# Body block # Body block
self.flow.nextblock() self.flow.nextblock()
self.visit(node.body) self.visit(node.body)
...@@ -1143,6 +1219,6 @@ class ControlFlowAnalysis(CythonTransform): ...@@ -1143,6 +1219,6 @@ class ControlFlowAnalysis(CythonTransform):
def visit_AmpersandNode(self, node): def visit_AmpersandNode(self, node):
if node.operand.is_name: if node.operand.is_name:
# Fake assignment to silence warning # Fake assignment to silence warning
self.mark_assignment(node.operand) self.mark_assignment(node.operand, fake_rhs_expr)
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -129,7 +129,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -129,7 +129,7 @@ def create_pipeline(context, mode, exclude_classes=()):
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform from ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
from TypeInference import MarkAssignments, MarkOverflowingArithmetic from TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions from ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions
from ParseTreeTransforms import RemoveUnreachableCode, GilCheck from ParseTreeTransforms import RemoveUnreachableCode, GilCheck
from FlowControl import ControlFlowAnalysis from FlowControl import ControlFlowAnalysis
...@@ -179,10 +179,10 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -179,10 +179,10 @@ def create_pipeline(context, mode, exclude_classes=()):
EmbedSignature(context), EmbedSignature(context),
EarlyReplaceBuiltinCalls(context), ## Necessary? EarlyReplaceBuiltinCalls(context), ## Necessary?
TransformBuiltinMethods(context), ## Necessary? TransformBuiltinMethods(context), ## Necessary?
MarkAssignments(context), MarkParallelAssignments(context),
ControlFlowAnalysis(context), ControlFlowAnalysis(context),
RemoveUnreachableCode(context), RemoveUnreachableCode(context),
# MarkAssignments(context), # MarkParallelAssignments(context),
MarkOverflowingArithmetic(context), MarkOverflowingArithmetic(context),
IntroduceBufferAuxiliaryVars(context), IntroduceBufferAuxiliaryVars(context),
_check_c_declarations, _check_c_declarations,
......
...@@ -112,7 +112,6 @@ class Entry(object): ...@@ -112,7 +112,6 @@ class Entry(object):
# buffer_aux BufferAux or None Extra information needed for buffer variables # buffer_aux BufferAux or None Extra information needed for buffer variables
# inline_func_in_pxd boolean Hacky special case for inline function in pxd file. # inline_func_in_pxd boolean Hacky special case for inline function in pxd file.
# Ideally this should not be necesarry. # Ideally this should not be necesarry.
# assignments [ExprNode] List of expressions that get assigned to this entry.
# might_overflow boolean In an arithmetic expression that could cause # might_overflow boolean In an arithmetic expression that could cause
# overflow (used for type inference). # overflow (used for type inference).
# utility_code_definition For some Cython builtins, the utility code # utility_code_definition For some Cython builtins, the utility code
...@@ -193,7 +192,6 @@ class Entry(object): ...@@ -193,7 +192,6 @@ class Entry(object):
self.pos = pos self.pos = pos
self.init = init self.init = init
self.overloaded_alternatives = [] self.overloaded_alternatives = []
self.assignments = []
self.cf_assignments = [] self.cf_assignments = []
self.cf_references = [] self.cf_references = []
......
...@@ -15,7 +15,11 @@ class TypedExprNode(ExprNodes.ExprNode): ...@@ -15,7 +15,11 @@ class TypedExprNode(ExprNodes.ExprNode):
object_expr = TypedExprNode(py_object_type) object_expr = TypedExprNode(py_object_type)
class MarkAssignments(EnvTransform):
class MarkParallelAssignments(EnvTransform):
# Collects assignments inside parallel blocks prange, with parallel.
# Perhaps it's better to move it to ControlFlowAnalysis.
# tells us whether we're in a normal loop # tells us whether we're in a normal loop
in_loop = False in_loop = False
...@@ -24,14 +28,13 @@ class MarkAssignments(EnvTransform): ...@@ -24,14 +28,13 @@ class MarkAssignments(EnvTransform):
def __init__(self, context): def __init__(self, context):
# Track the parallel block scopes (with parallel, for i in prange()) # Track the parallel block scopes (with parallel, for i in prange())
self.parallel_block_stack = [] self.parallel_block_stack = []
return super(MarkAssignments, self).__init__(context) return super(MarkParallelAssignments, self).__init__(context)
def mark_assignment(self, lhs, rhs, inplace_op=None): def mark_assignment(self, lhs, rhs, inplace_op=None):
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)): if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
if lhs.entry is None: if lhs.entry is None:
# TODO: This shouldn't happen... # TODO: This shouldn't happen...
return return
lhs.entry.assignments.append(rhs)
if self.parallel_block_stack: if self.parallel_block_stack:
parallel_node = self.parallel_block_stack[-1] parallel_node = self.parallel_block_stack[-1]
...@@ -359,8 +362,8 @@ class SimpleAssignmentTypeInferer(object): ...@@ -359,8 +362,8 @@ class SimpleAssignmentTypeInferer(object):
entry.type = py_object_type entry.type = py_object_type
continue continue
all = set() all = set()
for expr in entry.assignments: for assmt in entry.cf_assignments:
all.update(expr.type_dependencies(scope)) all.update(assmt.rhs.type_dependencies(scope))
if all: if all:
dependancies_by_entry[entry] = all dependancies_by_entry[entry] = all
for dep in all: for dep in all:
...@@ -384,7 +387,8 @@ class SimpleAssignmentTypeInferer(object): ...@@ -384,7 +387,8 @@ class SimpleAssignmentTypeInferer(object):
while True: while True:
while ready_to_infer: while ready_to_infer:
entry = ready_to_infer.pop() entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments] types = [assmt.rhs.infer_type(scope)
for assmt in entry.cf_assignments]
if types and Utils.all(types): if types and Utils.all(types):
entry.type = spanning_type(types, entry.might_overflow) entry.type = spanning_type(types, entry.might_overflow)
else: else:
...@@ -397,10 +401,13 @@ class SimpleAssignmentTypeInferer(object): ...@@ -397,10 +401,13 @@ class SimpleAssignmentTypeInferer(object):
# Deal with simple circular dependancies... # Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items(): for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]): if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()] types = [assmt.rhs.infer_type(scope)
for assmt in entry.cf_assignments
if assmt.rhs.type_dependencies(scope) == ()]
if types: if types:
entry.type = spanning_type(types, entry.might_overflow) entry.type = spanning_type(types, entry.might_overflow)
types = [expr.infer_type(scope) for expr in entry.assignments] types = [assmt.rhs.infer_type(scope)
for assmt in entry.cf_assignments]
entry.type = spanning_type(types, entry.might_overflow) # might be wider... entry.type = spanning_type(types, entry.might_overflow) # might be wider...
resolve_dependancy(entry) resolve_dependancy(entry)
del dependancies_by_entry[entry] del dependancies_by_entry[entry]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment