Commit 9614a25e authored by Stefan Behnel's avatar Stefan Behnel

use True/None/False as infer_types() option values, make 'bint' type inference safe in safe mode

parent 9c44f22f
......@@ -1146,7 +1146,7 @@ class NameNode(AtomicExprNode):
if not self.entry:
self.entry = env.lookup_here(self.name)
if not self.entry:
if env.directives['infer_types'] != 'none':
if env.directives['infer_types'] != False:
type = unspecified_type
else:
type = py_object_type
......
......@@ -62,7 +62,7 @@ directive_defaults = {
'ccomplex' : False, # use C99/C++ for complex types and arith
'callspec' : "",
'profile': False,
'infer_types': 'none', # 'none', 'safe', 'all'
'infer_types': False,
'autotestdict': True,
# test support
......@@ -71,7 +71,9 @@ directive_defaults = {
}
# Override types possibilities above, if needed
directive_types = {}
directive_types = {
'infer_types' : bool, # values can be True/None/False
}
for key, val in directive_defaults.items():
if key not in directive_types:
......
......@@ -440,7 +440,18 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
directivetype = Options.directive_types.get(optname)
if directivetype:
args, kwds = node.explicit_args_kwds()
if directivetype is bool:
if optname == 'infer_types':
if kwds is not None or len(args) != 1:
raise PostParseError(node.function.pos,
'The %s directive takes one compile-time boolean argument' % optname)
elif isinstance(args[0], BoolNode):
return (optname, args[0].value)
elif isinstance(args[0], NoneNode):
return (optname, None)
else:
raise PostParseError(node.function.pos,
'The %s directive takes one compile-time boolean argument' % optname)
elif directivetype is bool:
if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
raise PostParseError(node.function.pos,
'The %s directive takes one compile-time boolean argument' % optname)
......
# cython: auto_cpdef=True, infer_types=all
# cython: auto_cpdef=True, infer_types=True
#
# Pyrex Parser
#
......
......@@ -2,7 +2,7 @@ import ExprNodes
import Nodes
import Builtin
import PyrexTypes
from PyrexTypes import py_object_type, unspecified_type, spanning_type
from PyrexTypes import py_object_type, unspecified_type
from Visitor import CythonTransform
try:
......@@ -131,7 +131,17 @@ class SimpleAssignmentTypeInferer:
# TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...)
def infer_types(self, scope):
which_types_to_infer = scope.directives['infer_types']
enabled = scope.directives['infer_types']
if enabled == True:
spanning_type = aggressive_spanning_type
elif enabled is None: # safe mode
spanning_type = safe_spanning_type
else:
for entry in scope.entries.values():
if entry.type is unspecified_type:
entry.type = py_object_type
return
dependancies_by_entry = {} # entry -> dependancies
entries_by_dependancy = {} # dependancy -> entries
ready_to_infer = []
......@@ -163,22 +173,20 @@ class SimpleAssignmentTypeInferer:
entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments]
if types:
result_type = reduce(spanning_type, types)
entry.type = spanning_type(types)
else:
# FIXME: raise a warning?
# print "No assignments", entry.pos, entry
result_type = py_object_type
entry.type = find_safe_type(result_type, which_types_to_infer)
entry.type = py_object_type
resolve_dependancy(entry)
# Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
if types:
entry.type = reduce(spanning_type, types)
entry.type = spanning_type(types)
types = [expr.infer_type(scope) for expr in entry.assignments]
entry.type = reduce(spanning_type, types) # might be wider...
entry.type = find_safe_type(entry.type, which_types_to_infer)
entry.type = spanning_type(types) # might be wider...
resolve_dependancy(entry)
del dependancies_by_entry[entry]
if ready_to_infer:
......@@ -190,25 +198,39 @@ class SimpleAssignmentTypeInferer:
for entry in dependancies_by_entry:
entry.type = py_object_type
def find_safe_type(result_type, which_types_to_infer):
if which_types_to_infer == 'none':
def find_spanning_type(type1, type2):
if type1 is type2:
return type1
elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
# type inference can break the coercion back to a Python bool
# if it returns an arbitrary int type here
return py_object_type
result_type = PyrexTypes.spanning_type(type1, type2)
if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type, Builtin.float_type):
# Python's float type is just a C double, so it's safe to
# use the C type instead
return PyrexTypes.c_double_type
return result_type
if which_types_to_infer == 'all':
def aggressive_spanning_type(types):
result_type = reduce(find_spanning_type, types)
return result_type
elif which_types_to_infer == 'safe':
def safe_spanning_type(types):
result_type = reduce(find_spanning_type, types)
if result_type.is_pyobject:
# any specific Python type is always safe to infer
return result_type
elif result_type is PyrexTypes.c_double_type:
# Python's float type is just a C double, so it's safe to use
# the C type instead
return result_type
elif result_type is PyrexTypes.c_bint_type:
# 'bint' should behave exactly like Python's bool type ...
return PyrexTypes.c_bint_type
# find_spanning_type() only returns 'bint' for clean boolean
# operations without other int types, so this is safe, too
return result_type
return py_object_type
def get_type_inferer():
return SimpleAssignmentTypeInferer()
# cython: infer_types = all
# cython: infer_types = True
cimport cython
from cython cimport typeof, infer_types
##################################################
# type inference tests in 'full' mode
cdef class MyType:
pass
......@@ -148,8 +151,29 @@ def loop():
pass
assert typeof(a) == "long"
cdef unicode retu():
return u"12345"
cdef bytes retb():
return b"12345"
def conditional(x):
"""
>>> conditional(True)
(True, 'Python object')
>>> conditional(False)
(False, 'Python object')
"""
if x:
a = retu()
else:
a = retb()
return type(a) is unicode, typeof(a)
##################################################
# type inference tests that work in 'safe' mode
@infer_types('safe')
@infer_types(None)
def double_inference():
"""
>>> values, types = double_inference()
......@@ -172,7 +196,7 @@ cdef object some_float_value():
@cython.test_assert_path_exists('//InPlaceAssignmentNode/NameNode',
'//NameNode[@type.is_pyobject]',
'//NameNode[@type.is_pyobject = False]')
@infer_types('safe')
@infer_types(None)
def double_loop():
"""
>>> double_loop() == 1.0 * 10
......@@ -184,26 +208,7 @@ def double_loop():
d += 1.0
return d
cdef unicode retu():
return u"12345"
cdef bytes retb():
return b"12345"
def conditional(x):
"""
>>> conditional(True)
(True, 'Python object')
>>> conditional(False)
(False, 'Python object')
"""
if x:
a = retu()
else:
a = retb()
return type(a) is unicode, typeof(a)
@infer_types('safe')
@infer_types(None)
def safe_only():
"""
>>> safe_only()
......@@ -215,7 +220,7 @@ def safe_only():
c = MyType()
assert typeof(c) == "MyType", typeof(c)
@infer_types('safe')
@infer_types(None)
def args_tuple_keywords(*args, **kwargs):
"""
>>> args_tuple_keywords(1,2,3, a=1, b=2)
......@@ -223,7 +228,7 @@ def args_tuple_keywords(*args, **kwargs):
assert typeof(args) == "tuple object", typeof(args)
assert typeof(kwargs) == "dict object", typeof(kwargs)
@infer_types('safe')
@infer_types(None)
def args_tuple_keywords_reassign_same(*args, **kwargs):
"""
>>> args_tuple_keywords_reassign_same(1,2,3, a=1, b=2)
......@@ -234,7 +239,7 @@ def args_tuple_keywords_reassign_same(*args, **kwargs):
args = ()
kwargs = {}
@infer_types('safe')
@infer_types(None)
def args_tuple_keywords_reassign_pyobjects(*args, **kwargs):
"""
>>> args_tuple_keywords_reassign_pyobjects(1,2,3, a=1, b=2)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment