Commit bfc6469a authored by Mark Florisson's avatar Mark Florisson

tp_traverse/clear for Py_buffer and _memoryviewslice

parent 8a7d02b6
...@@ -20,6 +20,7 @@ import PyrexTypes ...@@ -20,6 +20,7 @@ import PyrexTypes
import TypeSlots import TypeSlots
import Version import Version
import DebugFlags import DebugFlags
import PyrexTypes
from Errors import error, warning from Errors import error, warning
from PyrexTypes import py_object_type from PyrexTypes import py_object_type
...@@ -1158,11 +1159,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1158,11 +1159,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
scope = type.scope scope = type.scope
if scope: # could be None if there was an error if scope: # could be None if there was an error
self.generate_exttype_vtable(scope, code) self.generate_exttype_vtable(scope, code)
self.generate_new_function(scope, code) self.generate_new_function(scope, code, entry)
self.generate_dealloc_function(scope, code) self.generate_dealloc_function(scope, code)
if scope.needs_gc(): if scope.needs_gc():
self.generate_traverse_function(scope, code) self.generate_traverse_function(scope, code, entry)
self.generate_clear_function(scope, code) self.generate_clear_function(scope, code, entry)
if scope.defines_any(["__getitem__"]): if scope.defines_any(["__getitem__"]):
self.generate_getitem_int_function(scope, code) self.generate_getitem_int_function(scope, code)
if scope.defines_any(["__setitem__", "__delitem__"]): if scope.defines_any(["__setitem__", "__delitem__"]):
...@@ -1202,19 +1203,24 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1202,19 +1203,24 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
type.declaration_code("p"), type.declaration_code("p"),
type.declaration_code(""))) type.declaration_code("")))
def generate_new_function(self, scope, code): def generate_new_function(self, scope, code, cclass_entry):
tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__') tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
slot_func = scope.mangle_internal("tp_new") slot_func = scope.mangle_internal("tp_new")
type = scope.parent_type type = scope.parent_type
base_type = type.base_type base_type = type.base_type
py_attrs = [] py_attrs = []
memviewslice_attrs = [] memviewslice_attrs = []
py_buffers = []
for entry in scope.var_entries: for entry in scope.var_entries:
if entry.type.is_pyobject: if entry.type.is_pyobject:
py_attrs.append(entry) py_attrs.append(entry)
elif entry.type.is_memoryviewslice: elif entry.type.is_memoryviewslice:
memviewslice_attrs.append(entry) memviewslice_attrs.append(entry)
need_self_cast = type.vtabslot_cname or py_attrs or memviewslice_attrs elif entry.type == PyrexTypes.c_py_buffer_type:
py_buffers.append(entry)
need_self_cast = type.vtabslot_cname or py_attrs or memviewslice_attrs or py_buffers
code.putln("") code.putln("")
code.putln( code.putln(
"static PyObject *%s(PyTypeObject *t, PyObject *a, PyObject *k) {" "static PyObject *%s(PyTypeObject *t, PyObject *a, PyObject *k) {"
...@@ -1251,15 +1257,24 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1251,15 +1257,24 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("p->%s = %s%s;" % ( code.putln("p->%s = %s%s;" % (
type.vtabslot_cname, type.vtabslot_cname,
struct_type_cast, type.vtabptr_cname)) struct_type_cast, type.vtabptr_cname))
for entry in py_attrs: for entry in py_attrs:
if scope.is_internal or entry.name == "__weakref__": if scope.is_internal or entry.name == "__weakref__":
# internal classes do not need None inits # internal classes do not need None inits
code.putln("p->%s = 0;" % entry.cname) code.putln("p->%s = 0;" % entry.cname)
else: else:
code.put_init_var_to_py_none(entry, "p->%s", nanny=False) code.put_init_var_to_py_none(entry, "p->%s", nanny=False)
for entry in memviewslice_attrs: for entry in memviewslice_attrs:
code.putln("p->%s.data = NULL;" % entry.cname) code.putln("p->%s.data = NULL;" % entry.cname)
code.putln("p->%s.memview = NULL;" % entry.cname) code.putln("p->%s.memview = NULL;" % entry.cname)
for entry in py_buffers:
code.putln("p->%s.obj = NULL;" % entry.cname)
if cclass_entry.cname == '__pyx_memoryviewslice':
code.putln("p->from_slice.memview = NULL;")
entry = scope.lookup_here("__new__") entry = scope.lookup_here("__new__")
if entry and entry.is_special: if entry and entry.is_special:
if entry.trivial_signature: if entry.trivial_signature:
...@@ -1334,7 +1349,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1334,7 +1349,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln( code.putln(
"}") "}")
def generate_traverse_function(self, scope, code): def generate_traverse_function(self, scope, code, cclass_entry):
tp_slot = TypeSlots.GCDependentSlot("tp_traverse") tp_slot = TypeSlots.GCDependentSlot("tp_traverse")
slot_func = scope.mangle_internal("tp_traverse") slot_func = scope.mangle_internal("tp_traverse")
base_type = scope.parent_type.base_type base_type = scope.parent_type.base_type
...@@ -1344,14 +1359,21 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1344,14 +1359,21 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln( code.putln(
"static int %s(PyObject *o, visitproc v, void *a) {" "static int %s(PyObject *o, visitproc v, void *a) {"
% slot_func) % slot_func)
py_attrs = [] py_attrs = []
py_buffers = []
for entry in scope.var_entries: for entry in scope.var_entries:
if entry.type.is_pyobject and entry.name != "__weakref__": if entry.type.is_pyobject and entry.name != "__weakref__":
py_attrs.append(entry) py_attrs.append(entry)
if entry.type == PyrexTypes.c_py_buffer_type:
py_buffers.append(entry)
if base_type or py_attrs: if base_type or py_attrs:
code.putln("int e;") code.putln("int e;")
if py_attrs:
if py_attrs or py_buffers:
self.generate_self_cast(scope, code) self.generate_self_cast(scope, code)
if base_type: if base_type:
# want to call it explicitly if possible so inlining can be performed # want to call it explicitly if possible so inlining can be performed
static_call = TypeSlots.get_base_slot_function(scope, tp_slot) static_call = TypeSlots.get_base_slot_function(scope, tp_slot)
...@@ -1363,6 +1385,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1363,6 +1385,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"e = %s->tp_traverse(o, v, a); if (e) return e;" % "e = %s->tp_traverse(o, v, a); if (e) return e;" %
base_type.typeptr_cname) base_type.typeptr_cname)
code.putln("}") code.putln("}")
for entry in py_attrs: for entry in py_attrs:
var_code = "p->%s" % entry.cname var_code = "p->%s" % entry.cname
code.putln( code.putln(
...@@ -1375,12 +1398,23 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1375,12 +1398,23 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
% var_code) % var_code)
code.putln( code.putln(
"}") "}")
for entry in py_buffers:
code.putln("if (p->%s.obj) {" % entry.cname)
code.putln( "e = (*v)(p->%s.obj, a); if (e) return e;" % entry.cname)
code.putln("}")
if cclass_entry.cname == '__pyx_memoryviewslice':
code.putln("if (p->from_slice.memview) {")
code.putln( "e = (*v)((PyObject *) p->from_slice.memview, a); if (e) return e;")
code.putln("}")
code.putln( code.putln(
"return 0;") "return 0;")
code.putln( code.putln(
"}") "}")
def generate_clear_function(self, scope, code): def generate_clear_function(self, scope, code, cclass_entry):
tp_slot = TypeSlots.GCDependentSlot("tp_clear") tp_slot = TypeSlots.GCDependentSlot("tp_clear")
slot_func = scope.mangle_internal("tp_clear") slot_func = scope.mangle_internal("tp_clear")
base_type = scope.parent_type.base_type base_type = scope.parent_type.base_type
...@@ -1388,13 +1422,19 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1388,13 +1422,19 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
return # never used return # never used
code.putln("") code.putln("")
code.putln("static int %s(PyObject *o) {" % slot_func) code.putln("static int %s(PyObject *o) {" % slot_func)
py_attrs = [] py_attrs = []
py_buffers = []
for entry in scope.var_entries: for entry in scope.var_entries:
if entry.type.is_pyobject and entry.name != "__weakref__": if entry.type.is_pyobject and entry.name != "__weakref__":
py_attrs.append(entry) py_attrs.append(entry)
if py_attrs: if entry.type == PyrexTypes.c_py_buffer_type:
py_buffers.append(entry)
if py_attrs or py_buffers:
self.generate_self_cast(scope, code) self.generate_self_cast(scope, code)
code.putln("PyObject* tmp;") code.putln("PyObject* tmp;")
if base_type: if base_type:
# want to call it explicitly if possible so inlining can be performed # want to call it explicitly if possible so inlining can be performed
static_call = TypeSlots.get_base_slot_function(scope, tp_slot) static_call = TypeSlots.get_base_slot_function(scope, tp_slot)
...@@ -1404,6 +1444,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1404,6 +1444,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("if (%s->tp_clear) {" % base_type.typeptr_cname) code.putln("if (%s->tp_clear) {" % base_type.typeptr_cname)
code.putln("%s->tp_clear(o);" % base_type.typeptr_cname) code.putln("%s->tp_clear(o);" % base_type.typeptr_cname)
code.putln("}") code.putln("}")
for entry in py_attrs: for entry in py_attrs:
name = "p->%s" % entry.cname name = "p->%s" % entry.cname
code.putln("tmp = ((PyObject*)%s);" % name) code.putln("tmp = ((PyObject*)%s);" % name)
...@@ -1412,6 +1453,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1412,6 +1453,13 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
else: else:
code.put_init_to_py_none(name, entry.type, nanny=False) code.put_init_to_py_none(name, entry.type, nanny=False)
code.putln("Py_XDECREF(tmp);") code.putln("Py_XDECREF(tmp);")
for entry in py_buffers:
code.putln("Py_CLEAR(p->%s.obj);" % entry.cname)
if cclass_entry.cname == '__pyx_memoryviewslice':
code.putln("__PYX_XDEC_MEMVIEW(&p->from_slice, 1);")
code.putln( code.putln(
"return 0;") "return 0;")
code.putln( code.putln(
......
...@@ -231,6 +231,15 @@ cdef class ObjectMockBuffer(MockBuffer): ...@@ -231,6 +231,15 @@ cdef class ObjectMockBuffer(MockBuffer):
cdef get_itemsize(self): return sizeof(void*) cdef get_itemsize(self): return sizeof(void*)
cdef get_default_format(self): return b"@O" cdef get_default_format(self): return b"@O"
cdef extern from "Python.h":
ctypedef struct PyObject:
pass
ctypedef int (*visitproc)(PyObject *obj, void *arg)
ctypedef int (*inquiry)(PyObject *self)
void Py_VISIT(object)
void Py_CLEAR(object)
cdef class IntStridedMockBuffer(IntMockBuffer): cdef class IntStridedMockBuffer(IntMockBuffer):
cdef __cythonbufferdefaults__ = {"mode" : "strided"} cdef __cythonbufferdefaults__ = {"mode" : "strided"}
......
...@@ -367,3 +367,27 @@ def test_memslice_getbuffer(): ...@@ -367,3 +367,27 @@ def test_memslice_getbuffer():
""" """
cdef int[:, :] array = create_array((4, 5), mode="c", use_callback=True) cdef int[:, :] array = create_array((4, 5), mode="c", use_callback=True)
print np.asarray(array)[::2, ::2] print np.asarray(array)[::2, ::2]
cdef class DeallocateMe(object):
def __dealloc__(self):
print "deallocated!"
# Disabled! References cycles don't seem to be supported by NumPy
# @testcase
def acquire_release_cycle(obj):
"""
>>> a = np.arange(20, dtype=np.object)
>>> a[10] = DeallocateMe()
>>> acquire_release_cycle(a)
deallocated!
"""
import gc
cdef object[:] buf = obj
buf[1] = buf
gc.collect()
del buf
gc.collect()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment