Commit 17b3b44a authored by Xavier Thompson's avatar Xavier Thompson

Enforce cypclass method overriding rules

parent 6ba13637
......@@ -1619,6 +1619,26 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
func.template_declaration = "template <typename %s>" % ", typename ".join(template_names)
self.body = StatListNode(self.pos, stats=defined_funcs)
# check that all overloaded alternatives for cypclass methods come from the same cypclass
if self.cypclass and scope is not None:
for method_entry in scope.entries.values():
if (method_entry.is_cfunction
and not method_entry.type.is_static_method
and method_entry.name not in ("<init>", "<alloc>", "<constructor>", "<del>")):
from_type = method_entry.from_type
for alternative in method_entry.all_alternatives():
if alternative.from_type is not from_type:
error(self.pos,
(
"Cypclass %s's method %s comes from %s but method %s comes from %s\n"
"Cypclass %s must either inherit all overload alternatives for %s from"
"the same base class or override all inherited alternatives itself"
)
% (self.name, str(method_entry.type), from_type.name, str(alternative.type),
alternative.from_type.name, self.name, method_entry.name)
)
# check for illegal implicit conversion paths between method arguments
if self.cypclass and scope is not None:
for method_entry in scope.entries.values():
......
# mode: run
# tag: cpp, cpp11, pthread
# cython: experimental_cpp_class_def=True, language_level=2
cdef cypclass A:
int foo(self, int a):
return a + 42
cdef cypclass B(A):
int foo(self, int a, int b):
return a + b
def test_resolve_unhidden_method():
"""
>>> test_resolve_unhidden_method()
43
"""
cdef B b = B()
# should resolve to A.foo
return b.foo(1)
cdef cypclass C:
int a
__init__(self, int a):
self.a = a
C foo(self, int other):
return C(a + other)
cdef cypclass D(C):
int b
__init__(self, int b):
self.b = 10 + b
D foo(self, int other):
return D(b + other)
def test_resolve_overriden_method():
"""
>>> test_resolve_overriden_method()
21
"""
cdef D d1 = D(0)
# should not resolve to D.foo
cdef D d2 = d1.foo(1)
return d2.b
cdef cypclass Left:
int foo(self):
return 1
cdef cypclass Right:
int foo(self):
return 2
cdef cypclass Derived(Left, Right):
pass
def test_resolve_multiple_inherited_methods():
"""
>>> test_resolve_multiple_inherited_methods()
1
"""
cdef Derived d = Derived()
# should resolve to Left.foo
cdef int r = d.foo()
return r
cdef cypclass Top:
int foo(self, int a, int b):
return 1
cdef cypclass Middle(Top):
int foo(self):
return 2
cdef cypclass Bottom(Middle):
int foo(self, int a):
return a + 10
def test_inherited_overloaded_method():
"""
>>> test_inherited_overloaded_method()
2
"""
cdef Bottom b = Bottom()
# should resolve to Middle.foo
cdef int r = b.foo()
return r
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment