Commit ba7df0bd authored by Robert Bradshaw's avatar Robert Bradshaw

Merge remote-tracking branch 'main/master'

parents 383f4776 28eebe43
...@@ -32,6 +32,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -32,6 +32,8 @@ class FusedCFuncDefNode(StatListNode):
specializations specializations
code_object CodeObjectNode shared by all specializations and the code_object CodeObjectNode shared by all specializations and the
fused function fused function
fused_compound_types All fused (compound) types (e.g. floating[:])
""" """
__signatures__ = None __signatures__ = None
...@@ -76,6 +78,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -76,6 +78,8 @@ class FusedCFuncDefNode(StatListNode):
[arg.type for arg in self.node.args if arg.type.is_fused]) [arg.type for arg in self.node.args if arg.type.is_fused])
permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types) permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types)
self.fused_compound_types = fused_compound_types
if self.node.entry in env.pyfunc_entries: if self.node.entry in env.pyfunc_entries:
env.pyfunc_entries.remove(self.node.entry) env.pyfunc_entries.remove(self.node.entry)
...@@ -121,6 +125,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -121,6 +125,7 @@ class FusedCFuncDefNode(StatListNode):
env.pyfunc_entries.remove(orig_py_func.entry) env.pyfunc_entries.remove(orig_py_func.entry)
fused_types = self.node.type.get_fused_types() fused_types = self.node.type.get_fused_types()
self.fused_compound_types = fused_types
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node) copied_node = copy.deepcopy(self.node)
...@@ -656,6 +661,12 @@ class FusedCFuncDefNode(StatListNode): ...@@ -656,6 +661,12 @@ class FusedCFuncDefNode(StatListNode):
Analyse the expressions. Take care to only evaluate default arguments Analyse the expressions. Take care to only evaluate default arguments
once and clone the result for all specializations once and clone the result for all specializations
""" """
for fused_compound_type in self.fused_compound_types:
for fused_type in fused_compound_type.get_fused_types():
for specialization_type in fused_type.types:
if specialization_type.is_complex:
specialization_type.create_declaration_utility_code(env)
if self.py_func: if self.py_func:
self.__signatures__.analyse_expressions(env) self.__signatures__.analyse_expressions(env)
self.py_func.analyse_expressions(env) self.py_func.analyse_expressions(env)
......
...@@ -1074,6 +1074,11 @@ class CVarDefNode(StatNode): ...@@ -1074,6 +1074,11 @@ class CVarDefNode(StatNode):
self.dest_scope = dest_scope self.dest_scope = dest_scope
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
if base_type.is_fused and not self.in_pxd and (env.is_c_class_scope or
env.is_module_scope):
error(self.pos, "Fused types not allowed here")
return error_type
self.entry = None self.entry = None
visibility = self.visibility visibility = self.visibility
......
...@@ -2734,7 +2734,7 @@ def p_c_func_or_var_declaration(s, pos, ctx): ...@@ -2734,7 +2734,7 @@ def p_c_func_or_var_declaration(s, pos, ctx):
visibility = ctx.visibility, visibility = ctx.visibility,
base_type = base_type, base_type = base_type,
declarators = declarators, declarators = declarators,
in_pxd = ctx.level == 'module_pxd', in_pxd = ctx.level in ('module_pxd', 'c_class_pxd'),
api = ctx.api, api = ctx.api,
overridable = ctx.overridable) overridable = ctx.overridable)
return result return result
......
...@@ -38,6 +38,16 @@ def f(memslice_dtype_t[:, :] a): ...@@ -38,6 +38,16 @@ def f(memslice_dtype_t[:, :] a):
lambda cython.integral i: i lambda cython.integral i: i
cdef cython.floating x
cdef class Foo(object):
cdef cython.floating attr
def outer(cython.floating f):
def inner():
cdef cython.floating g
# This is all valid # This is all valid
dtype5 = fused_type(int, long, float) dtype5 = fused_type(int, long, float)
dtype6 = cython.fused_type(int, long) dtype6 = cython.fused_type(int, long)
...@@ -53,18 +63,21 @@ ctypedef fused fused2: ...@@ -53,18 +63,21 @@ ctypedef fused fused2:
func(x, y) func(x, y)
_ERRORS = u""" _ERRORS = u"""
fused_types.pyx:10:15: fused_type does not take keyword arguments 10:15: fused_type does not take keyword arguments
fused_types.pyx:15:38: Type specified multiple times 15:38: Type specified multiple times
fused_types.pyx:17:33: Cannot fuse a fused type 17:33: Cannot fuse a fused type
fused_types.pyx:26:4: Invalid use of fused types, type cannot be specialized 26:4: Invalid use of fused types, type cannot be specialized
fused_types.pyx:26:4: Not enough types specified to specialize the function, int2_t is still fused 26:4: Not enough types specified to specialize the function, int2_t is still fused
fused_types.pyx:27:4: Invalid use of fused types, type cannot be specialized 27:4: Invalid use of fused types, type cannot be specialized
fused_types.pyx:27:4: Not enough types specified to specialize the function, int2_t is still fused 27:4: Not enough types specified to specialize the function, int2_t is still fused
fused_types.pyx:28:16: Call with wrong number of arguments (expected 2, got 1) 28:16: Call with wrong number of arguments (expected 2, got 1)
fused_types.pyx:29:16: Call with wrong number of arguments (expected 2, got 3) 29:16: Call with wrong number of arguments (expected 2, got 3)
fused_types.pyx:30:4: Invalid use of fused types, type cannot be specialized 30:4: Invalid use of fused types, type cannot be specialized
fused_types.pyx:30:4: Keyword and starred arguments not allowed in cdef functions. 30:4: Keyword and starred arguments not allowed in cdef functions.
fused_types.pyx:36:6: Invalid base type for memoryview slice: int * 36:6: Invalid base type for memoryview slice: int *
fused_types.pyx:39:0: Fused lambdas not allowed 39:0: Fused lambdas not allowed
42:5: Fused types not allowed here
45:9: Fused types not allowed here
""" """
...@@ -272,3 +272,13 @@ def test_fused_memslice_dtype(cython.floating[:] array): ...@@ -272,3 +272,13 @@ def test_fused_memslice_dtype(cython.floating[:] array):
cdef cython.floating[:] otherarray = array[0:100:1] cdef cython.floating[:] otherarray = array[0:100:1]
print cython.typeof(array), cython.typeof(otherarray), \ print cython.typeof(array), cython.typeof(otherarray), \
array[5], otherarray[6] array[5], otherarray[6]
def test_cython_numeric(cython.numeric arg):
"""
Test to see whether complex numbers have their utility code declared
properly.
>>> test_cython_numeric(10.0 + 1j)
double complex (10+1j)
"""
print cython.typeof(arg), arg
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment