Commit ab433f3d authored by Stefan Behnel's avatar Stefan Behnel

'safe' mode for type inference: only infer types that are very unlikely to break code

parent 52679b28
......@@ -62,7 +62,7 @@ directive_defaults = {
'ccomplex' : False, # use C99/C++ for complex types and arith
'callspec' : "",
'profile': False,
'infer_types': False,
'infer_types': 'none', # 'none', 'safe', 'all'
'autotestdict': True,
# test support
......@@ -87,7 +87,7 @@ directive_scopes = { # defaults to available everywhere
def parse_directive_value(name, value):
"""
Parses value as an option value for the given name and returns
the interpreted value. None is returned if the option does not exist.
the interpreted value. None is returned if the option does not exist.
>>> print parse_directive_value('nonexisting', 'asdf asdfd')
None
......@@ -110,6 +110,8 @@ def parse_directive_value(name, value):
return int(value)
except ValueError:
raise ValueError("%s directive must be set to an integer" % name)
elif type is str:
return str(value)
else:
assert False
......
import ExprNodes
import PyrexTypes
from PyrexTypes import py_object_type, unspecified_type, spanning_type
from Visitor import CythonTransform
......@@ -119,6 +120,7 @@ class SimpleAssignmentTypeInferer:
# TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...)
def infer_types(self, scope):
which_types_to_infer = scope.directives['infer_types']
dependancies_by_entry = {} # entry -> dependancies
entries_by_dependancy = {} # dependancy -> entries
ready_to_infer = []
......@@ -150,11 +152,12 @@ class SimpleAssignmentTypeInferer:
entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments]
if types:
entry.type = reduce(spanning_type, types)
result_type = reduce(spanning_type, types)
else:
# List comprehension?
# print "No assignments", entry.pos, entry
entry.type = py_object_type
result_type = py_object_type
entry.type = find_safe_type(result_type, which_types_to_infer)
resolve_dependancy(entry)
# Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items():
......@@ -164,6 +167,7 @@ class SimpleAssignmentTypeInferer:
entry.type = reduce(spanning_type, types)
types = [expr.infer_type(scope) for expr in entry.assignments]
entry.type = reduce(spanning_type, types) # might be wider...
entry.type = find_safe_type(entry.type, which_types_to_infer)
resolve_dependancy(entry)
del dependancies_by_entry[entry]
if ready_to_infer:
......@@ -175,5 +179,18 @@ class SimpleAssignmentTypeInferer:
for entry in dependancies_by_entry:
entry.type = py_object_type
def find_safe_type(result_type, which_types_to_infer):
if which_types_to_infer == 'all':
return result_type
elif which_types_to_infer == 'safe':
if result_type.is_pyobject:
# any specific Python type is always safe to infer
return result_type
elif result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type):
# Python's float type is just a C double, so it's safe to
# use the C type instead
return PyrexTypes.c_double_type
return py_object_type
def get_type_inferer():
return SimpleAssignmentTypeInferer()
# cython: infer_types = True
# cython: infer_types = all
from cython cimport typeof
from cython cimport typeof, infer_types
cdef class MyType:
pass
def simple():
"""
......@@ -26,6 +29,23 @@ def simple():
t = (4,5,6)
assert typeof(t) == "tuple object", typeof(t)
def builtin_types():
"""
>>> builtin_types()
"""
b = bytes()
assert typeof(b) == "bytes object", typeof(b)
u = unicode()
assert typeof(u) == "unicode object", typeof(u)
L = list()
assert typeof(L) == "list object", typeof(L)
t = tuple()
assert typeof(t) == "tuple object", typeof(t)
d = dict()
assert typeof(d) == "dict object", typeof(d)
B = bool()
assert typeof(B) == "bool object", typeof(B)
def multiple_assignments():
"""
>>> multiple_assignments()
......@@ -43,9 +63,9 @@ def multiple_assignments():
c = [1,2,3]
assert typeof(c) == "Python object"
def arithmatic():
def arithmetic():
"""
>>> arithmatic()
>>> arithmetic()
"""
a = 1 + 2
assert typeof(a) == "long"
......@@ -105,3 +125,15 @@ def loop():
for d in range(0, 10L, 2):
pass
assert typeof(a) == "long"
@infer_types('safe')
def safe_only():
"""
>>> safe_only()
"""
a = 1.0
assert typeof(a) == "double", typeof(c)
b = 1
assert typeof(b) == "Python object", typeof(c)
c = MyType()
assert typeof(c) == "MyType", typeof(c)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment