Commit 7846cbbc authored by da-woods's avatar da-woods Committed by GitHub

Include return type in fused types of function pointers (GH-4678)

For fused functions it makes sense that the return type is ignored
(a function can't be specialized based on return type alone) but
for function pointers the return type should be included (since
such a pointer might be an argument to a fused function)

Fixes https://github.com/cython/cython/issues/4644
parent 0b3ccd7f
......@@ -78,7 +78,7 @@ class BaseType(object):
"""
return self
def get_fused_types(self, result=None, seen=None, subtypes=None):
def get_fused_types(self, result=None, seen=None, subtypes=None, include_function_return_type=False):
subtypes = subtypes or self.subtypes
if not subtypes:
return None
......@@ -91,10 +91,10 @@ class BaseType(object):
list_or_subtype = getattr(self, attr)
if list_or_subtype:
if isinstance(list_or_subtype, BaseType):
list_or_subtype.get_fused_types(result, seen)
list_or_subtype.get_fused_types(result, seen, include_function_return_type=include_function_return_type)
else:
for subtype in list_or_subtype:
subtype.get_fused_types(result, seen)
subtype.get_fused_types(result, seen, include_function_return_type=include_function_return_type)
return result
......@@ -1845,7 +1845,7 @@ class FusedType(CType):
else:
raise CannotSpecialize()
def get_fused_types(self, result=None, seen=None):
def get_fused_types(self, result=None, seen=None, include_function_return_type=False):
if result is None:
return [self]
......@@ -2757,6 +2757,11 @@ class CPtrType(CPointerBaseType):
return self.base_type.find_cpp_operation_type(operator, operand_type)
return None
def get_fused_types(self, result=None, seen=None, include_function_return_type=False):
# For function pointers, include the return type - unlike for fused functions themselves,
# where the return type cannot be an independent fused type (i.e. is derived or non-fused).
return super(CPointerBaseType, self).get_fused_types(result, seen, include_function_return_type=True)
class CNullPtrType(CPtrType):
......@@ -3232,10 +3237,13 @@ class CFuncType(CType):
return result
def get_fused_types(self, result=None, seen=None, subtypes=None):
def get_fused_types(self, result=None, seen=None, subtypes=None, include_function_return_type=False):
"""Return fused types in the order they appear as parameter types"""
return super(CFuncType, self).get_fused_types(result, seen,
subtypes=['args'])
return super(CFuncType, self).get_fused_types(
result, seen,
# for function pointer types, we consider the result type; for plain function
# types we don't (because it must be derivable from the arguments)
subtypes=self.subtypes if include_function_return_type else ['args'])
def specialize_entry(self, entry, cname):
assert not self.is_fused
......@@ -3865,7 +3873,7 @@ class CppClassType(CType):
def is_template_type(self):
return self.templates is not None and self.template_type is None
def get_fused_types(self, result=None, seen=None):
def get_fused_types(self, result=None, seen=None, include_function_return_type=False):
if result is None:
result = []
seen = set()
......
......@@ -510,3 +510,48 @@ def convert_to_ptr(cython.floating x):
return handle_float(&x)
elif cython.floating is double:
return handle_double(&x)
cdef double get_double():
return 1.0
cdef float get_float():
return 0.0
cdef call_func_pointer(cython.floating (*f)()):
return f()
def test_fused_func_pointer():
"""
>>> test_fused_func_pointer()
1.0
0.0
"""
print(call_func_pointer(get_double))
print(call_func_pointer(get_float))
cdef double get_double_from_int(int i):
return i
cdef call_func_pointer_with_1(cython.floating (*f)(cython.integral)):
return f(1)
def test_fused_func_pointer2():
"""
>>> test_fused_func_pointer2()
1.0
"""
print(call_func_pointer_with_1(get_double_from_int))
cdef call_function_that_calls_fused_pointer(object (*f)(cython.floating (*)(cython.integral))):
if cython.floating is double and cython.integral is int:
return 5*f(get_double_from_int)
else:
return None # practically it's hard to make this kind of function useful...
def test_fused_func_pointer_multilevel():
"""
>>> test_fused_func_pointer_multilevel()
5.0
None
"""
print(call_function_that_calls_fused_pointer(call_func_pointer_with_1[double, int]))
print(call_function_that_calls_fused_pointer(call_func_pointer_with_1[float, int]))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment