Commit baa6c7f1 authored by Xavier Thompson's avatar Xavier Thompson

Emulate Python MRO by adding virtual methods in derived cypclasses forwarding to correct method

parent 89cd064a
...@@ -7581,8 +7581,6 @@ class AttributeNode(ExprNode): ...@@ -7581,8 +7581,6 @@ class AttributeNode(ExprNode):
if obj.type.is_cpp_class and self.entry and self.entry.is_cfunction: if obj.type.is_cpp_class and self.entry and self.entry.is_cfunction:
# the entry might have been resolved to an overladed alternative in the meantime # the entry might have been resolved to an overladed alternative in the meantime
self.member = self.entry.cname self.member = self.entry.cname
if obj.type.is_cyp_class and self.entry and self.entry.from_type:
self.member = "%s::%s" % (self.entry.from_type.empty_declaration_code(), self.member)
return "%s%s%s" % (obj_code, self.op, self.member) return "%s%s%s" % (obj_code, self.op, self.member)
def generate_result_code(self, code): def generate_result_code(self, code):
......
...@@ -1066,6 +1066,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1066,6 +1066,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
constructor = scope.lookup_here("<constructor>") constructor = scope.lookup_here("<constructor>")
for constructor_alternative in constructor.all_alternatives(): for constructor_alternative in constructor.all_alternatives():
code.putln("static %s;" % constructor_alternative.type.declaration_code(constructor_alternative.cname)) code.putln("static %s;" % constructor_alternative.type.declaration_code(constructor_alternative.cname))
self.generate_cyp_class_mro_method_resolution(scope, code)
elif constructor or py_attrs: elif constructor or py_attrs:
if constructor: if constructor:
for constructor_alternative in constructor.all_alternatives(): for constructor_alternative in constructor.all_alternatives():
...@@ -1155,6 +1156,47 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1155,6 +1156,47 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.globalstate.use_utility_code( code.globalstate.use_utility_code(
UtilityCode.load("CyObjects", "CyObjects.cpp", proto_block="utility_code_proto_before_types")) UtilityCode.load("CyObjects", "CyObjects.cpp", proto_block="utility_code_proto_before_types"))
def generate_cyp_class_mro_method_resolution(self, scope, code):
"""
Generate overriding methods in derived cypclasses to forward calls to the correct method according
to the MRO, regardless of the type of the pointer to the object through which the call is made.
In other words: emulate Python MRO lookup rules using only C++ virtual methods.
"""
inherited_methods = [
e for entry in scope.entries.values() for e in entry.all_alternatives()
if e.is_cfunction
and e.from_type
and e.mro_index > 0
and not e.type.is_static_method # avoid dealing with static methods for now
and e.name not in ("<init>", "<del>")
and not e.type.has_varargs # avoid dealing with varargs for now (is this ever required anyway ?)
]
if inherited_methods:
code.putln("")
code.putln("/* make all inherited (non overriden) methods resolve correctly according to the MRO */")
for e in inherited_methods:
modifiers = code.build_function_modifiers(e.func_modifiers)
arg_decls = [arg.declaration_code() for arg in e.type.args]
arg_names = [arg.cname for arg in e.type.args]
if e.type.optional_arg_count:
opt_name = Naming.optional_args_cname
arg_decls.append(e.type.op_arg_struct.declaration_code(opt_name))
arg_names.append(opt_name)
header = e.type.function_header_code(e.cname, ", ".join(arg_decls))
if not e.name.startswith("operator "):
header = e.type.return_type.declaration_code(header)
return_code = "" if e.type.return_type.is_void else "return "
resolution = e.from_type.empty_declaration_code()
body = "%s%s::%s(%s);" % (return_code, resolution, e.cname, ", ".join(arg_names))
code.putln("virtual %s%s {%s}" % (modifiers, header, body))
if inherited_methods:
code.putln("")
def generate_cyp_class_attrs_destructor_definition(self, entry, code): def generate_cyp_class_attrs_destructor_definition(self, entry, code):
scope = entry.type.scope scope = entry.type.scope
cypclass_attrs = [e for e in scope.var_entries cypclass_attrs = [e for e in scope.var_entries
......
...@@ -2951,10 +2951,10 @@ class CFuncType(CType): ...@@ -2951,10 +2951,10 @@ class CFuncType(CType):
return 0 return 0
return 1 return 1
def compatible_signature_with(self, other_type, as_cmethod = 0): def compatible_signature_with(self, other_type, as_cmethod = 0, ignore_return_type=0):
return self.compatible_signature_with_resolved_type(other_type.resolve(), as_cmethod) return self.compatible_signature_with_resolved_type(other_type.resolve(), as_cmethod, ignore_return_type)
def compatible_signature_with_resolved_type(self, other_type, as_cmethod): def compatible_signature_with_resolved_type(self, other_type, as_cmethod, ignore_return_type=0):
#print "CFuncType.same_c_signature_as_resolved_type:", \ #print "CFuncType.same_c_signature_as_resolved_type:", \
# self, other_type, "as_cmethod =", as_cmethod ### # self, other_type, "as_cmethod =", as_cmethod ###
if other_type is error_type: if other_type is error_type:
...@@ -2977,8 +2977,9 @@ class CFuncType(CType): ...@@ -2977,8 +2977,9 @@ class CFuncType(CType):
return 0 return 0
if self.has_varargs != other_type.has_varargs: if self.has_varargs != other_type.has_varargs:
return 0 return 0
if not self.return_type.subtype_of_resolved_type(other_type.return_type): if not ignore_return_type:
return 0 if not self.return_type.subtype_of_resolved_type(other_type.return_type):
return 0
if not self.same_calling_convention_as(other_type): if not self.same_calling_convention_as(other_type):
return 0 return 0
if self.nogil != other_type.nogil: if self.nogil != other_type.nogil:
......
...@@ -534,7 +534,12 @@ class Scope(object): ...@@ -534,7 +534,12 @@ class Scope(object):
if type.is_cfunction and old_entry.type.is_cfunction and self.is_cpp(): if type.is_cfunction and old_entry.type.is_cfunction and self.is_cpp():
cpp_override_allowed = True cpp_override_allowed = True
for index, alt_entry in enumerate(old_entry.all_alternatives()): for index, alt_entry in enumerate(old_entry.all_alternatives()):
if type.compatible_signature_with(alt_entry.type): # in a cypclass, a method can hide a method inherited from a different class
# regardless of their return types
ignore_return_type = (self.is_cyp_class_scope
and alt_entry.is_inherited
and alt_entry.from_type is not from_type)
if type.compatible_signature_with(alt_entry.type, ignore_return_type=ignore_return_type):
cpp_override_allowed = False cpp_override_allowed = False
...@@ -3033,7 +3038,6 @@ class CppClassScope(Scope): ...@@ -3033,7 +3038,6 @@ class CppClassScope(Scope):
entry.is_cfunction = base_entry.is_cfunction entry.is_cfunction = base_entry.is_cfunction
if entry.is_cfunction: if entry.is_cfunction:
entry.func_cname = base_entry.func_cname entry.func_cname = base_entry.func_cname
self.inherited_var_entries.append(entry)
for base_entry in base_scope.type_entries: for base_entry in base_scope.type_entries:
if base_entry.name not in base_templates: if base_entry.name not in base_templates:
...@@ -3041,7 +3045,6 @@ class CppClassScope(Scope): ...@@ -3041,7 +3045,6 @@ class CppClassScope(Scope):
base_entry.pos, base_entry.cname, base_entry.pos, base_entry.cname,
base_entry.visibility, defining=0) base_entry.visibility, defining=0)
entry.is_inherited = 1 entry.is_inherited = 1
self.inherited_type_entries.append(entry)
def specialize(self, values, type_entry): def specialize(self, values, type_entry):
scope = CppClassScope(self.name, self.outer_scope) scope = CppClassScope(self.name, self.outer_scope)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment