Commit 9614a25e authored by Stefan Behnel's avatar Stefan Behnel

use True/None/False as infer_types() option values, make 'bint' type inference safe in safe mode

parent 9c44f22f
...@@ -1146,7 +1146,7 @@ class NameNode(AtomicExprNode): ...@@ -1146,7 +1146,7 @@ class NameNode(AtomicExprNode):
if not self.entry: if not self.entry:
self.entry = env.lookup_here(self.name) self.entry = env.lookup_here(self.name)
if not self.entry: if not self.entry:
if env.directives['infer_types'] != 'none': if env.directives['infer_types'] != False:
type = unspecified_type type = unspecified_type
else: else:
type = py_object_type type = py_object_type
......
...@@ -62,7 +62,7 @@ directive_defaults = { ...@@ -62,7 +62,7 @@ directive_defaults = {
'ccomplex' : False, # use C99/C++ for complex types and arith 'ccomplex' : False, # use C99/C++ for complex types and arith
'callspec' : "", 'callspec' : "",
'profile': False, 'profile': False,
'infer_types': 'none', # 'none', 'safe', 'all' 'infer_types': False,
'autotestdict': True, 'autotestdict': True,
# test support # test support
...@@ -71,7 +71,9 @@ directive_defaults = { ...@@ -71,7 +71,9 @@ directive_defaults = {
} }
# Override types possibilities above, if needed # Override types possibilities above, if needed
directive_types = {} directive_types = {
'infer_types' : bool, # values can be True/None/False
}
for key, val in directive_defaults.items(): for key, val in directive_defaults.items():
if key not in directive_types: if key not in directive_types:
......
...@@ -440,7 +440,18 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -440,7 +440,18 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
directivetype = Options.directive_types.get(optname) directivetype = Options.directive_types.get(optname)
if directivetype: if directivetype:
args, kwds = node.explicit_args_kwds() args, kwds = node.explicit_args_kwds()
if directivetype is bool: if optname == 'infer_types':
if kwds is not None or len(args) != 1:
raise PostParseError(node.function.pos,
'The %s directive takes one compile-time boolean argument' % optname)
elif isinstance(args[0], BoolNode):
return (optname, args[0].value)
elif isinstance(args[0], NoneNode):
return (optname, None)
else:
raise PostParseError(node.function.pos,
'The %s directive takes one compile-time boolean argument' % optname)
elif directivetype is bool:
if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode): if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
raise PostParseError(node.function.pos, raise PostParseError(node.function.pos,
'The %s directive takes one compile-time boolean argument' % optname) 'The %s directive takes one compile-time boolean argument' % optname)
......
# cython: auto_cpdef=True, infer_types=all # cython: auto_cpdef=True, infer_types=True
# #
# Pyrex Parser # Pyrex Parser
# #
......
...@@ -2,7 +2,7 @@ import ExprNodes ...@@ -2,7 +2,7 @@ import ExprNodes
import Nodes import Nodes
import Builtin import Builtin
import PyrexTypes import PyrexTypes
from PyrexTypes import py_object_type, unspecified_type, spanning_type from PyrexTypes import py_object_type, unspecified_type
from Visitor import CythonTransform from Visitor import CythonTransform
try: try:
...@@ -131,7 +131,17 @@ class SimpleAssignmentTypeInferer: ...@@ -131,7 +131,17 @@ class SimpleAssignmentTypeInferer:
# TODO: Implement a real type inference algorithm. # TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...) # (Something more powerful than just extending this one...)
def infer_types(self, scope): def infer_types(self, scope):
which_types_to_infer = scope.directives['infer_types'] enabled = scope.directives['infer_types']
if enabled == True:
spanning_type = aggressive_spanning_type
elif enabled is None: # safe mode
spanning_type = safe_spanning_type
else:
for entry in scope.entries.values():
if entry.type is unspecified_type:
entry.type = py_object_type
return
dependancies_by_entry = {} # entry -> dependancies dependancies_by_entry = {} # entry -> dependancies
entries_by_dependancy = {} # dependancy -> entries entries_by_dependancy = {} # dependancy -> entries
ready_to_infer = [] ready_to_infer = []
...@@ -163,22 +173,20 @@ class SimpleAssignmentTypeInferer: ...@@ -163,22 +173,20 @@ class SimpleAssignmentTypeInferer:
entry = ready_to_infer.pop() entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments] types = [expr.infer_type(scope) for expr in entry.assignments]
if types: if types:
result_type = reduce(spanning_type, types) entry.type = spanning_type(types)
else: else:
# FIXME: raise a warning? # FIXME: raise a warning?
# print "No assignments", entry.pos, entry # print "No assignments", entry.pos, entry
result_type = py_object_type entry.type = py_object_type
entry.type = find_safe_type(result_type, which_types_to_infer)
resolve_dependancy(entry) resolve_dependancy(entry)
# Deal with simple circular dependancies... # Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items(): for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]): if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()] types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
if types: if types:
entry.type = reduce(spanning_type, types) entry.type = spanning_type(types)
types = [expr.infer_type(scope) for expr in entry.assignments] types = [expr.infer_type(scope) for expr in entry.assignments]
entry.type = reduce(spanning_type, types) # might be wider... entry.type = spanning_type(types) # might be wider...
entry.type = find_safe_type(entry.type, which_types_to_infer)
resolve_dependancy(entry) resolve_dependancy(entry)
del dependancies_by_entry[entry] del dependancies_by_entry[entry]
if ready_to_infer: if ready_to_infer:
...@@ -190,25 +198,39 @@ class SimpleAssignmentTypeInferer: ...@@ -190,25 +198,39 @@ class SimpleAssignmentTypeInferer:
for entry in dependancies_by_entry: for entry in dependancies_by_entry:
entry.type = py_object_type entry.type = py_object_type
def find_safe_type(result_type, which_types_to_infer): def find_spanning_type(type1, type2):
if which_types_to_infer == 'none': if type1 is type2:
return type1
elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
# type inference can break the coercion back to a Python bool
# if it returns an arbitrary int type here
return py_object_type return py_object_type
result_type = PyrexTypes.spanning_type(type1, type2)
if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type, Builtin.float_type): if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type, Builtin.float_type):
# Python's float type is just a C double, so it's safe to # Python's float type is just a C double, so it's safe to
# use the C type instead # use the C type instead
return PyrexTypes.c_double_type return PyrexTypes.c_double_type
return result_type
def aggressive_spanning_type(types):
result_type = reduce(find_spanning_type, types)
return result_type
if which_types_to_infer == 'all': def safe_spanning_type(types):
result_type = reduce(find_spanning_type, types)
if result_type.is_pyobject:
# any specific Python type is always safe to infer
return result_type
elif result_type is PyrexTypes.c_double_type:
# Python's float type is just a C double, so it's safe to use
# the C type instead
return result_type
elif result_type is PyrexTypes.c_bint_type:
# find_spanning_type() only returns 'bint' for clean boolean
# operations without other int types, so this is safe, too
return result_type return result_type
elif which_types_to_infer == 'safe':
if result_type.is_pyobject:
# any specific Python type is always safe to infer
return result_type
elif result_type is PyrexTypes.c_bint_type:
# 'bint' should behave exactly like Python's bool type ...
return PyrexTypes.c_bint_type
return py_object_type return py_object_type
def get_type_inferer(): def get_type_inferer():
return SimpleAssignmentTypeInferer() return SimpleAssignmentTypeInferer()
# cython: infer_types = all # cython: infer_types = True
cimport cython cimport cython
from cython cimport typeof, infer_types from cython cimport typeof, infer_types
##################################################
# type inference tests in 'full' mode
cdef class MyType: cdef class MyType:
pass pass
...@@ -148,8 +151,29 @@ def loop(): ...@@ -148,8 +151,29 @@ def loop():
pass pass
assert typeof(a) == "long" assert typeof(a) == "long"
cdef unicode retu():
return u"12345"
cdef bytes retb():
return b"12345"
def conditional(x):
"""
>>> conditional(True)
(True, 'Python object')
>>> conditional(False)
(False, 'Python object')
"""
if x:
a = retu()
else:
a = retb()
return type(a) is unicode, typeof(a)
##################################################
# type inference tests that work in 'safe' mode
@infer_types('safe') @infer_types(None)
def double_inference(): def double_inference():
""" """
>>> values, types = double_inference() >>> values, types = double_inference()
...@@ -172,7 +196,7 @@ cdef object some_float_value(): ...@@ -172,7 +196,7 @@ cdef object some_float_value():
@cython.test_assert_path_exists('//InPlaceAssignmentNode/NameNode', @cython.test_assert_path_exists('//InPlaceAssignmentNode/NameNode',
'//NameNode[@type.is_pyobject]', '//NameNode[@type.is_pyobject]',
'//NameNode[@type.is_pyobject = False]') '//NameNode[@type.is_pyobject = False]')
@infer_types('safe') @infer_types(None)
def double_loop(): def double_loop():
""" """
>>> double_loop() == 1.0 * 10 >>> double_loop() == 1.0 * 10
...@@ -184,26 +208,7 @@ def double_loop(): ...@@ -184,26 +208,7 @@ def double_loop():
d += 1.0 d += 1.0
return d return d
cdef unicode retu(): @infer_types(None)
return u"12345"
cdef bytes retb():
return b"12345"
def conditional(x):
"""
>>> conditional(True)
(True, 'Python object')
>>> conditional(False)
(False, 'Python object')
"""
if x:
a = retu()
else:
a = retb()
return type(a) is unicode, typeof(a)
@infer_types('safe')
def safe_only(): def safe_only():
""" """
>>> safe_only() >>> safe_only()
...@@ -215,7 +220,7 @@ def safe_only(): ...@@ -215,7 +220,7 @@ def safe_only():
c = MyType() c = MyType()
assert typeof(c) == "MyType", typeof(c) assert typeof(c) == "MyType", typeof(c)
@infer_types('safe') @infer_types(None)
def args_tuple_keywords(*args, **kwargs): def args_tuple_keywords(*args, **kwargs):
""" """
>>> args_tuple_keywords(1,2,3, a=1, b=2) >>> args_tuple_keywords(1,2,3, a=1, b=2)
...@@ -223,7 +228,7 @@ def args_tuple_keywords(*args, **kwargs): ...@@ -223,7 +228,7 @@ def args_tuple_keywords(*args, **kwargs):
assert typeof(args) == "tuple object", typeof(args) assert typeof(args) == "tuple object", typeof(args)
assert typeof(kwargs) == "dict object", typeof(kwargs) assert typeof(kwargs) == "dict object", typeof(kwargs)
@infer_types('safe') @infer_types(None)
def args_tuple_keywords_reassign_same(*args, **kwargs): def args_tuple_keywords_reassign_same(*args, **kwargs):
""" """
>>> args_tuple_keywords_reassign_same(1,2,3, a=1, b=2) >>> args_tuple_keywords_reassign_same(1,2,3, a=1, b=2)
...@@ -234,7 +239,7 @@ def args_tuple_keywords_reassign_same(*args, **kwargs): ...@@ -234,7 +239,7 @@ def args_tuple_keywords_reassign_same(*args, **kwargs):
args = () args = ()
kwargs = {} kwargs = {}
@infer_types('safe') @infer_types(None)
def args_tuple_keywords_reassign_pyobjects(*args, **kwargs): def args_tuple_keywords_reassign_pyobjects(*args, **kwargs):
""" """
>>> args_tuple_keywords_reassign_pyobjects(1,2,3, a=1, b=2) >>> args_tuple_keywords_reassign_pyobjects(1,2,3, a=1, b=2)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment