Commit 1472e87b authored by Mark Florisson's avatar Mark Florisson

Support fused cdef methods

parent a8305590
...@@ -595,7 +595,10 @@ class ExprNode(Node): ...@@ -595,7 +595,10 @@ class ExprNode(Node):
for signature in src_type.get_all_specific_function_types(): for signature in src_type.get_all_specific_function_types():
if signature.same_as(dst_type): if signature.same_as(dst_type):
return CoerceFusedToSpecific(src, signature) src.type = signature
src.entry = src.type.entry
src.entry.used = True
return self
error(self.pos, "Type is not specific") error(self.pos, "Type is not specific")
self.type = error_type self.type = error_type
...@@ -2330,10 +2333,9 @@ class IndexNode(ExprNode): ...@@ -2330,10 +2333,9 @@ class IndexNode(ExprNode):
NameNode with specific entry just after analysis of expressions by NameNode with specific entry just after analysis of expressions by
AnalyseExpressionsTransform. AnalyseExpressionsTransform.
""" """
base_type = self.base.type
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
base_type = self.base.type
specific_types = [] specific_types = []
positions = [] positions = []
...@@ -2378,6 +2380,15 @@ class IndexNode(ExprNode): ...@@ -2378,6 +2380,15 @@ class IndexNode(ExprNode):
for signature in self.base.type.get_all_specific_function_types(): for signature in self.base.type.get_all_specific_function_types():
if type.same_as(signature): if type.same_as(signature):
self.type = signature self.type = signature
if self.base.is_attribute:
# Pretend to be a normal attribute, for cdef extension
# methods
self.entry = signature.entry
self.is_attribute = self.base.is_attribute
self.obj = self.base.obj
self.entry.used = True
break break
else: else:
assert False assert False
...@@ -3026,11 +3037,13 @@ class SimpleCallNode(CallNode): ...@@ -3026,11 +3037,13 @@ class SimpleCallNode(CallNode):
function = self.function function = self.function
function.is_called = 1 function.is_called = 1
self.function.analyse_types(env) self.function.analyse_types(env)
if function.is_attribute and function.entry and function.entry.is_cmethod: if function.is_attribute and function.entry and function.entry.is_cmethod:
# Take ownership of the object from which the attribute # Take ownership of the object from which the attribute
# was obtained, because we need to pass it as 'self'. # was obtained, because we need to pass it as 'self'.
self.self = function.obj self.self = function.obj
function.obj = CloneNode(self.self) function.obj = CloneNode(self.self)
func_type = self.function_type() func_type = self.function_type()
if func_type.is_pyobject: if func_type.is_pyobject:
self.arg_tuple = TupleNode(self.pos, args = self.args) self.arg_tuple = TupleNode(self.pos, args = self.args)
...@@ -3111,15 +3124,13 @@ class SimpleCallNode(CallNode): ...@@ -3111,15 +3124,13 @@ class SimpleCallNode(CallNode):
elif (isinstance(self.function, IndexNode) and elif (isinstance(self.function, IndexNode) and
self.function.base.type.is_fused): self.function.base.type.is_fused):
overloaded_entry = self.function.type.entry overloaded_entry = self.function.type.entry
self.function.entry = self.function.type.entry
else: else:
overloaded_entry = None overloaded_entry = None
if overloaded_entry: if overloaded_entry:
if self.function.type.is_fused: if self.function.type.is_fused:
alternatives = [] functypes = self.function.type.get_all_specific_function_types()
PyrexTypes.map_with_specific_entries(self.function.entry, alternatives = [f.entry for f in functypes]
alternatives.append)
else: else:
alternatives = overloaded_entry.all_alternatives() alternatives = overloaded_entry.all_alternatives()
...@@ -3129,6 +3140,8 @@ class SimpleCallNode(CallNode): ...@@ -3129,6 +3140,8 @@ class SimpleCallNode(CallNode):
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.result_code = "<error>" self.result_code = "<error>"
return return
entry.used = True
self.function.entry = entry self.function.entry = entry
self.function.type = entry.type self.function.type = entry.type
func_type = self.function_type() func_type = self.function_type()
...@@ -3812,6 +3825,12 @@ class AttributeNode(ExprNode): ...@@ -3812,6 +3825,12 @@ class AttributeNode(ExprNode):
#print "...obj_code =", obj_code ### #print "...obj_code =", obj_code ###
if self.entry and self.entry.is_cmethod: if self.entry and self.entry.is_cmethod:
if obj.type.is_extension_type: if obj.type.is_extension_type:
# If the attribute was specialized through indexing, make sure
# to get the right fused name, as our entry was replaced by our
# parent index node (AnalyseExpressionsTransform)
if self.type.from_fused:
self.member = self.entry.cname
return "((struct %s *)%s%s%s)->%s" % ( return "((struct %s *)%s%s%s)->%s" % (
obj.type.vtabstruct_cname, obj_code, self.op, obj.type.vtabstruct_cname, obj_code, self.op,
obj.type.vtabslot_cname, self.member) obj.type.vtabslot_cname, self.member)
...@@ -7379,18 +7398,6 @@ class CoercionNode(ExprNode): ...@@ -7379,18 +7398,6 @@ class CoercionNode(ExprNode):
file, line, col = self.pos file, line, col = self.pos
code.annotate((file, line, col-1), AnnotationItem(style='coerce', tag='coerce', text='[%s] to [%s]' % (self.arg.type, self.type))) code.annotate((file, line, col-1), AnnotationItem(style='coerce', tag='coerce', text='[%s] to [%s]' % (self.arg.type, self.type)))
class CoerceFusedToSpecific(CoercionNode):
def __init__(self, arg, dst_type):
super(CoerceFusedToSpecific, self).__init__(arg)
self.type = dst_type
self.specialized_cname = dst_type.entry.cname
def calculate_result_code(self):
return self.specialized_cname
def generate_result_code(self, code):
pass
class CastNode(CoercionNode): class CastNode(CoercionNode):
# Wrap a node in a C type cast. # Wrap a node in a C type cast.
......
...@@ -156,19 +156,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -156,19 +156,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
f.close() f.close()
def generate_public_declaration(self, entry, h_code, i_code): def generate_public_declaration(self, entry, h_code, i_code):
PyrexTypes.map_with_specific_entries(entry,
self._generate_public_declaration,
h_code,
i_code)
def _generate_public_declaration(self, entry, h_code, i_code):
h_code.putln("%s %s;" % ( h_code.putln("%s %s;" % (
Naming.extern_c_macro, Naming.extern_c_macro,
entry.type.declaration_code( entry.type.declaration_code(
entry.cname, dll_linkage = "DL_IMPORT"))) entry.cname, dll_linkage = "DL_IMPORT")))
if i_code: if i_code:
i_code.putln("cdef extern %s" % i_code.putln("cdef extern %s" %
entry.type.declaration_code(cname, pyrex = 1)) entry.type.declaration_code(entry.cname, pyrex = 1))
def api_name(self, env): def api_name(self, env):
return env.qualified_name.replace(".", "__") return env.qualified_name.replace(".", "__")
...@@ -992,39 +986,38 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -992,39 +986,38 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
dll_linkage = "DL_EXPORT", definition = definition) dll_linkage = "DL_EXPORT", definition = definition)
def generate_cfunction_predeclarations(self, env, code, definition): def generate_cfunction_predeclarations(self, env, code, definition):
func = self._generate_cfunction_predeclaration
for entry in env.cfunc_entries: for entry in env.cfunc_entries:
PyrexTypes.map_with_specific_entries(entry, func, code, definition) should_declare = (not entry.in_cinclude and
(definition or entry.defined_in_pxd or
def _generate_cfunction_predeclaration(self, entry, code, definition): entry.visibility == 'extern'))
if entry.inline_func_in_pxd or (not entry.in_cinclude and (definition
or entry.defined_in_pxd or entry.visibility == 'extern')): if entry.used and (entry.inline_func_in_pxd or should_declare):
if entry.visibility == 'public': if entry.visibility == 'public':
storage_class = "%s " % Naming.extern_c_macro storage_class = "%s " % Naming.extern_c_macro
dll_linkage = "DL_EXPORT" dll_linkage = "DL_EXPORT"
elif entry.visibility == 'extern': elif entry.visibility == 'extern':
storage_class = "%s " % Naming.extern_c_macro storage_class = "%s " % Naming.extern_c_macro
dll_linkage = "DL_IMPORT" dll_linkage = "DL_IMPORT"
elif entry.visibility == 'private': elif entry.visibility == 'private':
storage_class = "static " storage_class = "static "
dll_linkage = None dll_linkage = None
else: else:
storage_class = "static " storage_class = "static "
dll_linkage = None dll_linkage = None
type = entry.type type = entry.type
if not definition and entry.defined_in_pxd: if not definition and entry.defined_in_pxd:
type = CPtrType(type) type = CPtrType(type)
header = type.declaration_code(entry.cname, header = type.declaration_code(entry.cname,
dll_linkage = dll_linkage) dll_linkage = dll_linkage)
if entry.func_modifiers: if entry.func_modifiers:
modifiers = "%s " % ' '.join(entry.func_modifiers).upper() modifiers = "%s " % ' '.join(entry.func_modifiers).upper()
else: else:
modifiers = '' modifiers = ''
code.putln("%s%s%s; /*proto*/" % ( code.putln("%s%s%s; /*proto*/" % (
storage_class, storage_class,
modifiers, modifiers,
header)) header))
def generate_typeobj_definitions(self, env, code): def generate_typeobj_definitions(self, env, code):
full_module_name = env.qualified_name full_module_name = env.qualified_name
...@@ -2048,28 +2041,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2048,28 +2041,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_c_function_export_code(self, env, code): def generate_c_function_export_code(self, env, code):
# Generate code to create PyCFunction wrappers for exported C functions. # Generate code to create PyCFunction wrappers for exported C functions.
func = self._generate_c_function_export_code
for entry in env.cfunc_entries: for entry in env.cfunc_entries:
from_fused = entry.type.is_fused
if entry.api or entry.defined_in_pxd: if entry.api or entry.defined_in_pxd:
PyrexTypes.map_with_specific_entries(entry, func, env, env.use_utility_code(function_export_utility_code)
code, from_fused) signature = entry.type.signature_string()
code.putln('if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s' % (
def _generate_c_function_export_code(self, entry, env, code, from_fused): entry.name,
env.use_utility_code(function_export_utility_code) entry.cname,
signature = entry.type.signature_string() signature,
s = 'if (__Pyx_ExportFunction("%s", (void (*)(void))%s, "%s") < 0) %s' code.error_goto(self.pos)))
if from_fused:
# Specific version of a fused function. Fused functions can never
# be declared public or api, but they may need to be exported when
# declared in a .pxd. We need to give them a unique name in that
# case
name = entry.cname
else:
name = entry.name
code.putln(s % (name, entry.cname, signature, code.error_goto(self.pos)))
def generate_type_import_code_for_module(self, module, env, code): def generate_type_import_code_for_module(self, module, env, code):
# Generate type import code for all exported extension types in # Generate type import code for all exported extension types in
...@@ -2095,30 +2075,16 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2095,30 +2075,16 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
module.qualified_name, module.qualified_name,
temp, temp,
code.error_goto(self.pos))) code.error_goto(self.pos)))
for entry in entries: for entry in entries:
PyrexTypes.map_with_specific_entries(entry, code.putln(
self._import_cdef_func, 'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s' % (
code, temp,
temp, entry.name,
entry.type.is_fused) entry.cname,
entry.type.signature_string(),
code.error_goto(self.pos)))
code.putln("Py_DECREF(%s); %s = 0;" % (temp, temp)) code.putln("Py_DECREF(%s); %s = 0;" % (temp, temp))
def _import_cdef_func(self, entry, code, temp, from_fused):
if from_fused:
name = entry.cname
else:
name = entry.name
code.putln(
'if (__Pyx_ImportFunction(%s, "%s", (void (**)(void))&%s, "%s") < 0) %s' % (
temp,
name,
entry.cname,
entry.type.signature_string(),
code.error_goto(self.pos)))
def generate_type_init_code(self, env, code): def generate_type_init_code(self, env, code):
# Generate type import code for extern extension types # Generate type import code for extern extension types
# and type ready code for non-extern ones. # and type ready code for non-extern ones.
......
...@@ -1749,7 +1749,6 @@ class CFuncDefNode(FuncDefNode): ...@@ -1749,7 +1749,6 @@ class CFuncDefNode(FuncDefNode):
inline_in_pxd = False inline_in_pxd = False
decorators = None decorators = None
directive_locals = None directive_locals = None
cname_postfix = None
def unqualified_name(self): def unqualified_name(self):
return self.entry.name return self.entry.name
...@@ -2050,6 +2049,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2050,6 +2049,8 @@ class FusedCFuncDefNode(StatListNode):
if n.cfunc_declarator.optional_arg_count: if n.cfunc_declarator.optional_arg_count:
assert n.type.op_arg_struct assert n.type.op_arg_struct
assert n.type.entry
assert node.type.is_fused assert node.type.is_fused
node.entry.fused_cfunction = self node.entry.fused_cfunction = self
...@@ -2065,19 +2066,29 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2065,19 +2066,29 @@ class FusedCFuncDefNode(StatListNode):
# print 'Node %s has %d specializations:' % (self.node.entry.name, # print 'Node %s has %d specializations:' % (self.node.entry.name,
# len(permutations)) # len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations]) # import pprint; pprint.pprint([d for cname, d in permutations])
env.cfunc_entries.remove(self.node.entry)
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node) copied_node = copy.deepcopy(self.node)
# Make the types in our CFuncType specific # Make the types in our CFuncType specific
newtype = copied_node.type.specialize(fused_to_specific) type = copied_node.type.specialize(fused_to_specific)
copied_node.type = newtype entry = copied_node.entry
copied_node.entry.type = newtype
newtype.entry = copied_node.entry copied_node.type = type
entry.type, type.entry = type, entry
self.node.cfunc_declarator.declare_optional_arg_struct( entry.used = (entry.used or
newtype, env, fused_cname=cname) self.node.entry.defined_in_pxd or
env.is_c_class_scope or
entry.is_cmethod)
copied_node.return_type = newtype.return_type if self.node.cfunc_declarator.optional_arg_count:
self.node.cfunc_declarator.declare_optional_arg_struct(
type, env, fused_cname=cname)
copied_node.return_type = type.return_type
copied_node.create_local_scope(env) copied_node.create_local_scope(env)
copied_node.local_scope.fused_to_specific = fused_to_specific copied_node.local_scope.fused_to_specific = fused_to_specific
...@@ -2090,8 +2101,8 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2090,8 +2101,8 @@ class FusedCFuncDefNode(StatListNode):
for arg in copied_node.cfunc_declarator.args: for arg in copied_node.cfunc_declarator.args:
arg.type = arg.type.specialize(fused_to_specific) arg.type = arg.type.specialize(fused_to_specific)
cname = self.node.type.get_specific_cname(cname) type.specialize_entry(entry, cname)
copied_node.entry.func_cname = copied_node.entry.cname = cname env.cfunc_entries.append(entry)
num_errors = Errors.num_errors num_errors = Errors.num_errors
transform = ParseTreeTransforms.ReplaceFusedTypeChecks( transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
...@@ -2101,6 +2112,23 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2101,6 +2112,23 @@ class FusedCFuncDefNode(StatListNode):
if Errors.num_errors > num_errors: if Errors.num_errors > num_errors:
break break
def generate_function_definitions(self, env, code):
for stat in self.stats:
# print stat.entry, stat.entry.used
if stat.entry.used:
stat.generate_function_definitions(env, code)
def generate_execution_code(self, code):
for stat in self.stats:
if stat.entry.used:
code.mark_pos(stat.pos)
stat.generate_execution_code(code)
def annotate(self, code):
for stat in self.stats:
if stat.entry.used:
stat.annotate(code)
class PyArgDeclNode(Node): class PyArgDeclNode(Node):
# Argument which must be a Python object (used # Argument which must be a Python object (used
......
...@@ -1344,20 +1344,17 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1344,20 +1344,17 @@ class AnalyseExpressionsTransform(CythonTransform):
def visit_IndexNode(self, node): def visit_IndexNode(self, node):
""" """
Replace index nodes used to specialize cdef functions with fused Replace index nodes used to specialize cdef functions with fused
argument types with a NameNode referring to the function with argument types with the Attribute- or NameNode referring to the
specialized entry and type. function. We then need to copy over the specialization properties to
the attribute or name node.
""" """
self.visit_Node(node) self.visit_Node(node)
type = node.type type = node.type
if type.is_cfunction and node.base.type.is_fused: if type.is_cfunction and node.base.type.is_fused:
node.base.type = node.type
node.base.entry = node.type.entry
node = node.base node = node.base
if not node.is_name:
error(node.pos, "Can only index a fused function once")
node.type = PyrexTypes.error_type
else:
node.type = type
node.entry = type.entry
return node return node
......
...@@ -39,20 +39,14 @@ class BaseType(object): ...@@ -39,20 +39,14 @@ class BaseType(object):
""" """
return self return self
def get_fused_types(self, result=None, seen=None): def get_fused_types(self, result=None, seen=None, subtypes=None):
if self.subtypes: subtypes = subtypes or self.subtypes
if subtypes:
def add_fused_types(types):
for type in types or ():
if type not in seen:
seen.add(type)
result.append(type)
if result is None: if result is None:
result = [] result = []
seen = cython.set() seen = cython.set()
for attr in self.subtypes: for attr in subtypes:
list_or_subtype = getattr(self, attr) list_or_subtype = getattr(self, attr)
if isinstance(list_or_subtype, BaseType): if isinstance(list_or_subtype, BaseType):
...@@ -1763,10 +1757,13 @@ class CFuncType(CType): ...@@ -1763,10 +1757,13 @@ class CFuncType(CType):
# with_gil boolean Acquire gil around function body # with_gil boolean Acquire gil around function body
# templates [string] or None # templates [string] or None
# cached_specialized_types [CFuncType] cached specialized versions of the CFuncType if defined in a pxd # cached_specialized_types [CFuncType] cached specialized versions of the CFuncType if defined in a pxd
# from_fused boolean Indicates whether this is a specialized
# C function
is_cfunction = 1 is_cfunction = 1
original_sig = None original_sig = None
cached_specialized_types = None cached_specialized_types = None
from_fused = False
subtypes = ['return_type', 'args'] subtypes = ['return_type', 'args']
...@@ -1994,17 +1991,20 @@ class CFuncType(CType): ...@@ -1994,17 +1991,20 @@ class CFuncType(CType):
else: else:
new_templates = [v.specialize(values) for v in self.templates] new_templates = [v.specialize(values) for v in self.templates]
return CFuncType(self.return_type.specialize(values), result = CFuncType(self.return_type.specialize(values),
[arg.specialize(values) for arg in self.args], [arg.specialize(values) for arg in self.args],
has_varargs = 0, has_varargs = 0,
exception_value = self.exception_value, exception_value = self.exception_value,
exception_check = self.exception_check, exception_check = self.exception_check,
calling_convention = self.calling_convention, calling_convention = self.calling_convention,
nogil = self.nogil, nogil = self.nogil,
with_gil = self.with_gil, with_gil = self.with_gil,
is_overridable = self.is_overridable, is_overridable = self.is_overridable,
optional_arg_count = self.optional_arg_count, optional_arg_count = self.optional_arg_count,
templates = new_templates) templates = new_templates)
result.from_fused = self.is_fused
return result
def opt_arg_cname(self, arg_name): def opt_arg_cname(self, arg_name):
return self.op_arg_struct.base_type.scope.lookup(arg_name).cname return self.op_arg_struct.base_type.scope.lookup(arg_name).cname
...@@ -2040,6 +2040,10 @@ class CFuncType(CType): ...@@ -2040,6 +2040,10 @@ class CFuncType(CType):
elif self.cached_specialized_types is not None: elif self.cached_specialized_types is not None:
return self.cached_specialized_types return self.cached_specialized_types
cfunc_entries = self.entry.scope.cfunc_entries
cfunc_entries.remove(self.entry)
result = [] result = []
permutations = self.get_all_specific_permutations() permutations = self.get_all_specific_permutations()
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
...@@ -2050,55 +2054,46 @@ class CFuncType(CType): ...@@ -2050,55 +2054,46 @@ class CFuncType(CType):
self.declare_opt_arg_struct(new_func_type, cname) self.declare_opt_arg_struct(new_func_type, cname)
new_entry = copy.deepcopy(self.entry) new_entry = copy.deepcopy(self.entry)
new_entry.cname = self.get_specific_cname(cname) new_func_type.specialize_entry(new_entry, cname)
new_entry.type = new_func_type new_entry.type = new_func_type
new_func_type.entry = new_entry new_func_type.entry = new_entry
result.append(new_func_type) result.append(new_func_type)
cfunc_entries.append(new_entry)
self.cached_specialized_types = result self.cached_specialized_types = result
return result return result
def get_specific_cname(self, fused_cname): def get_fused_types(self, result=None, seen=None, subtypes=None):
""" "Return fused types in the order they appear as parameter types"
Given the cname for a permutation of fused types, return the cname return super(CFuncType, self).get_fused_types(result, seen,
for the corresponding function with specific types. subtypes=['args'])
"""
assert self.is_fused
return get_fused_cname(fused_cname, self.entry.func_cname)
def specialize_entry(self, entry, cname):
assert not self.is_fused
entry.name = get_fused_cname(cname, entry.name)
if entry.is_cmethod:
entry.cname = entry.name
if entry.is_inherited:
entry.cname = "%s.%s" % (Naming.obj_base_cname, entry.cname)
else:
entry.cname = get_fused_cname(cname, entry.cname)
if entry.func_cname:
entry.func_cname = get_fused_cname(cname, entry.func_cname)
def get_fused_cname(fused_cname, orig_cname): def get_fused_cname(fused_cname, orig_cname):
""" """
Given the fused cname id and an original cname, return a specialized cname Given the fused cname id and an original cname, return a specialized cname
""" """
assert fused_cname and orig_cname
return '%s%s%s' % (Naming.fused_func_prefix, fused_cname, orig_cname) return '%s%s%s' % (Naming.fused_func_prefix, fused_cname, orig_cname)
def map_with_specific_entries(entry, func, *args, **kwargs):
"""
Call func for every specific function instance. If this is not a
signature with fused types, call it with the entry for this cdef
function.
"""
type = entry.type
if type.is_cfunction and (entry.fused_cfunction or type.is_fused):
if entry.fused_cfunction:
# cdef with fused types defined in this file
for cfunction in entry.fused_cfunction.nodes:
func(cfunction.entry, *args, **kwargs)
else:
# cdef with fused types defined in another file, create their
# signatures
for func_type in type.get_all_specific_function_types():
func(func_type.entry, *args, **kwargs)
else:
# a normal cdef or not a c function
func(entry, *args, **kwargs)
def get_all_specific_permutations(fused_types, id="", f2s=()): def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0] fused_type = fused_types[0]
result = [] result = []
......
...@@ -734,6 +734,7 @@ class Scope(object): ...@@ -734,6 +734,7 @@ class Scope(object):
else: else:
return outer.is_cpp() return outer.is_cpp()
class PreImportScope(Scope): class PreImportScope(Scope):
namespace_cname = Naming.preimport_cname namespace_cname = Naming.preimport_cname
...@@ -1696,6 +1697,7 @@ class CClassScope(ClassScope): ...@@ -1696,6 +1697,7 @@ class CClassScope(ClassScope):
if defining: if defining:
entry.func_cname = self.mangle(Naming.func_prefix, name) entry.func_cname = self.mangle(Naming.func_prefix, name)
entry.utility_code = utility_code entry.utility_code = utility_code
type.entry = entry
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers): def add_cfunction(self, name, type, pos, cname, visibility, modifiers):
...@@ -1744,6 +1746,14 @@ class CClassScope(ClassScope): ...@@ -1744,6 +1746,14 @@ class CClassScope(ClassScope):
base_entry.type, None, 'private') base_entry.type, None, 'private')
entry.is_variable = 1 entry.is_variable = 1
self.inherited_var_entries.append(entry) self.inherited_var_entries.append(entry)
# If the class defined in a pxd, specific entries have not been added.
# Ensure now that the parent (base) scope has specific entries
# Iterate over a copy as get_all_specific_function_types() will mutate
for base_entry in base_scope.cfunc_entries[:]:
if base_entry.type.is_fused:
base_entry.type.get_all_specific_function_types()
for base_entry in base_scope.cfunc_entries: for base_entry in base_scope.cfunc_entries:
entry = self.add_cfunction(base_entry.name, base_entry.type, entry = self.add_cfunction(base_entry.name, base_entry.type,
base_entry.pos, adapt(base_entry.cname), base_entry.pos, adapt(base_entry.cname),
...@@ -1819,6 +1829,7 @@ class CppClassScope(Scope): ...@@ -1819,6 +1829,7 @@ class CppClassScope(Scope):
if prev_entry: if prev_entry:
entry.overloaded_alternatives = prev_entry.all_alternatives() entry.overloaded_alternatives = prev_entry.all_alternatives()
entry.utility_code = utility_code entry.utility_code = utility_code
type.entry = entry
return entry return entry
def declare_inherited_cpp_attributes(self, base_scope): def declare_inherited_cpp_attributes(self, base_scope):
......
...@@ -169,8 +169,7 @@ cdef opt_args(integral x, floating y = 4.0): ...@@ -169,8 +169,7 @@ cdef opt_args(integral x, floating y = 4.0):
def test_opt_args(): def test_opt_args():
""" """
ToDO: enable and fix >>> test_opt_args()
test_opt_args()
3 4.0 3 4.0
3 4.0 3 4.0
3 4.0 3 4.0
......
...@@ -39,6 +39,12 @@ cdef object_t add_simple(object_t obj, simple_t simple) ...@@ -39,6 +39,12 @@ cdef object_t add_simple(object_t obj, simple_t simple)
cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple) cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple)
cdef public_optional_args(object_t obj, simple_t simple = *) cdef public_optional_args(object_t obj, simple_t simple = *)
ctypedef cython.fused_type(float, double) floating
ctypedef cython.fused_type(int, long) integral
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y)
######## header.h ######## ######## header.h ########
typedef int extern_int; typedef int extern_int;
...@@ -58,6 +64,16 @@ cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple): ...@@ -58,6 +64,16 @@ cdef less_simple_t add_to_simple(object_t obj, less_simple_t simple):
cdef public_optional_args(object_t obj, simple_t simple = 6): cdef public_optional_args(object_t obj, simple_t simple = 6):
return obj.a, simple return obj.a, simple
cdef class TestFusedExtMethods(object):
cdef floating method(self, integral x, floating y):
if integral is int:
x += 1
if floating is double:
y += 2.0
return x + y
######## b.pyx ######## ######## b.pyx ########
from a cimport * from a cimport *
...@@ -92,3 +108,42 @@ assert public_optional_args[mystruct_t, int](mystruct) == (5, 6) ...@@ -92,3 +108,42 @@ assert public_optional_args[mystruct_t, int](mystruct) == (5, 6)
assert public_optional_args[mystruct_t, float](mystruct) == (5, 6.0) assert public_optional_args[mystruct_t, float](mystruct) == (5, 6.0)
assert public_optional_args[mystruct_t, float](mystruct, 7.0) == (5, 7.0) assert public_optional_args[mystruct_t, float](mystruct, 7.0) == (5, 7.0)
cdef TestFusedExtMethods obj = TestFusedExtMethods()
cdef int x = 4
cdef float y = 5.0
cdef long a = 6
cdef double b = 7.0
cdef double (*func)(TestFusedExtMethods, long, double)
func = obj.method
assert func(obj, a, b) == 15.0
func = <double (*)(TestFusedExtMethods, long, double)> obj.method
assert func(obj, x, y) == 11.0
func = obj.method[long, double]
assert func(obj, a, y) == 13.0
assert obj.method(x, <double> a) == 13.0
assert obj.method[int, double](x, b) == 14.0
# Test inheritance
cdef class Subclass(TestFusedExtMethods):
cdef floating method(self, integral x, floating y):
return -x -y
cdef Subclass myobj = Subclass()
assert myobj.method[int, float](5, 5.0) == -10
cdef float (*meth)(Subclass, int, float)
meth = myobj.method
assert meth(myobj, 5, 5.0) == -10
meth = myobj.method[int, float]
assert meth(myobj, 5, 5.0) == -10
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment