Commit d2697b13 authored by da-woods's avatar da-woods Committed by GitHub

Implement PEP-560 inheritance (__mro_entries__) (GH-4005)

Fixes https://github.com/cython/cython/issues/3537

Both the C code and the tests are largely copied from CPython.

Note that this currently only applies to Python classes, not to cdef classes. Maybe it could be applied to the second+subsequent bases but I don't think it's needed for the initial implementation.
parent 89f881b5
......@@ -4752,6 +4752,7 @@ class PyClassDefNode(ClassDefNode):
# entry Symtab.Entry
# scope PyClassScope
# decorators [DecoratorNode] list of decorators or None
# bases ExprNode Expression that evaluates to a tuple of base classes
#
# The following subnodes are constructed internally:
#
......@@ -4759,15 +4760,18 @@ class PyClassDefNode(ClassDefNode):
# dict DictNode Class dictionary or Py3 namespace
# classobj ClassNode Class object
# target NameNode Variable to assign class object to
# orig_bases None or ExprNode "bases" before transformation by PEP560 __mro_entries__,
# used to create the __orig_bases__ attribute
child_attrs = ["doc_node", "body", "dict", "metaclass", "mkw", "bases", "class_result",
"target", "class_cell", "decorators"]
"target", "class_cell", "decorators", "orig_bases"]
decorators = None
class_result = None
is_py3_style_class = False # Python3 style class (kwargs)
metaclass = None
mkw = None
doc_node = None
orig_bases = None
def __init__(self, pos, name, bases, doc, body, decorators=None,
keyword_args=None, force_py3_semantics=False):
......@@ -4893,7 +4897,22 @@ class PyClassDefNode(ClassDefNode):
self.body.analyse_declarations(cenv)
self.class_result.analyse_annotations(cenv)
update_bases_functype = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("bases", PyrexTypes.py_object_type, None)
])
def analyse_expressions(self, env):
if self.bases and not (self.bases.is_sequence_constructor and len(self.bases.args) == 0):
from .ExprNodes import PythonCapiCallNode, CloneNode
# handle the Python 3.7 __mro_entries__ transformation
orig_bases = self.bases.analyse_expressions(env)
self.bases = PythonCapiCallNode(orig_bases.pos,
function_name="__Pyx_PEP560_update_bases",
func_type=self.update_bases_functype,
utility_code=UtilityCode.load_cached('Py3UpdateBases', 'ObjectHandling.c'),
args=[CloneNode(orig_bases)])
self.orig_bases = orig_bases
if self.bases:
self.bases = self.bases.analyse_expressions(env)
if self.mkw:
......@@ -4916,6 +4935,8 @@ class PyClassDefNode(ClassDefNode):
code.mark_pos(self.pos)
code.pyclass_stack.append(self)
cenv = self.scope
if self.orig_bases:
self.orig_bases.generate_evaluation_code(code)
if self.bases:
self.bases.generate_evaluation_code(code)
if self.mkw:
......@@ -4923,6 +4944,17 @@ class PyClassDefNode(ClassDefNode):
if self.metaclass:
self.metaclass.generate_evaluation_code(code)
self.dict.generate_evaluation_code(code)
if self.orig_bases:
# update __orig_bases__ if needed
code.putln("if (%s != %s) {" % (self.bases.result(), self.orig_bases.result()))
code.putln(
code.error_goto_if_neg('PyDict_SetItemString(%s, "__orig_bases__", %s)' % (
self.dict.result(), self.orig_bases.result()),
self.pos
))
code.putln("}")
self.orig_bases.generate_disposal_code(code)
self.orig_bases.free_temps(code)
cenv.namespace_cname = cenv.class_obj_cname = self.dict.result()
class_cell = self.class_cell
......
......@@ -1024,6 +1024,94 @@ static PyObject *__Pyx_CreateClass(PyObject *bases, PyObject *dict, PyObject *na
return result;
}
/////////////// Py3UpdateBases.proto ///////////////
static PyObject* __Pyx_PEP560_update_bases(PyObject *bases); /* proto */
/////////////// Py3UpdateBases /////////////////////
//@requires: PyObjectCallOneArg
//@requires: PyObjectGetAttrStrNoError
/* Shamelessly adapted from cpython/bltinmodule.c update_bases */
static PyObject*
__Pyx_PEP560_update_bases(PyObject *bases)
{
Py_ssize_t i, j, size_bases;
PyObject *base, *meth, *new_base, *result, *new_bases = NULL;
/*assert(PyTuple_Check(bases));*/
size_bases = PyTuple_GET_SIZE(bases);
for (i = 0; i < size_bases; i++) {
// original code in CPython: base = args[i];
base = PyTuple_GET_ITEM(bases, i);
if (PyType_Check(base)) {
if (new_bases) {
// If we already have made a replacement, then we append every normal base,
// otherwise just skip it.
if (PyList_Append(new_bases, base) < 0) {
goto error;
}
}
continue;
}
// original code in CPython:
// if (_PyObject_LookupAttrId(base, &PyId___mro_entries__, &meth) < 0) {
meth = __Pyx_PyObject_GetAttrStrNoError(base, PYIDENT("__mro_entries__"));
if (!meth && PyErr_Occurred()) {
goto error;
}
if (!meth) {
if (new_bases) {
if (PyList_Append(new_bases, base) < 0) {
goto error;
}
}
continue;
}
new_base = __Pyx_PyObject_CallOneArg(meth, bases);
Py_DECREF(meth);
if (!new_base) {
goto error;
}
if (!PyTuple_Check(new_base)) {
PyErr_SetString(PyExc_TypeError,
"__mro_entries__ must return a tuple");
Py_DECREF(new_base);
goto error;
}
if (!new_bases) {
// If this is a first successful replacement, create new_bases list and
// copy previously encountered bases.
if (!(new_bases = PyList_New(i))) {
goto error;
}
for (j = 0; j < i; j++) {
// original code in CPython: base = args[j];
base = PyTuple_GET_ITEM(bases, j);
PyList_SET_ITEM(new_bases, j, base);
Py_INCREF(base);
}
}
j = PyList_GET_SIZE(new_bases);
if (PyList_SetSlice(new_bases, j, j, new_base) < 0) {
goto error;
}
Py_DECREF(new_base);
}
if (!new_bases) {
// unlike the CPython implementation, always return a new reference
Py_INCREF(bases);
return bases;
}
result = PyList_AsTuple(new_bases);
Py_DECREF(new_bases);
return result;
error:
Py_XDECREF(new_bases);
return NULL;
}
/////////////// Py3ClassCreate.proto ///////////////
static PyObject *__Pyx_Py3MetaclassPrepare(PyObject *metaclass, PyObject *bases, PyObject *name, PyObject *qualname,
......
......@@ -4,9 +4,159 @@
# COPIED FROM CPython 3.7
import contextlib
import unittest
import sys
class TestMROEntry(unittest.TestCase):
def test_mro_entry_signature(self):
tested = []
class B: ...
class C:
def __mro_entries__(self, *args, **kwargs):
tested.extend([args, kwargs])
return (C,)
c = C()
self.assertEqual(tested, [])
class D(B, c): ...
self.assertEqual(tested[0], ((B, c),))
self.assertEqual(tested[1], {})
def test_mro_entry(self):
tested = []
class A: ...
class B: ...
class C:
def __mro_entries__(self, bases):
tested.append(bases)
return (self.__class__,)
c = C()
self.assertEqual(tested, [])
class D(A, c, B): ...
self.assertEqual(tested[-1], (A, c, B))
self.assertEqual(D.__bases__, (A, C, B))
self.assertEqual(D.__orig_bases__, (A, c, B))
self.assertEqual(D.__mro__, (D, A, C, B, object))
d = D()
class E(d): ...
self.assertEqual(tested[-1], (d,))
self.assertEqual(E.__bases__, (D,))
def test_mro_entry_none(self):
tested = []
class A: ...
class B: ...
class C:
def __mro_entries__(self, bases):
tested.append(bases)
return ()
c = C()
self.assertEqual(tested, [])
class D(A, c, B): ...
self.assertEqual(tested[-1], (A, c, B))
self.assertEqual(D.__bases__, (A, B))
self.assertEqual(D.__orig_bases__, (A, c, B))
self.assertEqual(D.__mro__, (D, A, B, object))
class E(c): ...
self.assertEqual(tested[-1], (c,))
if sys.version_info[0] > 2:
# not all of it works on Python 2
self.assertEqual(E.__bases__, (object,))
self.assertEqual(E.__orig_bases__, (c,))
if sys.version_info[0] > 2:
# not all of it works on Python 2
self.assertEqual(E.__mro__, (E, object))
def test_mro_entry_with_builtins(self):
tested = []
class A: ...
class C:
def __mro_entries__(self, bases):
tested.append(bases)
return (dict,)
c = C()
self.assertEqual(tested, [])
class D(A, c): ...
self.assertEqual(tested[-1], (A, c))
self.assertEqual(D.__bases__, (A, dict))
self.assertEqual(D.__orig_bases__, (A, c))
self.assertEqual(D.__mro__, (D, A, dict, object))
def test_mro_entry_with_builtins_2(self):
tested = []
class C:
def __mro_entries__(self, bases):
tested.append(bases)
return (C,)
c = C()
self.assertEqual(tested, [])
class D(c, dict): ...
self.assertEqual(tested[-1], (c, dict))
self.assertEqual(D.__bases__, (C, dict))
self.assertEqual(D.__orig_bases__, (c, dict))
self.assertEqual(D.__mro__, (D, C, dict, object))
def test_mro_entry_errors(self):
class C_too_many:
def __mro_entries__(self, bases, something, other):
return ()
c = C_too_many()
with self.assertRaises(TypeError):
class D(c): ...
class C_too_few:
def __mro_entries__(self):
return ()
d = C_too_few()
with self.assertRaises(TypeError):
class D(d): ...
def test_mro_entry_errors_2(self):
class C_not_callable:
__mro_entries__ = "Surprise!"
c = C_not_callable()
with self.assertRaises(TypeError):
class D(c): ...
class C_not_tuple:
def __mro_entries__(self):
return object
c = C_not_tuple()
with self.assertRaises(TypeError):
class D(c): ...
def test_mro_entry_metaclass(self):
meta_args = []
class Meta(type):
def __new__(mcls, name, bases, ns):
meta_args.extend([mcls, name, bases, ns])
return super().__new__(mcls, name, bases, ns)
class A: ...
class C:
def __mro_entries__(self, bases):
return (A,)
c = C()
class D(c, metaclass=Meta):
x = 1
self.assertEqual(meta_args[0], Meta)
self.assertEqual(meta_args[1], 'D')
self.assertEqual(meta_args[2], (A,))
self.assertEqual(meta_args[3]['x'], 1)
self.assertEqual(D.__bases__, (A,))
self.assertEqual(D.__orig_bases__, (c,))
self.assertEqual(D.__mro__, (D, A, object))
self.assertEqual(D.__class__, Meta)
@unittest.skipIf(sys.version_info < (3, 7), "'type' checks for __mro_entries__ not implemented")
def test_mro_entry_type_call(self):
# Substitution should _not_ happen in direct type call
class C:
def __mro_entries__(self, bases):
return ()
c = C()
with self.assertRaisesRegex(TypeError,
"MRO entry resolution; "
"use types.new_class()"):
type('Bad', (c,), {})
class TestClassGetitem(unittest.TestCase):
# BEGIN - Additional tests from cython
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment