Commit bfc6469a authored by Mark Florisson's avatar Mark Florisson

tp_traverse/clear for Py_buffer and _memoryviewslice

parent 8a7d02b6
......@@ -20,6 +20,7 @@ import PyrexTypes
import TypeSlots
import Version
import DebugFlags
import PyrexTypes
from Errors import error, warning
from PyrexTypes import py_object_type
......@@ -1158,11 +1159,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
scope = type.scope
if scope: # could be None if there was an error
self.generate_exttype_vtable(scope, code)
self.generate_new_function(scope, code)
self.generate_new_function(scope, code, entry)
self.generate_dealloc_function(scope, code)
if scope.needs_gc():
self.generate_traverse_function(scope, code)
self.generate_clear_function(scope, code)
self.generate_traverse_function(scope, code, entry)
self.generate_clear_function(scope, code, entry)
if scope.defines_any(["__getitem__"]):
self.generate_getitem_int_function(scope, code)
if scope.defines_any(["__setitem__", "__delitem__"]):
......@@ -1202,19 +1203,24 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
type.declaration_code("p"),
type.declaration_code("")))
def generate_new_function(self, scope, code):
def generate_new_function(self, scope, code, cclass_entry):
tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
slot_func = scope.mangle_internal("tp_new")
type = scope.parent_type
base_type = type.base_type
py_attrs = []
memviewslice_attrs = []
py_buffers = []
for entry in scope.var_entries:
if entry.type.is_pyobject:
py_attrs.append(entry)
elif entry.type.is_memoryviewslice:
memviewslice_attrs.append(entry)
need_self_cast = type.vtabslot_cname or py_attrs or memviewslice_attrs
elif entry.type == PyrexTypes.c_py_buffer_type:
py_buffers.append(entry)
need_self_cast = type.vtabslot_cname or py_attrs or memviewslice_attrs or py_buffers
code.putln("")
code.putln(
"static PyObject *%s(PyTypeObject *t, PyObject *a, PyObject *k) {"
......@@ -1251,15 +1257,24 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("p->%s = %s%s;" % (
type.vtabslot_cname,
struct_type_cast, type.vtabptr_cname))
for entry in py_attrs:
if scope.is_internal or entry.name == "__weakref__":
# internal classes do not need None inits
code.putln("p->%s = 0;" % entry.cname)
else:
code.put_init_var_to_py_none(entry, "p->%s", nanny=False)
for entry in memviewslice_attrs:
code.putln("p->%s.data = NULL;" % entry.cname)
code.putln("p->%s.memview = NULL;" % entry.cname)
for entry in py_buffers:
code.putln("p->%s.obj = NULL;" % entry.cname)
if cclass_entry.cname == '__pyx_memoryviewslice':
code.putln("p->from_slice.memview = NULL;")
entry = scope.lookup_here("__new__")
if entry and entry.is_special:
if entry.trivial_signature:
......@@ -1334,7 +1349,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(
"}")
def generate_traverse_function(self, scope, code):
def generate_traverse_function(self, scope, code, cclass_entry):
tp_slot = TypeSlots.GCDependentSlot("tp_traverse")
slot_func = scope.mangle_internal("tp_traverse")
base_type = scope.parent_type.base_type
......@@ -1344,14 +1359,21 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln(
"static int %s(PyObject *o, visitproc v, void *a) {"
% slot_func)
py_attrs = []
py_buffers = []
for entry in scope.var_entries:
if entry.type.is_pyobject and entry.name != "__weakref__":
py_attrs.append(entry)
if entry.type == PyrexTypes.c_py_buffer_type:
py_buffers.append(entry)
if base_type or py_attrs:
code.putln("int e;")
if py_attrs:
if py_attrs or py_buffers:
self.generate_self_cast(scope, code)
if base_type:
# want to call it explicitly if possible so inlining can be performed
static_call = TypeSlots.get_base_slot_function(scope, tp_slot)
......@@ -1363,6 +1385,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"e = %s->tp_traverse(o, v, a); if (e) return e;" %
base_type.typeptr_cname)
code.putln("}")
for entry in py_attrs:
var_code = "p->%s" % entry.cname
code.putln(
......@@ -1375,12 +1398,23 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
% var_code)
code.putln(
"}")
for entry in py_buffers:
code.putln("if (p->%s.obj) {" % entry.cname)
code.putln( "e = (*v)(p->%s.obj, a); if (e) return e;" % entry.cname)
code.putln("}")
if cclass_entry.cname == '__pyx_memoryviewslice':
code.putln("if (p->from_slice.memview) {")
code.putln( "e = (*v)((PyObject *) p->from_slice.memview, a); if (e) return e;")
code.putln("}")
code.putln(
"return 0;")
code.putln(
"}")
def generate_clear_function(self, scope, code):
def generate_clear_function(self, scope, code, cclass_entry):
tp_slot = TypeSlots.GCDependentSlot("tp_clear")
slot_func = scope.mangle_internal("tp_clear")
base_type = scope.parent_type.base_type
......@@ -1388,13 +1422,19 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
return # never used
code.putln("")
code.putln("static int %s(PyObject *o) {" % slot_func)
py_attrs = []
py_buffers = []
for entry in scope.var_entries:
if entry.type.is_pyobject and entry.name != "__weakref__":
py_attrs.append(entry)
if py_attrs:
if entry.type == PyrexTypes.c_py_buffer_type:
py_buffers.append(entry)
if py_attrs or py_buffers:
self.generate_self_cast(scope, code)
code.putln("PyObject* tmp;")
if base_type:
# want to call it explicitly if possible so inlining can be performed
static_call = TypeSlots.get_base_slot_function(scope, tp_slot)
......@@ -1404,6 +1444,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("if (%s->tp_clear) {" % base_type.typeptr_cname)
code.putln("%s->tp_clear(o);" % base_type.typeptr_cname)
code.putln("}")
for entry in py_attrs:
name = "p->%s" % entry.cname
code.putln("tmp = ((PyObject*)%s);" % name)
......@@ -1412,6 +1453,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
else:
code.put_init_to_py_none(name, entry.type, nanny=False)
code.putln("Py_XDECREF(tmp);")
for entry in py_buffers:
code.putln("Py_CLEAR(p->%s.obj);" % entry.cname)
if cclass_entry.cname == '__pyx_memoryviewslice':
code.putln("__PYX_XDEC_MEMVIEW(&p->from_slice, 1);")
code.putln(
"return 0;")
code.putln(
......
......@@ -231,6 +231,15 @@ cdef class ObjectMockBuffer(MockBuffer):
cdef get_itemsize(self): return sizeof(void*)
cdef get_default_format(self): return b"@O"
cdef extern from "Python.h":
ctypedef struct PyObject:
pass
ctypedef int (*visitproc)(PyObject *obj, void *arg)
ctypedef int (*inquiry)(PyObject *self)
void Py_VISIT(object)
void Py_CLEAR(object)
cdef class IntStridedMockBuffer(IntMockBuffer):
cdef __cythonbufferdefaults__ = {"mode" : "strided"}
......
......@@ -367,3 +367,27 @@ def test_memslice_getbuffer():
"""
cdef int[:, :] array = create_array((4, 5), mode="c", use_callback=True)
print np.asarray(array)[::2, ::2]
cdef class DeallocateMe(object):
def __dealloc__(self):
print "deallocated!"
# Disabled! References cycles don't seem to be supported by NumPy
# @testcase
def acquire_release_cycle(obj):
"""
>>> a = np.arange(20, dtype=np.object)
>>> a[10] = DeallocateMe()
>>> acquire_release_cycle(a)
deallocated!
"""
import gc
cdef object[:] buf = obj
buf[1] = buf
gc.collect()
del buf
gc.collect()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment