Commit e9f3800b authored by Stefan Behnel's avatar Stefan Behnel

merge

parents 2f682f9b 1f2b10af
from Cython.Compiler.Visitor import VisitorTransform, ScopeTrackingTransform, TreeVisitor from Cython.Compiler.Visitor import VisitorTransform, ScopeTrackingTransform, TreeVisitor
from Nodes import StatListNode, SingleAssignmentNode, CFuncDefNode from Nodes import StatListNode, SingleAssignmentNode, CFuncDefNode
from ExprNodes import (DictNode, DictItemNode, NameNode, UnicodeNode, NoneNode, from ExprNodes import DictNode, DictItemNode, NameNode, UnicodeNode, NoneNode, \
ExprNode, AttributeNode, ModuleRefNode, DocstringRefNode) ExprNode, AttributeNode, ModuleRefNode, DocstringRefNode
from PyrexTypes import py_object_type from PyrexTypes import py_object_type
from Builtin import dict_type from Builtin import dict_type
from StringEncoding import EncodedString from StringEncoding import EncodedString
......
...@@ -12,8 +12,8 @@ import Naming ...@@ -12,8 +12,8 @@ import Naming
import Nodes import Nodes
from Nodes import Node from Nodes import Node
import PyrexTypes import PyrexTypes
from PyrexTypes import py_object_type, c_long_type, typecast, error_type from PyrexTypes import py_object_type, c_long_type, typecast, error_type, unspecified_type
from Builtin import list_type, tuple_type, set_type, dict_type, unicode_type, bytes_type from Builtin import list_type, tuple_type, set_type, dict_type, unicode_type, bytes_type, type_type
import Builtin import Builtin
import Symtab import Symtab
import Options import Options
...@@ -307,6 +307,27 @@ class ExprNode(Node): ...@@ -307,6 +307,27 @@ class ExprNode(Node):
temp_bool = bool.coerce_to_temp(env) temp_bool = bool.coerce_to_temp(env)
return temp_bool return temp_bool
# --------------- Type Inference -----------------
def type_dependencies(self, env):
# Returns the list of entries whose types must be determined
# before the type of self can be infered.
if hasattr(self, 'type') and self.type is not None:
return ()
return sum([node.type_dependencies(env) for node in self.subexpr_nodes()], ())
def infer_type(self, env):
# Attempt to deduce the type of self.
# Differs from analyse_types as it avoids unnecessary
# analysis of subexpressions, but can assume everything
# in self.type_dependencies() has been resolved.
if hasattr(self, 'type') and self.type is not None:
return self.type
elif hasattr(self, 'entry') and self.entry is not None:
return self.entry.type
else:
self.not_implemented("infer_type")
# --------------- Type Analysis ------------------ # --------------- Type Analysis ------------------
def analyse_as_module(self, env): def analyse_as_module(self, env):
...@@ -781,7 +802,6 @@ class StringNode(ConstNode): ...@@ -781,7 +802,6 @@ class StringNode(ConstNode):
# Arrange for a Python version of the string to be pre-allocated # Arrange for a Python version of the string to be pre-allocated
# when coercing to a Python type. # when coercing to a Python type.
if dst_type.is_pyobject and not self.type.is_pyobject: if dst_type.is_pyobject and not self.type.is_pyobject:
warn_once(self.pos, "String literals will no longer be Py3 bytes in Cython 0.12.", 1)
node = self.as_py_string_node(env) node = self.as_py_string_node(env)
else: else:
node = self node = self
...@@ -811,8 +831,9 @@ class StringNode(ConstNode): ...@@ -811,8 +831,9 @@ class StringNode(ConstNode):
def calculate_result_code(self): def calculate_result_code(self):
return self.result_code return self.result_code
class UnicodeNode(PyConstNode): class UnicodeNode(PyConstNode):
# entry Symtab.Entry
type = unicode_type type = unicode_type
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
...@@ -858,6 +879,8 @@ class LongNode(AtomicExprNode): ...@@ -858,6 +879,8 @@ class LongNode(AtomicExprNode):
# #
# value string # value string
type = py_object_type
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = long(self.value) self.constant_result = long(self.value)
...@@ -865,7 +888,6 @@ class LongNode(AtomicExprNode): ...@@ -865,7 +888,6 @@ class LongNode(AtomicExprNode):
return long(self.value) return long(self.value)
def analyse_types(self, env): def analyse_types(self, env):
self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
gil_message = "Constructing Python long int" gil_message = "Constructing Python long int"
...@@ -954,6 +976,27 @@ class NameNode(AtomicExprNode): ...@@ -954,6 +976,27 @@ class NameNode(AtomicExprNode):
create_analysed_rvalue = staticmethod(create_analysed_rvalue) create_analysed_rvalue = staticmethod(create_analysed_rvalue)
def type_dependencies(self, env):
if self.entry is None:
self.entry = env.lookup(self.name)
if self.entry is not None and self.entry.type.is_unspecified:
return (self.entry,)
else:
return ()
def infer_type(self, env):
if self.entry is None:
self.entry = env.lookup(self.name)
if self.entry is None:
return py_object_type
elif (self.entry.type.is_extension_type or self.entry.type.is_builtin_type) and \
self.name == self.entry.type.name:
# Unfortunately the type attribute of type objects
# is used for the pointer to the type the represent.
return type_type
else:
return self.entry.type
def compile_time_value(self, denv): def compile_time_value(self, denv):
try: try:
return denv.lookup(self.name) return denv.lookup(self.name)
...@@ -1023,7 +1066,11 @@ class NameNode(AtomicExprNode): ...@@ -1023,7 +1066,11 @@ class NameNode(AtomicExprNode):
if not self.entry: if not self.entry:
self.entry = env.lookup_here(self.name) self.entry = env.lookup_here(self.name)
if not self.entry: if not self.entry:
self.entry = env.declare_var(self.name, py_object_type, self.pos) if env.directives['infer_types']:
type = unspecified_type
else:
type = py_object_type
self.entry = env.declare_var(self.name, type, self.pos)
env.control_flow.set_state(self.pos, (self.name, 'initalized'), True) env.control_flow.set_state(self.pos, (self.name, 'initalized'), True)
env.control_flow.set_state(self.pos, (self.name, 'source'), 'assignment') env.control_flow.set_state(self.pos, (self.name, 'source'), 'assignment')
if self.entry.is_declared_generic: if self.entry.is_declared_generic:
...@@ -1294,12 +1341,13 @@ class BackquoteNode(ExprNode): ...@@ -1294,12 +1341,13 @@ class BackquoteNode(ExprNode):
# #
# arg ExprNode # arg ExprNode
type = py_object_type
subexprs = ['arg'] subexprs = ['arg']
def analyse_types(self, env): def analyse_types(self, env):
self.arg.analyse_types(env) self.arg.analyse_types(env)
self.arg = self.arg.coerce_to_pyobject(env) self.arg = self.arg.coerce_to_pyobject(env)
self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
gil_message = "Backquote expression" gil_message = "Backquote expression"
...@@ -1325,15 +1373,16 @@ class ImportNode(ExprNode): ...@@ -1325,15 +1373,16 @@ class ImportNode(ExprNode):
# module_name IdentifierStringNode dotted name of module # module_name IdentifierStringNode dotted name of module
# name_list ListNode or None list of names to be imported # name_list ListNode or None list of names to be imported
type = py_object_type
subexprs = ['module_name', 'name_list'] subexprs = ['module_name', 'name_list']
def analyse_types(self, env): def analyse_types(self, env):
self.module_name.analyse_types(env) self.module_name.analyse_types(env)
self.module_name = self.module_name.coerce_to_pyobject(env) self.module_name = self.module_name.coerce_to_pyobject(env)
if self.name_list: if self.name_list:
self.name_list.analyse_types(env) self.name_list.analyse_types(env)
self.name_list.coerce_to_pyobject(env) self.name_list.coerce_to_pyobject(env)
self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
env.use_utility_code(import_utility_code) env.use_utility_code(import_utility_code)
...@@ -1363,12 +1412,13 @@ class IteratorNode(ExprNode): ...@@ -1363,12 +1412,13 @@ class IteratorNode(ExprNode):
# #
# sequence ExprNode # sequence ExprNode
type = py_object_type
subexprs = ['sequence'] subexprs = ['sequence']
def analyse_types(self, env): def analyse_types(self, env):
self.sequence.analyse_types(env) self.sequence.analyse_types(env)
self.sequence = self.sequence.coerce_to_pyobject(env) self.sequence = self.sequence.coerce_to_pyobject(env)
self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
gil_message = "Iterating over Python object" gil_message = "Iterating over Python object"
...@@ -1420,10 +1470,11 @@ class NextNode(AtomicExprNode): ...@@ -1420,10 +1470,11 @@ class NextNode(AtomicExprNode):
# #
# iterator ExprNode # iterator ExprNode
type = py_object_type
def __init__(self, iterator, env): def __init__(self, iterator, env):
self.pos = iterator.pos self.pos = iterator.pos
self.iterator = iterator self.iterator = iterator
self.type = py_object_type
self.is_temp = 1 self.is_temp = 1
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -1476,9 +1527,10 @@ class ExcValueNode(AtomicExprNode): ...@@ -1476,9 +1527,10 @@ class ExcValueNode(AtomicExprNode):
# of an ExceptClauseNode to fetch the current # of an ExceptClauseNode to fetch the current
# exception value. # exception value.
type = py_object_type
def __init__(self, pos, env): def __init__(self, pos, env):
ExprNode.__init__(self, pos) ExprNode.__init__(self, pos)
self.type = py_object_type
def set_var(self, var): def set_var(self, var):
self.var = var self.var = var
...@@ -1594,6 +1646,19 @@ class IndexNode(ExprNode): ...@@ -1594,6 +1646,19 @@ class IndexNode(ExprNode):
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env))) return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
return None return None
def type_dependencies(self, env):
return self.base.type_dependencies(env)
def infer_type(self, env):
if isinstance(self.base, StringNode):
return py_object_type
base_type = self.base.infer_type(env)
if base_type.is_ptr or base_type.is_array:
return base_type.base_type
else:
# TODO: Handle buffers (hopefully without too much redundancy).
return py_object_type
def analyse_types(self, env): def analyse_types(self, env):
self.analyse_base_and_index_types(env, getting = 1) self.analyse_base_and_index_types(env, getting = 1)
...@@ -2102,6 +2167,9 @@ class SliceNode(ExprNode): ...@@ -2102,6 +2167,9 @@ class SliceNode(ExprNode):
# start ExprNode # start ExprNode
# stop ExprNode # stop ExprNode
# step ExprNode # step ExprNode
type = py_object_type
is_temp = 1
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.base.constant_result[ self.constant_result = self.base.constant_result[
...@@ -2133,8 +2201,6 @@ class SliceNode(ExprNode): ...@@ -2133,8 +2201,6 @@ class SliceNode(ExprNode):
self.start = self.start.coerce_to_pyobject(env) self.start = self.start.coerce_to_pyobject(env)
self.stop = self.stop.coerce_to_pyobject(env) self.stop = self.stop.coerce_to_pyobject(env)
self.step = self.step.coerce_to_pyobject(env) self.step = self.step.coerce_to_pyobject(env)
self.type = py_object_type
self.is_temp = 1
gil_message = "Constructing Python slice object" gil_message = "Constructing Python slice object"
...@@ -2150,6 +2216,7 @@ class SliceNode(ExprNode): ...@@ -2150,6 +2216,7 @@ class SliceNode(ExprNode):
class CallNode(ExprNode): class CallNode(ExprNode):
def analyse_as_type_constructor(self, env): def analyse_as_type_constructor(self, env):
type = self.function.analyse_as_type(env) type = self.function.analyse_as_type(env)
if type and type.is_struct_or_union: if type and type.is_struct_or_union:
...@@ -2202,6 +2269,20 @@ class SimpleCallNode(CallNode): ...@@ -2202,6 +2269,20 @@ class SimpleCallNode(CallNode):
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def type_dependencies(self, env):
# TODO: Update when Danilo's C++ code merged in to handle the
# the case of function overloading.
return self.function.type_dependencies(env)
def infer_type(self, env):
func_type = self.function.infer_type(env)
if func_type.is_ptr:
func_type = func_type.base_type
if func_type.is_cfunction:
return func_type.return_type
else:
return py_object_type
def analyse_as_type(self, env): def analyse_as_type(self, env):
attr = self.function.as_cython_attribute() attr = self.function.as_cython_attribute()
if attr == 'pointer': if attr == 'pointer':
...@@ -2462,6 +2543,8 @@ class GeneralCallNode(CallNode): ...@@ -2462,6 +2543,8 @@ class GeneralCallNode(CallNode):
# keyword_args ExprNode or None Dict of keyword arguments # keyword_args ExprNode or None Dict of keyword arguments
# starstar_arg ExprNode or None Dict of extra keyword args # starstar_arg ExprNode or None Dict of extra keyword args
type = py_object_type
subexprs = ['function', 'positional_args', 'keyword_args', 'starstar_arg'] subexprs = ['function', 'positional_args', 'keyword_args', 'starstar_arg']
nogil_check = Node.gil_error nogil_check = Node.gil_error
...@@ -2639,6 +2722,18 @@ class AttributeNode(ExprNode): ...@@ -2639,6 +2722,18 @@ class AttributeNode(ExprNode):
return getattr(obj, attr) return getattr(obj, attr)
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def type_dependencies(self, env):
return self.obj.type_dependencies(env)
def infer_type(self, env):
if self.analyse_as_cimported_attribute(env, 0):
return self.entry.type
elif self.analyse_as_unbound_cmethod(env):
return self.entry.type
else:
self.analyse_attribute(env, obj_type = self.obj.infer_type(env))
return self.type
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
pass pass
...@@ -2742,13 +2837,17 @@ class AttributeNode(ExprNode): ...@@ -2742,13 +2837,17 @@ class AttributeNode(ExprNode):
self.is_temp = 1 self.is_temp = 1
self.result_ctype = py_object_type self.result_ctype = py_object_type
def analyse_attribute(self, env): def analyse_attribute(self, env, obj_type = None):
# Look up attribute and set self.type and self.member. # Look up attribute and set self.type and self.member.
self.is_py_attr = 0 self.is_py_attr = 0
self.member = self.attribute self.member = self.attribute
if self.obj.type.is_string: if obj_type is None:
self.obj = self.obj.coerce_to_pyobject(env) if self.obj.type.is_string:
obj_type = self.obj.type self.obj = self.obj.coerce_to_pyobject(env)
obj_type = self.obj.type
else:
if obj_type.is_string:
obj_type = py_object_type
if obj_type.is_ptr or obj_type.is_array: if obj_type.is_ptr or obj_type.is_array:
obj_type = obj_type.base_type obj_type = obj_type.base_type
self.op = "->" self.op = "->"
...@@ -2787,10 +2886,11 @@ class AttributeNode(ExprNode): ...@@ -2787,10 +2886,11 @@ class AttributeNode(ExprNode):
# type, or it is an extension type and the attribute is either not # type, or it is an extension type and the attribute is either not
# declared or is declared as a Python method. Treat it as a Python # declared or is declared as a Python method. Treat it as a Python
# attribute reference. # attribute reference.
self.analyse_as_python_attribute(env) self.analyse_as_python_attribute(env, obj_type)
def analyse_as_python_attribute(self, env): def analyse_as_python_attribute(self, env, obj_type = None):
obj_type = self.obj.type if obj_type is None:
obj_type = self.obj.type
self.member = self.attribute self.member = self.attribute
if obj_type.is_pyobject: if obj_type.is_pyobject:
self.type = py_object_type self.type = py_object_type
...@@ -2943,6 +3043,7 @@ class StarredTargetNode(ExprNode): ...@@ -2943,6 +3043,7 @@ class StarredTargetNode(ExprNode):
subexprs = ['target'] subexprs = ['target']
is_starred = 1 is_starred = 1
type = py_object_type type = py_object_type
is_temp = 1
def __init__(self, pos, target): def __init__(self, pos, target):
self.pos = pos self.pos = pos
...@@ -3203,6 +3304,8 @@ class SequenceNode(ExprNode): ...@@ -3203,6 +3304,8 @@ class SequenceNode(ExprNode):
class TupleNode(SequenceNode): class TupleNode(SequenceNode):
# Tuple constructor. # Tuple constructor.
type = tuple_type
gil_message = "Constructing Python tuple" gil_message = "Constructing Python tuple"
...@@ -3212,7 +3315,6 @@ class TupleNode(SequenceNode): ...@@ -3212,7 +3315,6 @@ class TupleNode(SequenceNode):
self.is_literal = 1 self.is_literal = 1
else: else:
SequenceNode.analyse_types(self, env, skip_children) SequenceNode.analyse_types(self, env, skip_children)
self.type = tuple_type
def calculate_result_code(self): def calculate_result_code(self):
if len(self.args) > 0: if len(self.args) > 0:
...@@ -3271,6 +3373,13 @@ class ListNode(SequenceNode): ...@@ -3271,6 +3373,13 @@ class ListNode(SequenceNode):
obj_conversion_errors = [] obj_conversion_errors = []
gil_message = "Constructing Python list" gil_message = "Constructing Python list"
def type_dependencies(self, env):
return ()
def infer_type(self, env):
# TOOD: Infer non-object list arrays.
return list_type
def analyse_expressions(self, env): def analyse_expressions(self, env):
SequenceNode.analyse_expressions(self, env) SequenceNode.analyse_expressions(self, env)
...@@ -3378,10 +3487,16 @@ class ComprehensionNode(ExprNode): ...@@ -3378,10 +3487,16 @@ class ComprehensionNode(ExprNode):
subexprs = ["target"] subexprs = ["target"]
child_attrs = ["loop", "append"] child_attrs = ["loop", "append"]
def infer_type(self, env):
return self.target.infer_type(env)
def analyse_types(self, env): def analyse_types(self, env):
self.target.analyse_expressions(env) self.target.analyse_expressions(env)
self.type = self.target.type self.type = self.target.type
self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
# We are analysing declarations to late.
self.loop.target.analyse_target_declaration(env)
env.infer_types()
self.loop.analyse_declarations(env) self.loop.analyse_declarations(env)
self.loop.analyse_expressions(env) self.loop.analyse_expressions(env)
...@@ -3451,10 +3566,12 @@ class DictComprehensionAppendNode(ComprehensionAppendNode): ...@@ -3451,10 +3566,12 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
class SetNode(ExprNode): class SetNode(ExprNode):
# Set constructor. # Set constructor.
type = set_type
subexprs = ['args'] subexprs = ['args']
gil_message = "Constructing Python set" gil_message = "Constructing Python set"
def analyse_types(self, env): def analyse_types(self, env):
for i in range(len(self.args)): for i in range(len(self.args)):
arg = self.args[i] arg = self.args[i]
...@@ -3518,6 +3635,13 @@ class DictNode(ExprNode): ...@@ -3518,6 +3635,13 @@ class DictNode(ExprNode):
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def type_dependencies(self, env):
return ()
def infer_type(self, env):
# TOOD: Infer struct constructors.
return dict_type
def analyse_types(self, env): def analyse_types(self, env):
hold_errors() hold_errors()
for item in self.key_value_pairs: for item in self.key_value_pairs:
...@@ -3681,12 +3805,13 @@ class UnboundMethodNode(ExprNode): ...@@ -3681,12 +3805,13 @@ class UnboundMethodNode(ExprNode):
# #
# function ExprNode Function object # function ExprNode Function object
type = py_object_type
is_temp = 1
subexprs = ['function'] subexprs = ['function']
def analyse_types(self, env): def analyse_types(self, env):
self.function.analyse_types(env) self.function.analyse_types(env)
self.type = py_object_type
self.is_temp = 1
gil_message = "Constructing an unbound method" gil_message = "Constructing an unbound method"
...@@ -3707,10 +3832,12 @@ class PyCFunctionNode(AtomicExprNode): ...@@ -3707,10 +3832,12 @@ class PyCFunctionNode(AtomicExprNode):
# #
# pymethdef_cname string PyMethodDef structure # pymethdef_cname string PyMethodDef structure
type = py_object_type
is_temp = 1
def analyse_types(self, env): def analyse_types(self, env):
self.type = py_object_type pass
self.is_temp = 1
gil_message = "Constructing Python function" gil_message = "Constructing Python function"
def generate_result_code(self, code): def generate_result_code(self, code):
...@@ -3764,6 +3891,9 @@ class UnopNode(ExprNode): ...@@ -3764,6 +3891,9 @@ class UnopNode(ExprNode):
return func(operand) return func(operand)
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def infer_type(self, env):
return self.operand.infer_type(env)
def analyse_types(self, env): def analyse_types(self, env):
self.operand.analyse_types(env) self.operand.analyse_types(env)
...@@ -3812,7 +3942,11 @@ class NotNode(ExprNode): ...@@ -3812,7 +3942,11 @@ class NotNode(ExprNode):
# 'not' operator # 'not' operator
# #
# operand ExprNode # operand ExprNode
type = PyrexTypes.c_bint_type
subexprs = ['operand']
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = not self.operand.constant_result self.constant_result = not self.operand.constant_result
...@@ -3823,12 +3957,12 @@ class NotNode(ExprNode): ...@@ -3823,12 +3957,12 @@ class NotNode(ExprNode):
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
subexprs = ['operand'] def infer_type(self, env):
return PyrexTypes.c_bint_type
def analyse_types(self, env): def analyse_types(self, env):
self.operand.analyse_types(env) self.operand.analyse_types(env)
self.operand = self.operand.coerce_to_boolean(env) self.operand = self.operand.coerce_to_boolean(env)
self.type = PyrexTypes.c_bint_type
def calculate_result_code(self): def calculate_result_code(self):
return "(!%s)" % self.operand.result() return "(!%s)" % self.operand.result()
...@@ -3896,6 +4030,9 @@ class AmpersandNode(ExprNode): ...@@ -3896,6 +4030,9 @@ class AmpersandNode(ExprNode):
# operand ExprNode # operand ExprNode
subexprs = ['operand'] subexprs = ['operand']
def infer_type(self, env):
return PyrexTypes.c_ptr_type(self.operand.infer_type(env))
def analyse_types(self, env): def analyse_types(self, env):
self.operand.analyse_types(env) self.operand.analyse_types(env)
...@@ -3954,6 +4091,15 @@ class TypecastNode(ExprNode): ...@@ -3954,6 +4091,15 @@ class TypecastNode(ExprNode):
subexprs = ['operand'] subexprs = ['operand']
base_type = declarator = type = None base_type = declarator = type = None
def type_dependencies(self, env):
return ()
def infer_type(self, env):
if self.type is None:
base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env)
return self.type
def analyse_types(self, env): def analyse_types(self, env):
if self.type is None: if self.type is None:
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
...@@ -4103,6 +4249,29 @@ class SizeofVarNode(SizeofNode): ...@@ -4103,6 +4249,29 @@ class SizeofVarNode(SizeofNode):
def generate_result_code(self, code): def generate_result_code(self, code):
pass pass
class TypeofNode(ExprNode):
# Compile-time type of an expression, as a string.
#
# operand ExprNode
# literal StringNode # internal
literal = None
type = py_object_type
subexprs = ['operand', 'literal']
def analyse_types(self, env):
self.operand.analyse_types(env)
from StringEncoding import EncodedString
self.literal = StringNode(self.pos, value=EncodedString(str(self.operand.type)))
self.literal.analyse_types(env)
self.literal = self.literal.coerce_to_pyobject(env)
def generate_evaluation_code(self, code):
self.literal.generate_evaluation_code(code)
def calculate_result_code(self):
return self.literal.calculate_result_code()
#------------------------------------------------------------------- #-------------------------------------------------------------------
# #
...@@ -4175,7 +4344,11 @@ class BinopNode(ExprNode): ...@@ -4175,7 +4344,11 @@ class BinopNode(ExprNode):
return func(operand1, operand2) return func(operand1, operand2)
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def infer_type(self, env):
return self.result_type(self.operand1.infer_type(env),
self.operand2.infer_type(env))
def analyse_types(self, env): def analyse_types(self, env):
self.operand1.analyse_types(env) self.operand1.analyse_types(env)
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
...@@ -4189,13 +4362,21 @@ class BinopNode(ExprNode): ...@@ -4189,13 +4362,21 @@ class BinopNode(ExprNode):
self.analyse_c_operation(env) self.analyse_c_operation(env)
def is_py_operation(self): def is_py_operation(self):
return (self.operand1.type.is_pyobject return self.is_py_operation_types(self.operand1.type, self.operand2.type)
or self.operand2.type.is_pyobject)
def is_py_operation_types(self, type1, type2):
return type1.is_pyobject or type2.is_pyobject
def result_type(self, type1, type2):
if self.is_py_operation_types(type1, type2):
return py_object_type
else:
return self.compute_c_result_type(type1, type2)
def nogil_check(self, env): def nogil_check(self, env):
if self.is_py_operation(): if self.is_py_operation():
self.gil_error() self.gil_error()
def coerce_operands_to_pyobjects(self, env): def coerce_operands_to_pyobjects(self, env):
self.operand1 = self.operand1.coerce_to_pyobject(env) self.operand1 = self.operand1.coerce_to_pyobject(env)
self.operand2 = self.operand2.coerce_to_pyobject(env) self.operand2 = self.operand2.coerce_to_pyobject(env)
...@@ -4314,12 +4495,11 @@ class IntBinopNode(NumBinopNode): ...@@ -4314,12 +4495,11 @@ class IntBinopNode(NumBinopNode):
class AddNode(NumBinopNode): class AddNode(NumBinopNode):
# '+' operator. # '+' operator.
def is_py_operation(self): def is_py_operation_types(self, type1, type2):
if self.operand1.type.is_string \ if type1.is_string and type2.is_string:
and self.operand2.type.is_string: return 1
return 1
else: else:
return NumBinopNode.is_py_operation(self) return NumBinopNode.is_py_operation_types(self, type1, type2)
def compute_c_result_type(self, type1, type2): def compute_c_result_type(self, type1, type2):
#print "AddNode.compute_c_result_type:", type1, self.operator, type2 ### #print "AddNode.compute_c_result_type:", type1, self.operator, type2 ###
...@@ -4348,14 +4528,12 @@ class SubNode(NumBinopNode): ...@@ -4348,14 +4528,12 @@ class SubNode(NumBinopNode):
class MulNode(NumBinopNode): class MulNode(NumBinopNode):
# '*' operator. # '*' operator.
def is_py_operation(self): def is_py_operation_types(self, type1, type2):
type1 = self.operand1.type
type2 = self.operand2.type
if (type1.is_string and type2.is_int) \ if (type1.is_string and type2.is_int) \
or (type2.is_string and type1.is_int): or (type2.is_string and type1.is_int):
return 1 return 1
else: else:
return NumBinopNode.is_py_operation(self) return NumBinopNode.is_py_operation_types(self, type1, type2)
class DivNode(NumBinopNode): class DivNode(NumBinopNode):
...@@ -4492,10 +4670,10 @@ class DivNode(NumBinopNode): ...@@ -4492,10 +4670,10 @@ class DivNode(NumBinopNode):
class ModNode(DivNode): class ModNode(DivNode):
# '%' operator. # '%' operator.
def is_py_operation(self): def is_py_operation_types(self, type1, type2):
return (self.operand1.type.is_string return (type1.is_string
or self.operand2.type.is_string or type2.is_string
or NumBinopNode.is_py_operation(self)) or NumBinopNode.is_py_operation_types(self, type1, type2))
def zero_division_message(self): def zero_division_message(self):
if self.type.is_int: if self.type.is_int:
...@@ -4571,6 +4749,13 @@ class BoolBinopNode(ExprNode): ...@@ -4571,6 +4749,13 @@ class BoolBinopNode(ExprNode):
# operand2 ExprNode # operand2 ExprNode
subexprs = ['operand1', 'operand2'] subexprs = ['operand1', 'operand2']
def infer_type(self, env):
if (self.operand1.infer_type(env).is_pyobject or
self.operand2.infer_type(env).is_pyobject):
return py_object_type
else:
return PyrexTypes.c_bint_type
def calculate_constant_result(self): def calculate_constant_result(self):
if self.operator == 'and': if self.operator == 'and':
...@@ -4685,6 +4870,13 @@ class CondExprNode(ExprNode): ...@@ -4685,6 +4870,13 @@ class CondExprNode(ExprNode):
false_val = None false_val = None
subexprs = ['test', 'true_val', 'false_val'] subexprs = ['test', 'true_val', 'false_val']
def type_dependencies(self, env):
return self.true_val.type_dependencies(env) + self.false_val.type_dependencies(env)
def infer_type(self, env):
return self.compute_result_type(self.true_val.infer_type(env),
self.false_val.infer_type(env))
def calculate_constant_result(self): def calculate_constant_result(self):
if self.test.constant_result: if self.test.constant_result:
...@@ -4769,6 +4961,10 @@ richcmp_constants = { ...@@ -4769,6 +4961,10 @@ richcmp_constants = {
class CmpNode(object): class CmpNode(object):
# Mixin class containing code common to PrimaryCmpNodes # Mixin class containing code common to PrimaryCmpNodes
# and CascadedCmpNodes. # and CascadedCmpNodes.
def infer_types(self, env):
# TODO: Actually implement this (after merging with -unstable).
return py_object_type
def calculate_cascaded_constant_result(self, operand1_result): def calculate_cascaded_constant_result(self, operand1_result):
func = compile_time_binary_operators[self.operator] func = compile_time_binary_operators[self.operator]
...@@ -4932,6 +5128,13 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -4932,6 +5128,13 @@ class PrimaryCmpNode(ExprNode, CmpNode):
cascade = None cascade = None
def infer_type(self, env):
# TODO: Actually implement this (after merging with -unstable).
return py_object_type
def type_dependencies(self, env):
return ()
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.calculate_cascaded_constant_result( self.constant_result = self.calculate_cascaded_constant_result(
self.operand1.constant_result) self.operand1.constant_result)
...@@ -5066,6 +5269,13 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -5066,6 +5269,13 @@ class CascadedCmpNode(Node, CmpNode):
cascade = None cascade = None
constant_result = constant_value_not_set # FIXME: where to calculate this? constant_result = constant_value_not_set # FIXME: where to calculate this?
def infer_type(self, env):
# TODO: Actually implement this (after merging with -unstable).
return py_object_type
def type_dependencies(self, env):
return ()
def analyse_types(self, env, operand1): def analyse_types(self, env, operand1):
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.cascade: if self.cascade:
...@@ -5287,11 +5497,12 @@ class NoneCheckNode(CoercionNode): ...@@ -5287,11 +5497,12 @@ class NoneCheckNode(CoercionNode):
class CoerceToPyTypeNode(CoercionNode): class CoerceToPyTypeNode(CoercionNode):
# This node is used to convert a C data type # This node is used to convert a C data type
# to a Python object. # to a Python object.
type = py_object_type
is_temp = 1
def __init__(self, arg, env): def __init__(self, arg, env):
CoercionNode.__init__(self, arg) CoercionNode.__init__(self, arg)
self.type = py_object_type
self.is_temp = 1
if not arg.type.create_to_py_utility_code(env): if not arg.type.create_to_py_utility_code(env):
error(arg.pos, error(arg.pos,
"Cannot convert '%s' to Python object" % arg.type) "Cannot convert '%s' to Python object" % arg.type)
...@@ -5359,9 +5570,10 @@ class CoerceToBooleanNode(CoercionNode): ...@@ -5359,9 +5570,10 @@ class CoerceToBooleanNode(CoercionNode):
# This node is used when a result needs to be used # This node is used when a result needs to be used
# in a boolean context. # in a boolean context.
type = PyrexTypes.c_bint_type
def __init__(self, arg, env): def __init__(self, arg, env):
CoercionNode.__init__(self, arg) CoercionNode.__init__(self, arg)
self.type = PyrexTypes.c_bint_type
if arg.type.is_pyobject: if arg.type.is_pyobject:
self.is_temp = 1 self.is_temp = 1
...@@ -5462,10 +5674,16 @@ class CloneNode(CoercionNode): ...@@ -5462,10 +5674,16 @@ class CloneNode(CoercionNode):
self.result_ctype = arg.result_ctype self.result_ctype = arg.result_ctype
if hasattr(arg, 'entry'): if hasattr(arg, 'entry'):
self.entry = arg.entry self.entry = arg.entry
def result(self): def result(self):
return self.arg.result() return self.arg.result()
def type_dependencies(self, env):
return self.arg.type_dependencies(env)
def infer_type(self, env):
return self.arg.infer_type(env)
def analyse_types(self, env): def analyse_types(self, env):
self.type = self.arg.type self.type = self.arg.type
self.result_ctype = self.arg.result_ctype self.result_ctype = self.arg.result_ctype
......
...@@ -87,6 +87,7 @@ class Context(object): ...@@ -87,6 +87,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from TypeInference import MarkAssignments
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform from AnalysedTreeTransforms import AutoTestDictTransform
from AutoDocTransforms import EmbedSignature from AutoDocTransforms import EmbedSignature
...@@ -129,6 +130,7 @@ class Context(object): ...@@ -129,6 +130,7 @@ class Context(object):
AnalyseDeclarationsTransform(self), AnalyseDeclarationsTransform(self),
AutoTestDictTransform(self), AutoTestDictTransform(self),
EmbedSignature(self), EmbedSignature(self),
MarkAssignments(self),
TransformBuiltinMethods(self), TransformBuiltinMethods(self),
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
......
...@@ -4,6 +4,12 @@ ...@@ -4,6 +4,12 @@
import sys, os, time, copy import sys, os, time, copy
try:
set
except NameError:
# Python 2.3
from sets import Set as set
import Code import Code
import Builtin import Builtin
from Errors import error, warning, InternalError from Errors import error, warning, InternalError
...@@ -3197,6 +3203,10 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -3197,6 +3203,10 @@ class InPlaceAssignmentNode(AssignmentNode):
self.lhs.annotate(code) self.lhs.annotate(code)
self.rhs.annotate(code) self.rhs.annotate(code)
self.dup.annotate(code) self.dup.annotate(code)
def create_binop_node(self):
import ExprNodes
return ExprNodes.binop_node(self.pos, self.operator, self.lhs, self.rhs)
class PrintStatNode(StatNode): class PrintStatNode(StatNode):
......
...@@ -68,6 +68,7 @@ option_defaults = { ...@@ -68,6 +68,7 @@ option_defaults = {
'c99_complex' : False, # Don't use macro wrappers for complex arith, not sure what to name this... 'c99_complex' : False, # Don't use macro wrappers for complex arith, not sure what to name this...
'callspec' : "", 'callspec' : "",
'profile': False, 'profile': False,
'infer_types': False,
'autotestdict': True, 'autotestdict': True,
# test support # test support
......
...@@ -328,7 +328,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -328,7 +328,7 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
duplication of functionality has to occur: We manually track cimports duplication of functionality has to occur: We manually track cimports
and which names the "cython" module may have been imported to. and which names the "cython" module may have been imported to.
""" """
special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', 'cast', 'address', 'pointer', 'compiled', 'NULL']) special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', 'typeof', 'cast', 'address', 'pointer', 'compiled', 'NULL'])
def __init__(self, context, compilation_option_overrides): def __init__(self, context, compilation_option_overrides):
super(InterpretCompilerDirectives, self).__init__(context) super(InterpretCompilerDirectives, self).__init__(context)
...@@ -785,11 +785,13 @@ property NAME: ...@@ -785,11 +785,13 @@ property NAME:
class AnalyseExpressionsTransform(CythonTransform): class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.scope.infer_types()
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
node.local_scope.infer_types()
node.body.analyse_expressions(node.local_scope) node.body.analyse_expressions(node.local_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -1007,6 +1009,11 @@ class TransformBuiltinMethods(EnvTransform): ...@@ -1007,6 +1009,11 @@ class TransformBuiltinMethods(EnvTransform):
node = SizeofTypeNode(node.function.pos, arg_type=type) node = SizeofTypeNode(node.function.pos, arg_type=type)
else: else:
node = SizeofVarNode(node.function.pos, operand=node.args[0]) node = SizeofVarNode(node.function.pos, operand=node.args[0])
elif function == 'typeof':
if len(node.args) != 1:
error(node.function.pos, u"sizeof takes exactly one argument" % function)
else:
node = TypeofNode(node.function.pos, operand=node.args[0])
elif function == 'address': elif function == 'address':
if len(node.args) != 1: if len(node.args) != 1:
error(node.function.pos, u"sizeof takes exactly one argument" % function) error(node.function.pos, u"sizeof takes exactly one argument" % function)
......
...@@ -77,6 +77,7 @@ class PyrexType(BaseType): ...@@ -77,6 +77,7 @@ class PyrexType(BaseType):
# #
is_pyobject = 0 is_pyobject = 0
is_unspecified = 0
is_extension_type = 0 is_extension_type = 0
is_builtin_type = 0 is_builtin_type = 0
is_numeric = 0 is_numeric = 0
...@@ -849,9 +850,17 @@ class CComplexType(CNumericType): ...@@ -849,9 +850,17 @@ class CComplexType(CNumericType):
def __hash__(self): def __hash__(self):
return ~hash(self.real_type) return ~hash(self.real_type)
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0):
if for_display:
base = public_decl(self.real_type.sign_and_name() + " complex", dll_linkage)
else:
base = public_decl(self.sign_and_name(), dll_linkage)
return self.base_declaration_code(base, entity_code)
def sign_and_name(self): def sign_and_name(self):
return Naming.type_prefix + self.real_type.specalization_name() + "_complex" return Naming.type_prefix + self.real_type.specalization_name() + "_complex"
def assignable_from_resolved_type(self, src_type): def assignable_from_resolved_type(self, src_type):
return (src_type.is_complex and self.real_type.assignable_from_resolved_type(src_type.real_type) return (src_type.is_complex and self.real_type.assignable_from_resolved_type(src_type.real_type)
or src_type.is_numeric and self.real_type.assignable_from_resolved_type(src_type) or src_type.is_numeric and self.real_type.assignable_from_resolved_type(src_type)
...@@ -1591,6 +1600,8 @@ class CUCharPtrType(CStringType, CPtrType): ...@@ -1591,6 +1600,8 @@ class CUCharPtrType(CStringType, CPtrType):
class UnspecifiedType(PyrexType): class UnspecifiedType(PyrexType):
# Used as a placeholder until the type can be determined. # Used as a placeholder until the type can be determined.
is_unspecified = 1
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
...@@ -1788,6 +1799,23 @@ def widest_numeric_type(type1, type2): ...@@ -1788,6 +1799,23 @@ def widest_numeric_type(type1, type2):
return sign_and_rank_to_type[min(type1.signed, type2.signed), max(type1.rank, type2.rank)] return sign_and_rank_to_type[min(type1.signed, type2.signed), max(type1.rank, type2.rank)]
return widest_type return widest_type
def spanning_type(type1, type2):
# Return a type assignable from both type1 and type2.
if type1 is py_object_type or type2 is py_object_type:
return py_object_type
elif type1 == type2:
return type1
elif type1.is_numeric and type2.is_numeric:
return widest_numeric_type(type1, type2)
elif type1.is_pyobject ^ type2.is_pyobject:
return py_object_type
elif type1.assignable_from(type2):
return type1
elif type2.assignable_from(type1):
return type2
else:
return py_object_type
def simple_c_type(signed, longness, name): def simple_c_type(signed, longness, name):
# Find type descriptor for simple type given name and modifiers. # Find type descriptor for simple type given name and modifiers.
# Returns None if arguments don't make sense. # Returns None if arguments don't make sense.
......
...@@ -8,7 +8,7 @@ from Errors import warning, error, InternalError ...@@ -8,7 +8,7 @@ from Errors import warning, error, InternalError
from StringEncoding import EncodedString from StringEncoding import EncodedString
import Options, Naming import Options, Naming
import PyrexTypes import PyrexTypes
from PyrexTypes import py_object_type from PyrexTypes import py_object_type, unspecified_type
import TypeSlots import TypeSlots
from TypeSlots import \ from TypeSlots import \
pyfunction_signature, pymethod_signature, \ pyfunction_signature, pymethod_signature, \
...@@ -114,9 +114,10 @@ class Entry(object): ...@@ -114,9 +114,10 @@ class Entry(object):
# api boolean Generate C API for C class or function # api boolean Generate C API for C class or function
# utility_code string Utility code needed when this entry is used # utility_code string Utility code needed when this entry is used
# #
# buffer_aux BufferAux or None Extra information needed for buffer variables # buffer_aux BufferAux or None Extra information needed for buffer variables
# inline_func_in_pxd boolean Hacky special case for inline function in pxd file. # inline_func_in_pxd boolean Hacky special case for inline function in pxd file.
# Ideally this should not be necesarry. # Ideally this should not be necesarry.
# assignments [ExprNode] List of expressions that get assigned to this entry.
inline_func_in_pxd = False inline_func_in_pxd = False
borrowed = 0 borrowed = 0
...@@ -171,6 +172,10 @@ class Entry(object): ...@@ -171,6 +172,10 @@ class Entry(object):
self.type = type self.type = type
self.pos = pos self.pos = pos
self.init = init self.init = init
self.assignments = []
def __repr__(self):
return "Entry(name=%s, type=%s)" % (self.name, self.type)
def redeclared(self, pos): def redeclared(self, pos):
error(pos, "'%s' does not match previous declaration" % self.name) error(pos, "'%s' does not match previous declaration" % self.name)
...@@ -542,6 +547,10 @@ class Scope(object): ...@@ -542,6 +547,10 @@ class Scope(object):
if name in self.entries: if name in self.entries:
return 1 return 1
return 0 return 0
def infer_types(self):
from TypeInference import get_type_inferer
get_type_inferer().infer_types(self)
class PreImportScope(Scope): class PreImportScope(Scope):
...@@ -814,6 +823,8 @@ class ModuleScope(Scope): ...@@ -814,6 +823,8 @@ class ModuleScope(Scope):
if not visibility in ('private', 'public', 'extern'): if not visibility in ('private', 'public', 'extern'):
error(pos, "Module-level variable cannot be declared %s" % visibility) error(pos, "Module-level variable cannot be declared %s" % visibility)
if not is_cdef: if not is_cdef:
if type is unspecified_type:
type = py_object_type
if not (type.is_pyobject and not type.is_extension_type): if not (type.is_pyobject and not type.is_extension_type):
raise InternalError( raise InternalError(
"Non-cdef global variable is not a generic Python object") "Non-cdef global variable is not a generic Python object")
...@@ -1043,6 +1054,10 @@ class ModuleScope(Scope): ...@@ -1043,6 +1054,10 @@ class ModuleScope(Scope):
var_entry.is_cglobal = 1 var_entry.is_cglobal = 1
var_entry.is_readonly = 1 var_entry.is_readonly = 1
entry.as_variable = var_entry entry.as_variable = var_entry
def infer_types(self):
from TypeInference import PyObjectTypeInferer
PyObjectTypeInferer().infer_types(self)
class LocalScope(Scope): class LocalScope(Scope):
...@@ -1074,7 +1089,7 @@ class LocalScope(Scope): ...@@ -1074,7 +1089,7 @@ class LocalScope(Scope):
cname, visibility, is_cdef) cname, visibility, is_cdef)
if type.is_pyobject and not Options.init_local_none: if type.is_pyobject and not Options.init_local_none:
entry.init = "0" entry.init = "0"
entry.init_to_none = type.is_pyobject and Options.init_local_none entry.init_to_none = (type.is_pyobject or type.is_unspecified) and Options.init_local_none
entry.is_local = 1 entry.is_local = 1
self.var_entries.append(entry) self.var_entries.append(entry)
return entry return entry
...@@ -1189,6 +1204,8 @@ class PyClassScope(ClassScope): ...@@ -1189,6 +1204,8 @@ class PyClassScope(ClassScope):
def declare_var(self, name, type, pos, def declare_var(self, name, type, pos,
cname = None, visibility = 'private', is_cdef = 0): cname = None, visibility = 'private', is_cdef = 0):
if type is unspecified_type:
type = py_object_type
# Add an entry for a class attribute. # Add an entry for a class attribute.
entry = Scope.declare_var(self, name, type, pos, entry = Scope.declare_var(self, name, type, pos,
cname, visibility, is_cdef) cname, visibility, is_cdef)
...@@ -1275,6 +1292,8 @@ class CClassScope(ClassScope): ...@@ -1275,6 +1292,8 @@ class CClassScope(ClassScope):
"Non-generic Python attribute cannot be exposed for writing from Python") "Non-generic Python attribute cannot be exposed for writing from Python")
return entry return entry
else: else:
if type is unspecified_type:
type = py_object_type
# Add an entry for a class attribute. # Add an entry for a class attribute.
entry = Scope.declare_var(self, name, type, pos, entry = Scope.declare_var(self, name, type, pos,
cname, visibility, is_cdef) cname, visibility, is_cdef)
......
import ExprNodes
from PyrexTypes import py_object_type, unspecified_type, spanning_type
from Visitor import CythonTransform
try:
set
except NameError:
# Python 2.3
from sets import Set as set
class TypedExprNode(ExprNodes.ExprNode):
# Used for declaring assignments of a specified type whithout a known entry.
def __init__(self, type):
self.type = type
object_expr = TypedExprNode(py_object_type)
class MarkAssignments(CythonTransform):
def mark_assignment(self, lhs, rhs):
if isinstance(lhs, ExprNodes.NameNode):
if lhs.entry is None:
# TODO: This shouldn't happen...
# It looks like comprehension loop targets are not declared soon enough.
return
lhs.entry.assignments.append(rhs)
elif isinstance(lhs, ExprNodes.SequenceNode):
for arg in lhs.args:
self.mark_assignment(arg, object_expr)
else:
# Could use this info to infer cdef class attributes...
pass
def visit_SingleAssignmentNode(self, node):
self.mark_assignment(node.lhs, node.rhs)
self.visitchildren(node)
return node
def visit_CascadedAssignmentNode(self, node):
for lhs in node.lhs_list:
self.mark_assignment(lhs, node.rhs)
self.visitchildren(node)
return node
def visit_InPlaceAssignmentNode(self, node):
self.mark_assignment(node.lhs, node.create_binop_node())
self.visitchildren(node)
return node
def visit_ForInStatNode(self, node):
# TODO: Remove redundancy with range optimization...
is_range = False
sequence = node.iterator.sequence
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and \
isinstance(function, ExprNodes.NameNode) and \
function.name in ('range', 'xrange'):
is_range = True
self.mark_assignment(node.target, sequence.args[0])
if len(sequence.args) > 1:
self.mark_assignment(node.target, sequence.args[1])
if len(sequence.args) > 2:
self.mark_assignment(node.target,
ExprNodes.binop_node(node.pos,
'+',
sequence.args[0],
sequence.args[2]))
if not is_range:
self.mark_assignment(node.target, object_expr)
self.visitchildren(node)
return node
def visit_ForFromStatNode(self, node):
self.mark_assignment(node.target, node.bound1)
if node.step is not None:
self.mark_assignment(node.target,
ExprNodes.binop_node(node.pos,
'+',
node.bound1,
node.step))
self.visitchildren(node)
return node
def visit_ExceptClauseNode(self, node):
if node.target is not None:
self.mark_assignment(node.target, object_expr)
self.visitchildren(node)
return node
def visit_FromCImportStatNode(self, node):
pass # Can't be assigned to...
def visit_FromImportStatNode(self, node):
for name, target in node.items:
if name != "*":
self.mark_assignment(target, object_expr)
self.visitchildren(node)
return node
class PyObjectTypeInferer:
"""
If it's not declared, it's a PyObject.
"""
def infer_types(self, scope):
"""
Given a dict of entries, map all unspecified types to a specified type.
"""
for name, entry in scope.entries.items():
if entry.type is unspecified_type:
entry.type = py_object_type
class SimpleAssignmentTypeInferer:
"""
Very basic type inference.
"""
# TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...)
def infer_types(self, scope):
dependancies_by_entry = {} # entry -> dependancies
entries_by_dependancy = {} # dependancy -> entries
ready_to_infer = []
for name, entry in scope.entries.items():
if entry.type is unspecified_type:
all = set()
for expr in entry.assignments:
all.update(expr.type_dependencies(scope))
if all:
dependancies_by_entry[entry] = all
for dep in all:
if dep not in entries_by_dependancy:
entries_by_dependancy[dep] = set([entry])
else:
entries_by_dependancy[dep].add(entry)
else:
ready_to_infer.append(entry)
def resolve_dependancy(dep):
if dep in entries_by_dependancy:
for entry in entries_by_dependancy[dep]:
entry_deps = dependancies_by_entry[entry]
entry_deps.remove(dep)
if not entry_deps and entry != dep:
del dependancies_by_entry[entry]
ready_to_infer.append(entry)
# Try to infer things in order...
while True:
while ready_to_infer:
entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments]
if types:
entry.type = reduce(spanning_type, types)
else:
# List comprehension?
# print "No assignments", entry.pos, entry
entry.type = py_object_type
resolve_dependancy(entry)
# Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
if types:
entry.type = reduce(spanning_type, types)
types = [expr.infer_type(scope) for expr in entry.assignments]
entry.type = reduce(spanning_type, types) # might be wider...
resolve_dependancy(entry)
del dependancies_by_entry[entry]
if ready_to_infer:
break
if not ready_to_infer:
break
# We can't figure out the rest with this algorithm, let them be objects.
for entry in dependancies_by_entry:
entry.type = py_object_type
def get_type_inferer():
return SimpleAssignmentTypeInferer()
...@@ -30,6 +30,9 @@ def cast(type, arg): ...@@ -30,6 +30,9 @@ def cast(type, arg):
def sizeof(arg): def sizeof(arg):
return 1 return 1
def typeof(arg):
return type(arg)
def address(arg): def address(arg):
return pointer(type(arg))([arg]) return pointer(type(arg))([arg])
......
#include "Python.h"
#include "embedded.h"
int main(int argc, char *argv) {
Py_Initialize();
initembedded();
spam();
Py_Finalize();
}
print "starting" print "starting"
def primes(int kmax): def primes(int kmax):
cdef int n, k, i # cdef int n, k, i
cdef int p[1000] cdef int p[1000]
result = [] result = []
if kmax > 1000: if kmax > 1000:
......
...@@ -8,7 +8,10 @@ import shutil ...@@ -8,7 +8,10 @@ import shutil
import unittest import unittest
import doctest import doctest
import operator import operator
from StringIO import StringIO try:
from StringIO import StringIO
except ImportError:
from io import StringIO
try: try:
import cPickle as pickle import cPickle as pickle
...@@ -394,7 +397,7 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -394,7 +397,7 @@ class CythonRunTestCase(CythonCompileTestCase):
try: try:
partial_result = PartialTestResult(result) partial_result = PartialTestResult(result)
doctest.DocTestSuite(module_name).run(partial_result) doctest.DocTestSuite(module_name).run(partial_result)
except Exception, e: except Exception:
partial_result.addError(module_name, sys.exc_info()) partial_result.addError(module_name, sys.exc_info())
result_code = 1 result_code = 1
pickle.dump(partial_result.data(), output) pickle.dump(partial_result.data(), output)
......
...@@ -2,6 +2,7 @@ cdef void foo(): ...@@ -2,6 +2,7 @@ cdef void foo():
cdef int i1, i2=0 cdef int i1, i2=0
cdef char c1=0, c2 cdef char c1=0, c2
cdef char *p1, *p2=NULL cdef char *p1, *p2=NULL
cdef object obj1
i1 = i2 i1 = i2
i1 = c1 i1 = c1
p1 = p2 p1 = p2
......
...@@ -576,15 +576,15 @@ def test_DefSInt(defs.SInt x): ...@@ -576,15 +576,15 @@ def test_DefSInt(defs.SInt x):
""" """
return x return x
def test_DefUInt(defs.UInt x): def test_DefUChar(defs.UChar x):
u""" u"""
>>> test_DefUInt(-1) #doctest: +ELLIPSIS >>> test_DefUChar(-1) #doctest: +ELLIPSIS
Traceback (most recent call last): Traceback (most recent call last):
... ...
OverflowError: ... OverflowError: ...
>>> test_DefUInt(0) >>> test_DefUChar(0)
0 0
>>> test_DefUInt(1) >>> test_DefUChar(1)
1 1
""" """
return x return x
......
...@@ -35,7 +35,7 @@ def test_all(): ...@@ -35,7 +35,7 @@ def test_all():
assert not isinstance(u"foo", int) assert not isinstance(u"foo", int)
# Non-optimized # Non-optimized
foo = A cdef object foo = A
assert isinstance(A(), foo) assert isinstance(A(), foo)
assert isinstance(0, (int, long)) assert isinstance(0, (int, long))
assert not isinstance(u"xyz", (int, long)) assert not isinstance(u"xyz", (int, long))
......
...@@ -5,6 +5,7 @@ __doc__ = u""" ...@@ -5,6 +5,7 @@ __doc__ = u"""
def f(): def f():
cdef int bool, int1, int2 cdef int bool, int1, int2
cdef object obj1, obj2
int1 = 0 int1 = 0
int2 = 0 int2 = 0
obj1 = 1 obj1 = 1
......
# cython: infer_types = True
__doc__ = u"""
>>> simple()
>>> multiple_assignments()
>>> arithmatic()
>>> cascade()
>>> increment()
>>> loop()
"""
from cython cimport typeof
def simple():
i = 3
assert typeof(i) == "long"
x = 1.41
assert typeof(x) == "double"
xptr = &x
assert typeof(xptr) == "double *"
xptrptr = &xptr
assert typeof(xptrptr) == "double **"
s = "abc"
assert typeof(s) == "char *"
u = u"xyz"
assert typeof(u) == "unicode object"
L = [1,2,3]
assert typeof(L) == "list object"
t = (4,5,6)
assert typeof(t) == "tuple object"
def multiple_assignments():
a = 3
a = 4
a = 5
assert typeof(a) == "long"
b = a
b = 3.1
b = 3.14159
assert typeof(b) == "double"
c = a
c = b
c = [1,2,3]
assert typeof(c) == "Python object"
def arithmatic():
a = 1 + 2
assert typeof(a) == "long"
b = 1 + 1.5
assert typeof(b) == "double"
c = 1 + <object>2
assert typeof(c) == "Python object"
d = "abc %s" % "x"
assert typeof(d) == "Python object"
def cascade():
a = 1.0
b = a + 2
c = b + 3
d = c + 4
assert typeof(d) == "double"
e = a + b + c + d
assert typeof(e) == "double"
def increment():
a = 5
a += 1
assert typeof(a) == "long"
def loop():
for a in range(10):
pass
assert typeof(a) == "long"
b = 1.0
for b in range(5):
pass
assert typeof(b) == "double"
for c from 0 <= c < 10 by .5:
pass
assert typeof(c) == "double"
__doc__ = u"""
>>> simple()
int
long
long long
int *
int **
A
B
X
Python object
>>> expression()
double
double complex
int
unsigned int
"""
from cython cimport typeof
cdef class A:
pass
cdef class B(A):
pass
cdef struct X:
double a
double complex b
def simple():
cdef int i = 0
cdef long l = 0
cdef long long ll = 0
cdef int* iptr = &i
cdef int** iptrptr = &iptr
cdef A a = None
cdef B b = None
cdef X x = X(a=1, b=2)
print typeof(i)
print typeof(l)
print typeof(ll)
print typeof(iptr)
print typeof(iptrptr)
print typeof(a)
print typeof(b)
print typeof(x)
print typeof(None)
used = i, l, ll, <long>iptr, <long>iptrptr, a, b, x
def expression():
cdef X x = X(a=1, b=2)
cdef X *xptr = &x
cdef short s = 0
cdef int i = 0
cdef unsigned int ui = 0
print typeof(x.a)
print typeof(xptr.b)
print typeof(s + i)
print typeof(i + ui)
used = x, <long>xptr, s, i, ui
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment