Commit 1472e87b authored by Mark Florisson's avatar Mark Florisson

Support fused cdef methods

parent a8305590
......@@ -595,7 +595,10 @@ class ExprNode(Node):
for signature in src_type.get_all_specific_function_types():
if signature.same_as(dst_type):
return CoerceFusedToSpecific(src, signature)
src.type = signature
src.entry = src.type.entry
src.entry.used = True
return self
error(self.pos, "Type is not specific")
self.type = error_type
......@@ -2330,10 +2333,9 @@ class IndexNode(ExprNode):
NameNode with specific entry just after analysis of expressions by
AnalyseExpressionsTransform.
"""
base_type = self.base.type
self.type = PyrexTypes.error_type
base_type = self.base.type
specific_types = []
positions = []
......@@ -2378,6 +2380,15 @@ class IndexNode(ExprNode):
for signature in self.base.type.get_all_specific_function_types():
if type.same_as(signature):
self.type = signature
if self.base.is_attribute:
# Pretend to be a normal attribute, for cdef extension
# methods
self.entry = signature.entry
self.is_attribute = self.base.is_attribute
self.obj = self.base.obj
self.entry.used = True
break
else:
assert False
......@@ -3026,11 +3037,13 @@ class SimpleCallNode(CallNode):
function = self.function
function.is_called = 1
self.function.analyse_types(env)
if function.is_attribute and function.entry and function.entry.is_cmethod:
# Take ownership of the object from which the attribute
# was obtained, because we need to pass it as 'self'.
self.self = function.obj
function.obj = CloneNode(self.self)
func_type = self.function_type()
if func_type.is_pyobject:
self.arg_tuple = TupleNode(self.pos, args = self.args)
......@@ -3111,15 +3124,13 @@ class SimpleCallNode(CallNode):
elif (isinstance(self.function, IndexNode) and
self.function.base.type.is_fused):
overloaded_entry = self.function.type.entry
self.function.entry = self.function.type.entry
else:
overloaded_entry = None
if overloaded_entry:
if self.function.type.is_fused:
alternatives = []
PyrexTypes.map_with_specific_entries(self.function.entry,
alternatives.append)
functypes = self.function.type.get_all_specific_function_types()
alternatives = [f.entry for f in functypes]
else:
alternatives = overloaded_entry.all_alternatives()
......@@ -3129,6 +3140,8 @@ class SimpleCallNode(CallNode):
self.type = PyrexTypes.error_type
self.result_code = "<error>"
return
entry.used = True
self.function.entry = entry
self.function.type = entry.type
func_type = self.function_type()
......@@ -3812,6 +3825,12 @@ class AttributeNode(ExprNode):
#print "...obj_code =", obj_code ###
if self.entry and self.entry.is_cmethod:
if obj.type.is_extension_type:
# If the attribute was specialized through indexing, make sure
# to get the right fused name, as our entry was replaced by our
# parent index node (AnalyseExpressionsTransform)
if self.type.from_fused:
self.member = self.entry.cname
return "((struct %s *)%s%s%s)->%s" % (
obj.type.vtabstruct_cname, obj_code, self.op,
obj.type.vtabslot_cname, self.member)
......@@ -7379,18 +7398,6 @@ class CoercionNode(ExprNode):
file, line, col = self.pos
code.annotate((file, line, col-1), AnnotationItem(style='coerce', tag='coerce', text='[%s] to [%s]' % (self.arg.type, self.type)))
class CoerceFusedToSpecific(CoercionNode):
def __init__(self, arg, dst_type):
super(CoerceFusedToSpecific, self).__init__(arg)
self.type = dst_type
self.specialized_cname = dst_type.entry.cname
def calculate_result_code(self):
return self.specialized_cname
def generate_result_code(self, code):
pass
class CastNode(CoercionNode):
# Wrap a node in a C type cast.
......
......@@ -156,19 +156,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
f.close()
def generate_public_declaration(self, entry, h_code, i_code):
PyrexTypes.map_with_specific_entries(entry,
self._generate_public_declaration,
h_code,
i_code)
def _generate_public_declaration(self, entry, h_code, i_code):
h_code.putln("%s %s;" % (
Naming.extern_c_macro,
entry.type.declaration_code(
entry.cname, dll_linkage = "DL_IMPORT")))
if i_code:
i_code.putln("cdef extern %s" %
entry.type.declaration_code(cname, pyrex = 1))
entry.type.declaration_code(entry.cname, pyrex = 1))
def api_name(self, env):
return env.qualified_name.replace(".", "__")
......@@ -992,13 +986,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
dll_linkage = "DL_EXPORT", definition = definition)
def generate_cfunction_predeclarations(self, env, code, definition):
func = self._generate_cfunction_predeclaration
for entry in env.cfunc_entries:
PyrexTypes.map_with_specific_entries(entry, func, code, definition)
should_declare = (not entry.in_cinclude and
(definition or entry.defined_in_pxd or
entry.visibility == 'extern'))
def _generate_cfunction_predeclaration(self, entry, code, definition):
if entry.inline_func_in_pxd or (not entry.in_cinclude and (definition
or entry.defined_in_pxd or entry.visibility == 'extern')):
if entry.used and (entry.inline_func_in_pxd or should_declare):
if entry.visibility == 'public':
storage_class = "%s " % Naming.extern_c_macro
dll_linkage = "DL_EXPORT"
......@@ -2048,28 +2041,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_c_function_export_code(self, env, code):
# Generate code to create PyCFunction wrappers for exported C functions.
func = self._generate_c_function_export_code
for entry in env.cfunc_entries:
from_fused = entry.type.is_fused
if entry.api or entry.defined_in_pxd:
PyrexTypes.map_with_specific_entries(entry, func, env,
code, from_fused)
def _generate_c_function_export_code(self, entry, env, code, from_fused):
env.use_utility_code(function_export_utility_code)
signature = entry.type.signature_string()
s = 'if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s'
if from_fused:
# Specific version of a fused function. Fused functions can never
# be declared public or api, but they may need to be exported when
# declared in a .pxd. We need to give them a unique name in that
# case
name = entry.cname
else:
name = entry.name
code.putln(s % (name, entry.cname, signature, code.error_goto(self.pos)))
code.putln('if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s' % (
entry.name,
entry.cname,
signature,
code.error_goto(self.pos)))
def generate_type_import_code_for_module(self, module, env, code):
# Generate type import code for all exported extension types in
......@@ -2095,29 +2075,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
module.qualified_name,
temp,
code.error_goto(self.pos)))
for entry in entries:
PyrexTypes.map_with_specific_entries(entry,
self._import_cdef_func,
code,
temp,
entry.type.is_fused)
code.putln("Py_DECREF(%s); %s = 0;" % (temp, temp))
def _import_cdef_func(self, entry, code, temp, from_fused):
if from_fused:
name = entry.cname
else:
name = entry.name
code.putln(
'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s' % (
temp,
name,
entry.name,
entry.cname,
entry.type.signature_string(),
code.error_goto(self.pos)))
code.putln("Py_DECREF(%s); %s = 0;" % (temp, temp))
def generate_type_init_code(self, env, code):
# Generate type import code for extern extension types
......
......@@ -1749,7 +1749,6 @@ class CFuncDefNode(FuncDefNode):
inline_in_pxd = False
decorators = None
directive_locals = None
cname_postfix = None
def unqualified_name(self):
return self.entry.name
......@@ -2050,6 +2049,8 @@ class FusedCFuncDefNode(StatListNode):
if n.cfunc_declarator.optional_arg_count:
assert n.type.op_arg_struct
assert n.type.entry
assert node.type.is_fused
node.entry.fused_cfunction = self
......@@ -2065,19 +2066,29 @@ class FusedCFuncDefNode(StatListNode):
# print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations])
env.cfunc_entries.remove(self.node.entry)
for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node)
# Make the types in our CFuncType specific
newtype = copied_node.type.specialize(fused_to_specific)
copied_node.type = newtype
copied_node.entry.type = newtype
newtype.entry = copied_node.entry
type = copied_node.type.specialize(fused_to_specific)
entry = copied_node.entry
copied_node.type = type
entry.type, type.entry = type, entry
entry.used = (entry.used or
self.node.entry.defined_in_pxd or
env.is_c_class_scope or
entry.is_cmethod)
if self.node.cfunc_declarator.optional_arg_count:
self.node.cfunc_declarator.declare_optional_arg_struct(
newtype, env, fused_cname=cname)
type, env, fused_cname=cname)
copied_node.return_type = newtype.return_type
copied_node.return_type = type.return_type
copied_node.create_local_scope(env)
copied_node.local_scope.fused_to_specific = fused_to_specific
......@@ -2090,8 +2101,8 @@ class FusedCFuncDefNode(StatListNode):
for arg in copied_node.cfunc_declarator.args:
arg.type = arg.type.specialize(fused_to_specific)
cname = self.node.type.get_specific_cname(cname)
copied_node.entry.func_cname = copied_node.entry.cname = cname
type.specialize_entry(entry, cname)
env.cfunc_entries.append(entry)
num_errors = Errors.num_errors
transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
......@@ -2101,6 +2112,23 @@ class FusedCFuncDefNode(StatListNode):
if Errors.num_errors > num_errors:
break
def generate_function_definitions(self, env, code):
for stat in self.stats:
# print stat.entry, stat.entry.used
if stat.entry.used:
stat.generate_function_definitions(env, code)
def generate_execution_code(self, code):
for stat in self.stats:
if stat.entry.used:
code.mark_pos(stat.pos)
stat.generate_execution_code(code)
def annotate(self, code):
for stat in self.stats:
if stat.entry.used:
stat.annotate(code)
class PyArgDeclNode(Node):
# Argument which must be a Python object (used
......
......@@ -1344,20 +1344,17 @@ class AnalyseExpressionsTransform(CythonTransform):
def visit_IndexNode(self, node):
"""
Replace index nodes used to specialize cdef functions with fused
argument types with a NameNode referring to the function with
specialized entry and type.
argument types with the Attribute- or NameNode referring to the
function. We then need to copy over the specialization properties to
the attribute or name node.
"""
self.visit_Node(node)
type = node.type
if type.is_cfunction and node.base.type.is_fused:
node.base.type = node.type
node.base.entry = node.type.entry
node = node.base
if not node.is_name:
error(node.pos, "Can only index a fused function once")
node.type = PyrexTypes.error_type
else:
node.type = type
node.entry = type.entry
return node
......
......@@ -39,20 +39,14 @@ class BaseType(object):
"""
return self
def get_fused_types(self, result=None, seen=None):
if self.subtypes:
def add_fused_types(types):
for type in types or ():
if type not in seen:
seen.add(type)
result.append(type)
def get_fused_types(self, result=None, seen=None, subtypes=None):
subtypes = subtypes or self.subtypes
if subtypes:
if result is None:
result = []
seen = cython.set()
for attr in self.subtypes:
for attr in subtypes:
list_or_subtype = getattr(self, attr)
if isinstance(list_or_subtype, BaseType):
......@@ -1763,10 +1757,13 @@ class CFuncType(CType):
# with_gil boolean Acquire gil around function body
# templates [string] or None
# cached_specialized_types [CFuncType] cached specialized versions of the CFuncType if defined in a pxd
# from_fused boolean Indicates whether this is a specialized
# C function
is_cfunction = 1
original_sig = None
cached_specialized_types = None
from_fused = False
subtypes = ['return_type', 'args']
......@@ -1994,7 +1991,7 @@ class CFuncType(CType):
else:
new_templates = [v.specialize(values) for v in self.templates]
return CFuncType(self.return_type.specialize(values),
result = CFuncType(self.return_type.specialize(values),
[arg.specialize(values) for arg in self.args],
has_varargs = 0,
exception_value = self.exception_value,
......@@ -2006,6 +2003,9 @@ class CFuncType(CType):
optional_arg_count = self.optional_arg_count,
templates = new_templates)
result.from_fused = self.is_fused
return result
def opt_arg_cname(self, arg_name):
return self.op_arg_struct.base_type.scope.lookup(arg_name).cname
......@@ -2040,6 +2040,10 @@ class CFuncType(CType):
elif self.cached_specialized_types is not None:
return self.cached_specialized_types
cfunc_entries = self.entry.scope.cfunc_entries
cfunc_entries.remove(self.entry)
result = []
permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations:
......@@ -2050,55 +2054,46 @@ class CFuncType(CType):
self.declare_opt_arg_struct(new_func_type, cname)
new_entry = copy.deepcopy(self.entry)
new_entry.cname = self.get_specific_cname(cname)
new_func_type.specialize_entry(new_entry, cname)
new_entry.type = new_func_type
new_func_type.entry = new_entry
result.append(new_func_type)
cfunc_entries.append(new_entry)
self.cached_specialized_types = result
return result
def get_specific_cname(self, fused_cname):
"""
Given the cname for a permutation of fused types, return the cname
for the corresponding function with specific types.
"""
assert self.is_fused
return get_fused_cname(fused_cname, self.entry.func_cname)
def get_fused_types(self, result=None, seen=None, subtypes=None):
"Return fused types in the order they appear as parameter types"
return super(CFuncType, self).get_fused_types(result, seen,
subtypes=['args'])
def specialize_entry(self, entry, cname):
assert not self.is_fused
entry.name = get_fused_cname(cname, entry.name)
if entry.is_cmethod:
entry.cname = entry.name
if entry.is_inherited:
entry.cname = "%s.%s" % (Naming.obj_base_cname, entry.cname)
else:
entry.cname = get_fused_cname(cname, entry.cname)
if entry.func_cname:
entry.func_cname = get_fused_cname(cname, entry.func_cname)
def get_fused_cname(fused_cname, orig_cname):
"""
Given the fused cname id and an original cname, return a specialized cname
"""
assert fused_cname and orig_cname
return '%s%s%s' % (Naming.fused_func_prefix, fused_cname, orig_cname)
def map_with_specific_entries(entry, func, *args, **kwargs):
"""
Call func for every specific function instance. If this is not a
signature with fused types, call it with the entry for this cdef
function.
"""
type = entry.type
if type.is_cfunction and (entry.fused_cfunction or type.is_fused):
if entry.fused_cfunction:
# cdef with fused types defined in this file
for cfunction in entry.fused_cfunction.nodes:
func(cfunction.entry, *args, **kwargs)
else:
# cdef with fused types defined in another file, create their
# signatures
for func_type in type.get_all_specific_function_types():
func(func_type.entry, *args, **kwargs)
else:
# a normal cdef or not a c function
func(entry, *args, **kwargs)
def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0]
result = []
......
......@@ -734,6 +734,7 @@ class Scope(object):
else:
return outer.is_cpp()
class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname
......@@ -1696,6 +1697,7 @@ class CClassScope(ClassScope):
if defining:
entry.func_cname = self.mangle(Naming.func_prefix, name)
entry.utility_code = utility_code
type.entry = entry
return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers):
......@@ -1744,6 +1746,14 @@ class CClassScope(ClassScope):
base_entry.type, None, 'private')
entry.is_variable = 1
self.inherited_var_entries.append(entry)
# If the class defined in a pxd, specific entries have not been added.
# Ensure now that the parent (base) scope has specific entries
# Iterate over a copy as get_all_specific_function_types() will mutate
for base_entry in base_scope.cfunc_entries[:]:
if base_entry.type.is_fused:
base_entry.type.get_all_specific_function_types()
for base_entry in base_scope.cfunc_entries:
entry = self.add_cfunction(base_entry.name, base_entry.type,
base_entry.pos, adapt(base_entry.cname),
......@@ -1819,6 +1829,7 @@ class CppClassScope(Scope):
if prev_entry:
entry.overloaded_alternatives = prev_entry.all_alternatives()
entry.utility_code = utility_code
type.entry = entry
return entry
def declare_inherited_cpp_attributes(self, base_scope):
......
......@@ -169,8 +169,7 @@ cdef opt_args(integral x, floating y = 4.0):
def test_opt_args():
"""
ToDO: enable and fix
test_opt_args()
>>> test_opt_args()
3 4.0
3 4.0
3 4.0
......
......@@ -39,6 +39,12 @@ cdef object_t add_simple(object_t obj, simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple)
cdef public_optional_args(object_t obj, simple_t simple = *)
ctypedef cython.fused_type(float, double) floating
ctypedef cython.fused_type(int, long) integral
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y)
######## header.h ########
typedef int extern_int;
......@@ -58,6 +64,16 @@ cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple):
cdef public_optional_args(object_t obj, simple_t simple = 6):
return obj.a, simple
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y):
if integral is int:
x += 1
if floating is double:
y += 2.0
return x + y
######## b.pyx ########
from a cimport *
......@@ -92,3 +108,42 @@ assert public_optional_args[mystruct_t, int](mystruct) == (5, 6)
assert public_optional_args[mystruct_t, float](mystruct) == (5, 6.0)
assert public_optional_args[mystruct_t, float](mystruct, 7.0) == (5, 7.0)
cdef TestFusedExtMethods obj = TestFusedExtMethods()
cdef int x = 4
cdef float y = 5.0
cdef long a = 6
cdef double b = 7.0
cdef double (*func)(TestFusedExtMethods, long, double)
func = obj.method
assert func(obj, a, b) == 15.0
func = <double (*)(TestFusedExtMethods, long, double)> obj.method
assert func(obj, x, y) == 11.0
func = obj.method[long, double]
assert func(obj, a, y) == 13.0
assert obj.method(x, <double> a) == 13.0
assert obj.method[int, double](x, b) == 14.0
# Test inheritance
cdef class Subclass(TestFusedExtMethods):
cdef floating method(self, integral x, floating y):
return -x -y
cdef Subclass myobj = Subclass()
assert myobj.method[int, float](5, 5.0) == -10
cdef float (*meth)(Subclass, int, float)
meth = myobj.method
assert meth(myobj, 5, 5.0) == -10
meth = myobj.method[int, float]
assert meth(myobj, 5, 5.0) == -10
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment