Commit 52d40881 authored by Xavier Thompson's avatar Xavier Thompson

Fix and refactor cypclass typecasts

parent 7677b120
...@@ -11004,11 +11004,11 @@ class TypecastNode(ExprNode): ...@@ -11004,11 +11004,11 @@ class TypecastNode(ExprNode):
imag_part) imag_part)
else: else:
operand_type = self.operand.type operand_type = self.operand.type
if operand_type.is_cyp_class: if self.overloaded:
if self.overloaded: if operand_type.is_cyp_class:
operand_result = '(*%s)' % operand_result operand_result = '(*%s)' % operand_result
# use dynamic cast when dowcasting from a base to a cypclass elif self.type.is_cyp_class and operand_type in self.type.mro() and not self.type.same_as(operand_type):
if self.type.is_cyp_class and operand_type in self.type.mro() and not self.type.same_as(operand_type): # use dynamic cast when dowcasting from a base to a cypclass
return self.type.dynamic_cast_code(operand_result) return self.type.dynamic_cast_code(operand_result)
return self.type.cast_code(operand_result) return self.type.cast_code(operand_result)
...@@ -11033,22 +11033,17 @@ class TypecastNode(ExprNode): ...@@ -11033,22 +11033,17 @@ class TypecastNode(ExprNode):
self.operand.result())) self.operand.result()))
code.put_incref(self.result(), self.ctype()) code.put_incref(self.result(), self.ctype())
elif self.type.is_cyp_class: elif self.type.is_cyp_class:
star = "*" if self.overloaded else "" # self.overloaded is True
operand_result = "%s%s" % (star, self.operand.result()) operand_type = self.operand.type
# use dynamic cast when dowcasting from a base to a cypclass operand_result = self.operand.result()
if self.operand.type in self.type.mro() and not self.type.same_as(self.operand.type): if operand_type.is_cyp_class:
code.putln( operand_result = "*%s" % operand_result
"%s = dynamic_cast<%s>(%s);" % ( code.putln(
self.result(), "%s = (%s)(%s);" % (
self.type.declaration_code(''), self.result(),
operand_result)) self.type.declaration_code(''),
else: operand_result))
code.putln( # the result is already a new reference
"%s = (%s)(%s);" % (
self.result(),
self.type.declaration_code(''),
operand_result))
code.put_incref(self.result(), self.type)
ERR_START = "Start may not be given" ERR_START = "Start may not be given"
......
# mode: run
# tag: cpp, cpp11, pthread
# cython: experimental_cpp_class_def=True, language_level=2
cdef cypclass Base:
__dealloc__(self) with gil:
print("Base destroyed")
cdef cypclass Derived(Base):
__dealloc__(self) with gil:
print("Derived destroyed")
def test_upcast_name():
"""
>>> test_upcast_name()
Derived destroyed
Base destroyed
0
"""
d = Derived()
b = <Base> d
if Cy_GETREF(d) != 3:
return -1
del b
if Cy_GETREF(d) != 2:
return -2
return 0
def test_upcast_and_drop_name():
"""
>>> test_upcast_and_drop_name()
Derived destroyed
Base destroyed
0
"""
d = Derived()
<Base> d
if Cy_GETREF(d) != 2:
return -1
return 0
def test_upcast_constructed():
"""
>>> test_upcast_constructed()
Derived destroyed
Base destroyed
0
"""
d = <Base> Derived()
if Cy_GETREF(d) != 2:
return -1
return 0
def test_upcast_and_drop_constructed():
"""
>>> test_upcast_and_drop_constructed()
Derived destroyed
Base destroyed
0
"""
<Base> Derived()
return 0
def test_downcast_name():
"""
>>> test_downcast_name()
Derived destroyed
Base destroyed
0
"""
b = <Base> Derived()
d = <Derived> b
if Cy_GETREF(b) != 3:
return -1
del b
if Cy_GETREF(d) != 2:
return -2
return 0
def test_downcast_and_drop_name():
"""
>>> test_downcast_and_drop_name()
Derived destroyed
Base destroyed
0
"""
b = <Base> Derived()
<Derived> b
if Cy_GETREF(b) != 2:
return -1
return 0
def test_downcast_constructed():
"""
>>> test_downcast_constructed()
Derived destroyed
Base destroyed
0
"""
d = <Derived> <Base> Derived()
if Cy_GETREF(d) != 2:
return -1
return 0
def test_downcast_and_drop_constructed():
"""
>>> test_downcast_and_drop_constructed()
Derived destroyed
Base destroyed
0
"""
<Derived> <Base> Derived()
return 0
def test_failed_downcast():
"""
>>> test_failed_downcast()
Base destroyed
0
"""
d = <Derived> Base()
if d is not NULL:
return -1
return 0
cdef cypclass Convertible:
Derived __Derived__(self) with gil:
print("Convertible -> Derived")
return Derived()
__dealloc__(self) with gil:
print("Convertible destroyed")
def test_convert_name():
"""
>>> test_convert_name()
Convertible -> Derived
Convertible destroyed
Derived destroyed
Base destroyed
0
"""
c = Convertible()
d = <Derived> c
if Cy_GETREF(c) != 2:
return -1
if Cy_GETREF(d) != 2:
return -2
del c
return 0
def test_convert_and_drop_name():
"""
>>> test_convert_and_drop_name()
Convertible -> Derived
Derived destroyed
Base destroyed
converted
Convertible destroyed
0
"""
c = Convertible()
<Derived> c
print("converted")
if Cy_GETREF(c) != 2:
return -1
return 0
def test_convert_constructed():
"""
>>> test_convert_constructed()
Convertible -> Derived
Convertible destroyed
converted
Derived destroyed
Base destroyed
0
"""
d = <Derived> Convertible()
print("converted")
if Cy_GETREF(d) != 2:
return -1
return 0
def test_convert_and_drop_constructed():
"""
>>> test_convert_and_drop_constructed()
Convertible -> Derived
Convertible destroyed
Derived destroyed
Base destroyed
converted
0
"""
<Derived> Convertible()
print("converted")
return 0
cdef cypclass DerivedConvertible(Base):
Base __Base__(self) with gil:
print("DerivedConvertible -> Base")
return Base()
__dealloc__(self) with gil:
print("DerivedConvertible destroyed")
def test_overloaded_upcast():
"""
>>> test_overloaded_upcast()
DerivedConvertible -> Base
converted
DerivedConvertible destroyed
Base destroyed
Base destroyed
0
"""
d = DerivedConvertible()
b = <Base> d
print("converted")
if Cy_GETREF(d) != 2:
return -1
if Cy_GETREF(b) != 2:
return -2
del d
return 0
cdef cypclass BaseConvertible
cdef cypclass DerivedConverted(BaseConvertible):
__dealloc__(self) with gil:
print("DerivedConverted destroyed")
cdef cypclass BaseConvertible:
DerivedConverted __DerivedConverted__(self) with gil:
print("BaseConvertible -> DerivedConverted")
return DerivedConverted()
__dealloc__(self) with gil:
print("BaseConvertible destroyed")
def test_overloaded_downcast():
"""
>>> test_overloaded_downcast()
BaseConvertible -> DerivedConverted
converted
BaseConvertible destroyed
DerivedConverted destroyed
BaseConvertible destroyed
0
"""
b = BaseConvertible()
d = <DerivedConverted> b
print("converted")
if Cy_GETREF(d) != 2:
return -1
if Cy_GETREF(b) != 2:
return -2
del b
return 0
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment