Commit 430ec7cf authored by Mark Florisson's avatar Mark Florisson

Change semantics of fused types

parent 82c13a65
......@@ -2332,9 +2332,7 @@ class IndexNode(ExprNode):
"""
base_type = self.base.type
def err(msg, pos=None):
error(pos or self.pos, msg)
self.type = PyrexTypes.error_type
self.type = PyrexTypes.error_type
specific_types = []
positions = []
......@@ -2347,11 +2345,15 @@ class IndexNode(ExprNode):
positions.append(arg.pos)
specific_types.append(arg.analyse_as_type(env))
else:
return err("Can only index fused functions with types")
return error(self.pos, "Can only index fused functions with types")
fused_types = base_type.get_fused_types()
if len(specific_types) > len(fused_types):
return err("Too many types specified")
return error(self.pos, "Too many types specified")
elif len(specific_types) < len(fused_types):
t = fused_types[len(specific_types)]
return error(self.pos, "Not enough types specified to specialize "
"the function, %s is still fused" % t)
# See if our index types form valid specializations
for pos, specific_type, fused_type in zip(positions,
......@@ -2359,27 +2361,19 @@ class IndexNode(ExprNode):
fused_types):
if not Utils.any([specific_type.same_as(t)
for t in fused_type.types]):
return err("Type not in fused type", pos=pos)
return error(pos, "Type not in fused type")
if specific_type is None or specific_type.is_error:
return
fused_to_specific = dict(zip(fused_types, specific_types))
# If we are only partially fused, specialize accordingly
for fused_type in fused_types:
if fused_type not in fused_to_specific:
fused_to_specific[fused_type] = fused_type
type = base_type.specialize(fused_to_specific)
if type is not base_type:
import copy
e = copy.copy(base_type.entry)
e.type = type
type.entry = e
if not type.is_fused:
if type.is_fused:
# Only partially specific, this is invalid
error(self.pos,
"Index operation makes function only partially specific")
else:
# Fully specific, find the signature with the specialized entry
for signature in self.base.type.get_all_specific_function_types():
if type.same_as(signature):
......@@ -2387,9 +2381,6 @@ class IndexNode(ExprNode):
break
else:
assert False
else:
# Only partially specific
self.type = type
gil_message = "Indexing Python object"
......@@ -3117,8 +3108,10 @@ class SimpleCallNode(CallNode):
return
elif hasattr(self.function, 'entry'):
overloaded_entry = self.function.entry
elif isinstance(self.function, IndexNode) and self.function.type.is_fused:
elif (isinstance(self.function, IndexNode) and
self.function.base.type.is_fused):
overloaded_entry = self.function.type.entry
self.function.entry = self.function.type.entry
else:
overloaded_entry = None
......
......@@ -937,29 +937,24 @@ class FusedTypeNode(CBaseTypeNode):
child_attrs = []
def analyse(self, env):
self.types = [type.analyse_as_type(env) for type in self.types]
# Note: this list may still contain multiple of the same entries
types = [type.analyse_as_type(env) for type in self.types]
if len(self.types) == 1:
return self.types[0]
return types[0]
types = []
seen = cython.set()
for type_node, type in zip(self.types, types):
if type in seen:
error(type_node.pos, "Type specified multiple times")
else:
seen.add(type)
if type.is_fused:
error(type_node.pos, "Cannot fuse a fused type")
for type in self.types:
self.add_type(type, types, seen)
self.types = types
return PyrexTypes.FusedType(types)
def add_type(self, type, types, seen):
if type not in seen:
seen.add(type)
if type.is_fused:
for specific_type in PyrexTypes.get_specific_types(type):
self.add_type(specific_type, types, seen)
else:
types.append(type)
class CVarDefNode(StatNode):
# C variable definition or forward/extern function declaration.
......@@ -1202,14 +1197,11 @@ class CTypeDefNode(StatNode):
child_attrs = ["base_type", "declarator"]
def analyse_declarations(self, env):
"""
If we are a fused type, do a normal type declaration, as we want
declared variables to have a FusedType type, not a CTypeDefType.
"""
base = self.base_type.analyse(env)
name_declarator, type = self.declarator.analyse(base, env)
name = name_declarator.name
cname = name_declarator.cname
entry = env.declare_typedef(name, type, self.pos,
cname = cname, visibility = self.visibility, api = self.api)
......@@ -2040,6 +2032,9 @@ class FusedCFuncDefNode(StatListNode):
from Cython.Compiler import ParseTreeTransforms
permutations = self.node.type.get_all_specific_permutations()
# print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations])
for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node)
......
......@@ -1347,24 +1347,17 @@ class AnalyseExpressionsTransform(CythonTransform):
argument types with a NameNode referring to the function with
specialized entry and type.
"""
was_nested = self.nested_index_node
self.nested_index_node = True
self.visit_Node(node)
self.nested_index_node = was_nested
type = node.type
if type.is_cfunction and type.is_fused and not self.nested_index_node:
error(node.pos, "Not enough types were specified to indicate a "
"specialized function")
elif type.is_cfunction and node.base.type.is_fused:
while not node.is_name:
node = node.base
node.type = type
node.entry = type.entry
print node.entry.cname
return node
if type.is_cfunction and node.base.type.is_fused:
node = node.base
if not node.is_name:
error(node.pos, "Can only index a fused function once")
node.type = PyrexTypes.error_type
else:
node.type = type
node.entry = type.entry
return node
......@@ -1905,6 +1898,12 @@ class ReplaceFusedTypeChecks(VisitorTransform):
...
"""
# Defer the import until now to avoid circularity...
from Cython.Compiler import Optimize
transform = Optimize.ConstantFolding()
transform.check_constant_value_not_set = False
def __init__(self, local_scope):
super(ReplaceFusedTypeChecks, self).__init__()
self.local_scope = local_scope
......@@ -1914,12 +1913,8 @@ class ReplaceFusedTypeChecks(VisitorTransform):
Filters out any if clauses with false compile time type check
expression.
"""
from Cython.Compiler import Optimize
self.visitchildren(node)
transform = Optimize.ConstantFolding()
transform.check_constant_value_not_set = False
return transform(node)
return self.transform(node)
def visit_PrimaryCmpNode(self, node):
type1 = node.operand1.analyse_as_type(self.local_scope)
......@@ -1932,7 +1927,7 @@ class ReplaceFusedTypeChecks(VisitorTransform):
type1 = self.specialize_type(type1, node.operand1.pos)
op = node.operator
if op in ('is', 'is not', '==', '!='):
if op in ('is', 'is_not', '==', '!='):
type2 = self.specialize_type(type2, node.operand2.pos)
is_same = type1.same_as(type2)
......
......@@ -677,6 +677,7 @@ class FusedType(PyrexType):
"""
is_fused = 1
name = None
def __init__(self, types):
self.types = types
......
......@@ -12,10 +12,28 @@ dtype4 = cython.typedef(cython.fused_type(int, long, kw=None))
ctypedef public cython.fused_type(int, long) dtype7
ctypedef api cython.fused_type(int, long) dtype8
ctypedef cython.fused_type(short, short int, int) int_t
ctypedef cython.fused_type(int, long) int2_t
ctypedef cython.fused_type(int2_t, int) dtype9
ctypedef cython.fused_type(float, double) floating
cdef func(floating x, int2_t y):
print x, y
cdef float x = 10.0
cdef int y = 10
func[float](x, y)
func[float][int](x, y)
func[float, int](x)
func[float, int](x, y, y)
func(x, y=y)
# This is all valid
ctypedef fused_type(int, long, float) dtype5
ctypedef cython.fused_type(int, long) dtype6
func[float, int](x, y)
func(x, y)
_ERRORS = u"""
fused_types.pyx:7:13: Can only fuse types with cython.fused_type()
......@@ -24,4 +42,11 @@ fused_types.pyx:9:20: 'foo' is not a type identifier
fused_types.pyx:10:23: fused_type does not take keyword arguments
fused_types.pyx:12:0: Fused types cannot be public or api
fused_types.pyx:13:0: Fused types cannot be public or api
fused_types.pyx:15:34: Type specified multiple times
fused_types.pyx:17:27: Cannot fuse a fused type
fused_types.pyx:26:4: Not enough types specified to specialize the function, int2_t is still fused
fused_types.pyx:27:4: Not enough types specified to specialize the function, int2_t is still fused
fused_types.pyx:28:16: Call with wrong number of arguments (expected 2, got 1)
fused_types.pyx:29:16: Call with wrong number of arguments (expected 2, got 3)
fused_types.pyx:30:4: Keyword and starred arguments not allowed in cdef functions.
"""
......@@ -5,12 +5,15 @@ ctypedef char *string_t
ctypedef cython.fused_type(int, long, float, string_t) fused_t
ctypedef cython.fused_type(int, long) other_t
ctypedef cython.fused_type(short, short int, short, int) base_t
ctypedef cython.fused_type(short int, int) base_t
ctypedef cython.fused_type(float complex, double complex,
int complex, long complex) complex_t
ctypedef base_t **base_t_p_p
ctypedef cython.fused_type(char, base_t_p_p, fused_t, complex_t) composed_t
# ctypedef cython.fused_type(char, base_t_p_p, fused_t, complex_t) composed_t
ctypedef cython.fused_type(char, int, float, string_t, float complex,
double complex, int complex, long complex,
cython.p_p_int) composed_t
cdef func(fused_t a, other_t b):
......@@ -160,8 +163,8 @@ def test_composed_types():
(0.9+0.4j)
<BLANKLINE>
not a complex number
9 10
19
7 8
15
<BLANKLINE>
7 8
<BLANKLINE>
......@@ -177,7 +180,7 @@ def test_composed_types():
print result
print
print composed(c + 2, d + 2)
print composed(c, d)
print
composed(&cp, &dp)
......
# mode: run
cimport cython
#from cython cimport p_double, p_int
from cpython cimport Py_INCREF
from Cython import Shadow as pure_cython
ctypedef char * string_t
ctypedef cython.fused_type(float, double) floating
ctypedef cython.fused_type(int, long) integral
ctypedef cython.fused_type(int, long, float, double, string_t) fused_type1
ctypedef cython.fused_type(string_t) fused_type2
ctypedef fused_type1 *composed_t
ctypedef cython.fused_type(int, long, float, double) other_t
ctypedef double *p_double
ctypedef int *p_int
def test_pure():
......@@ -101,3 +108,76 @@ def test_fused_with_pointer():
print fused_with_pointer(float_array)
print
print fused_with_pointer(string_array)
cdef test_specialize(fused_type1 x, fused_type1 *y, composed_t z, other_t *a):
cdef fused_type1 result
if composed_t is p_double:
print "double pointer"
if fused_type1 in floating:
result = x + y[0] + z[0] + a[0]
return result
def test_specializations():
"""
>>> test_specializations()
double pointer
double pointer
double pointer
double pointer
double pointer
"""
cdef object (*f)(double, double *, double *, int *)
cdef double somedouble = 2.2
cdef double otherdouble = 3.3
cdef int someint = 4
cdef p_double somedouble_p = &somedouble
cdef p_double otherdouble_p = &otherdouble
cdef p_int someint_p = &someint
f = test_specialize
assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
f = <object (*)(double, double *, double *, int *)> test_specialize
assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
assert (<object (*)(double, double *, double *, int *)>
test_specialize)(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
f = test_specialize[double, int]
assert f(1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
assert test_specialize[double, int](1.1, somedouble_p, otherdouble_p, someint_p) == 10.6
# The following cases are not supported
# f = test_specialize[double][p_int]
# print f(1.1, somedouble_p, otherdouble_p)
# print
# print test_specialize[double][p_int](1.1, somedouble_p, otherdouble_p)
# print
# print test_specialize[double](1.1, somedouble_p, otherdouble_p)
# print
#cdef opt_args(integral x, floating y = 4.0):
# print x, y
def test_opt_args():
"""
ToDO: enable and fix
test_opt_args()
3 4.0
3 4.0
3 4.0
3 4.0
"""
#opt_args[int, float](3)
#opt_args[int, double](3)
#opt_args[int, float](3, 4.0)
#opt_args[int, double](3, 4.0)
......@@ -31,7 +31,7 @@ cdef public class MyExt [ type MyExtType, object MyExtObject ]:
ctypedef char *string_t
ctypedef cython.fused_type(int, float) simple_t
ctypedef cython.fused_type(simple_t, string_t) less_simple_t
ctypedef cython.fused_type(int, float, string_t) less_simple_t
ctypedef cython.fused_type(mystruct_t, myunion_t, MyExt) object_t
ctypedef cython.fused_type(str, unicode, bytes) builtin_t
......@@ -82,6 +82,3 @@ assert f(mystruct, 5).a == 10
f = add_simple[mystruct_t, int]
assert f(mystruct, 5).a == 10
f = add_simple[mystruct_t][int]
assert f(mystruct, 5).a == 10
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment