Commit e750b7ff authored by Xavier Thompson's avatar Xavier Thompson

Dispatch cypclass static methods correctly

parent 61ad43c6
...@@ -1041,7 +1041,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1041,7 +1041,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
continue continue
elif attr.type.is_cyp_class: elif attr.type.is_cyp_class:
cname = "%s = NULL" % cname cname = "%s = NULL" % cname
if type.is_cyp_class and attr.type.is_cfunction and attr.type.is_static_method and attr.static_cname is not None:
code.putln("%s;" % attr.type.declaration_code(attr.static_cname))
self.generate_cyp_class_static_method_resolution(attr, code)
else:
code.putln("%s;" % attr.type.declaration_code(cname)) code.putln("%s;" % attr.type.declaration_code(cname))
is_implementing = 'init_module' in code.globalstate.parts is_implementing = 'init_module' in code.globalstate.parts
for reified in scope.reifying_entries: for reified in scope.reifying_entries:
...@@ -1168,8 +1173,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1168,8 +1173,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
and e.from_type and e.from_type
and e.mro_index > 0 and e.mro_index > 0
and e.from_type.is_cyp_class # avoid dealing with methods inherited from non-cypclass bases for now and e.from_type.is_cyp_class # avoid dealing with methods inherited from non-cypclass bases for now
and not e.type.is_static_method # avoid dealing with static methods for now
and e.name not in ("<init>", "<del>") and e.name not in ("<init>", "<del>")
and (not e.type.is_static_method
or e.static_cname is not None) # mro-resolve the virtual methods used to dispatch static methods
and not e.type.has_varargs # avoid dealing with varargs for now (is this ever required anyway ?) and not e.type.has_varargs # avoid dealing with varargs for now (is this ever required anyway ?)
] ]
if inherited_methods: if inherited_methods:
...@@ -1198,6 +1204,32 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1198,6 +1204,32 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
if inherited_methods: if inherited_methods:
code.putln("") code.putln("")
def generate_cyp_class_static_method_resolution(self, static_method, code):
"""
Generate a virtual method in cypclass that just forward calls to associated static method.
The virtual version will serve to correctly dispatch static methods.
"""
func_type = static_method.type
modifiers = code.build_function_modifiers(static_method.func_modifiers)
arg_names = ["%s_%d" % (arg.cname, i) for i, arg in enumerate(func_type.args)]
arg_decls = [arg.type.declaration_code(arg_name) for arg, arg_name in zip(func_type.args, arg_names)]
if func_type.optional_arg_count:
opt_name = Naming.optional_args_cname
arg_decls.append(func_type.op_arg_struct.declaration_code(opt_name))
arg_names.append(opt_name)
header = func_type.function_header_code(static_method.cname, ", ".join(arg_decls))
if not static_method.name.startswith("operator "):
header = func_type.return_type.declaration_code(header)
return_code = "" if func_type.return_type.is_void else "return "
body = "%s%s(%s);" % (return_code, static_method.static_cname, ", ".join(arg_names))
code.putln("virtual %s%s {%s}" % (modifiers, header, body))
def generate_enum_definition(self, entry, code): def generate_enum_definition(self, entry, code):
code.mark_pos(entry.pos) code.mark_pos(entry.pos)
type = entry.type type = entry.type
......
...@@ -170,10 +170,6 @@ api_name = pyrex_prefix + "capi__" ...@@ -170,10 +170,6 @@ api_name = pyrex_prefix + "capi__"
# cname for the type that defines the essential memory layout of a cypclass wrapper. # cname for the type that defines the essential memory layout of a cypclass wrapper.
cypclass_wrapper_layout_type = "CyPyObject" cypclass_wrapper_layout_type = "CyPyObject"
# cname for the underlying cypclass attribute in the memory layout of a cypclass wrapper.
cypclass_wrapper_underlying_attr = "nogil_cyobject"
# the h and api guards get changed to: # the h and api guards get changed to:
# __PYX_HAVE__FILENAME (for ascii filenames) # __PYX_HAVE__FILENAME (for ascii filenames)
# __PYX_HAVE_U_PUNYCODEFILENAME (for non-ascii filenames) # __PYX_HAVE_U_PUNYCODEFILENAME (for non-ascii filenames)
......
...@@ -176,6 +176,8 @@ class Entry(object): ...@@ -176,6 +176,8 @@ class Entry(object):
# defining_classes [CypClassType or CppClassType or CStructOrUnionType] # defining_classes [CypClassType or CppClassType or CStructOrUnionType]
# All the base classes that define an entry that this entry # All the base classes that define an entry that this entry
# overrides, if this entry represents a cypclass method # overrides, if this entry represents a cypclass method
#
# static_cname string The cname of a static method in a cypclass
# TODO: utility_code and utility_code_definition serves the same purpose... # TODO: utility_code and utility_code_definition serves the same purpose...
...@@ -254,6 +256,7 @@ class Entry(object): ...@@ -254,6 +256,7 @@ class Entry(object):
is_default = False is_default = False
mro_index = 0 mro_index = 0
from_type = None from_type = None
static_cname = None
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
...@@ -2820,6 +2823,8 @@ class CppClassScope(Scope): ...@@ -2820,6 +2823,8 @@ class CppClassScope(Scope):
cname = None, visibility = 'extern', cname = None, visibility = 'extern',
api = 0, in_pxd = 0, is_cdef = 0, defining = 0): api = 0, in_pxd = 0, is_cdef = 0, defining = 0):
# Add an entry for an attribute. # Add an entry for an attribute.
if name.startswith(Naming.func_prefix):
error(pos, "Names starting with %s are reserved inside cppclass and cypclass" % Naming.func_prefix)
if not cname: if not cname:
cname = name cname = name
entry = self.lookup_here(name) entry = self.lookup_here(name)
...@@ -2839,6 +2844,9 @@ class CppClassScope(Scope): ...@@ -2839,6 +2844,9 @@ class CppClassScope(Scope):
entry.is_cfunction = type.is_cfunction entry.is_cfunction = type.is_cfunction
if type.is_cfunction and self.type: if type.is_cfunction and self.type:
if not self.type.get_fused_types(): if not self.type.get_fused_types():
if (self.parent_type.is_cyp_class and type.is_static_method and name not in ("<alloc>", "__new__")):
cname = "%s__static__%s" % (Naming.func_prefix, cname)
entry.static_cname = cname
entry.func_cname = "%s::%s" % (self.type.empty_declaration_code(), cname) entry.func_cname = "%s::%s" % (self.type.empty_declaration_code(), cname)
if name != "this" and (defining or name != "<init>" or self.parent_type.is_cyp_class): if name != "this" and (defining or name != "<init>" or self.parent_type.is_cyp_class):
self.var_entries.append(entry) self.var_entries.append(entry)
...@@ -3120,6 +3128,7 @@ class CppClassScope(Scope): ...@@ -3120,6 +3128,7 @@ class CppClassScope(Scope):
entry.is_cfunction = base_entry.is_cfunction entry.is_cfunction = base_entry.is_cfunction
if entry.is_cfunction: if entry.is_cfunction:
entry.func_cname = base_entry.func_cname entry.func_cname = base_entry.func_cname
entry.static_cname = base_entry.static_cname
for base_entry in base_scope.type_entries: for base_entry in base_scope.type_entries:
if base_entry.name not in base_templates: if base_entry.name not in base_templates:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment