Commit f5984db3 authored by Xavier Thompson's avatar Xavier Thompson

Implement runtime isolation check on consuming non 'iso' cypclass

parent 44056f33
...@@ -11356,10 +11356,11 @@ class ConsumeNode(ExprNode): ...@@ -11356,10 +11356,11 @@ class ConsumeNode(ExprNode):
# Consume expression # Consume expression
# #
# operand ExprNode # operand ExprNode
#
# generate_runtime_check boolean used internally
# operand_is_named boolean used internally
subexprs = ['operand'] subexprs = ['operand']
generate_runtime_check = True
operand_is_named = True
def infer_type(self, env): def infer_type(self, env):
operand_type = self.operand.infer_type(env) operand_type = self.operand.infer_type(env)
...@@ -11375,10 +11376,6 @@ class ConsumeNode(ExprNode): ...@@ -11375,10 +11376,6 @@ class ConsumeNode(ExprNode):
error(self.pos, "Can only consume cypclass") error(self.pos, "Can only consume cypclass")
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
return self return self
if self.operand.is_name or self.operand.is_attribute:
self.is_temp = self.operand_is_named = True
# We steal the reference of the operand.
self.use_managed_ref = False
if operand_type.is_qualified_cyp_class: if operand_type.is_qualified_cyp_class:
if operand_type.qualifier == 'iso!': if operand_type.qualifier == 'iso!':
error(self.pos, "Cannot consume iso!") error(self.pos, "Cannot consume iso!")
...@@ -11390,7 +11387,13 @@ class ConsumeNode(ExprNode): ...@@ -11390,7 +11387,13 @@ class ConsumeNode(ExprNode):
else: else:
self.type = PyrexTypes.cyp_class_qualified_type(operand_type.qual_base_type, 'iso~') self.type = PyrexTypes.cyp_class_qualified_type(operand_type.qual_base_type, 'iso~')
else: else:
self.generate_runtime_check = True
self.type = PyrexTypes.cyp_class_qualified_type(operand_type, 'iso~') self.type = PyrexTypes.cyp_class_qualified_type(operand_type, 'iso~')
self.operand_is_named = self.operand.is_name or self.operand.is_attribute
self.is_temp = self.operand_is_named or self.generate_runtime_check
if self.is_temp:
# We steal the reference of the operand.
self.use_managed_ref = False
return self return self
def may_be_none(self): def may_be_none(self):
...@@ -11404,19 +11407,16 @@ class ConsumeNode(ExprNode): ...@@ -11404,19 +11407,16 @@ class ConsumeNode(ExprNode):
pass pass
def calculate_result_code(self): def calculate_result_code(self):
if self.generate_runtime_check: return self.operand.result()
# TODO: generate runtime check for isolation
return self.operand.result()
else:
return self.operand.result()
def generate_result_code(self, code): def generate_result_code(self, code):
if self.is_temp: if self.is_temp:
operand_result = self.operand.result() operand_result = self.operand.result()
code.putln("%s = %s;" % (self.result(), operand_result)) code.putln("%s = %s;" % (self.result(), operand_result))
if self.generate_runtime_check: if self.generate_runtime_check:
# TODO: generate runtime check for isolation code.putln("if (!%s->CyObject_iso()) {" % self.result())
pass code.putln("std::terminate();")
code.putln("}")
# We steal the reference of the operand. # We steal the reference of the operand.
code.putln("%s = NULL;" % operand_result) code.putln("%s = NULL;" % operand_result)
if self.operand.is_temp: if self.operand.is_temp:
......
...@@ -937,6 +937,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -937,6 +937,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_cyp_class_activated_methods(entry, code) self.generate_cyp_class_activated_methods(entry, code)
# Generate cypclass attr destructor # Generate cypclass attr destructor
self.generate_cyp_class_attrs_destructor_definition(entry, code) self.generate_cyp_class_attrs_destructor_definition(entry, code)
# Generate cypclass traverse method and isolation check method
self.generate_cyp_class_traverse_and_iso_definition(entry, code)
# Generate wrapper constructor # Generate wrapper constructor
wrapper = scope.lookup_here("<constructor>") wrapper = scope.lookup_here("<constructor>")
constructor = scope.lookup_here("<init>") constructor = scope.lookup_here("<init>")
...@@ -967,6 +969,41 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -967,6 +969,41 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("Cy_XDECREF(this->%s);" % attr.cname) code.putln("Cy_XDECREF(this->%s);" % attr.cname)
code.putln("}") code.putln("}")
def generate_cyp_class_traverse_and_iso_definition(self, entry, code):
"""
Generate traverse method and isolation check method definition for the given cypclass entry.
"""
scope = entry.type.scope
all_cypclass_attrs = [e for e in scope.entries.values()
if e.type.is_cyp_class and not e.name == "this"
and not e.is_type]
# potential template
if entry.type.templates:
templates_code = "template <typename %s>" % ", typename ".join(t.name for t in entry.type.templates)
else:
templates_code = None
# traverse method
namespace = entry.type.empty_declaration_code()
if templates_code:
code.putln(templates_code)
code.putln("int %s::CyObject_traverse(void *(*visit)(const CyObject *o, void *arg), void *arg) const" % namespace)
code.putln("{")
for attr in all_cypclass_attrs:
code.putln("if (void *ret = visit(this->%s, arg)) return (int) (intptr_t) ret;" % attr.cname)
code.putln("return 0;")
code.putln("}")
# isolation check method
if templates_code:
code.putln(templates_code)
code.putln("int %s::CyObject_iso() const" % namespace)
code.putln("{")
if all_cypclass_attrs:
code.putln("return __Pyx_CyObject_owning(this) == 1;")
else:
code.putln("return this->CyObject_GETREF() == 1;")
code.putln("}")
def generate_cyp_class_activated_methods(self, entry, code): def generate_cyp_class_activated_methods(self, entry, code):
""" """
Generate activated cypclass methods. Generate activated cypclass methods.
...@@ -1459,10 +1496,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1459,10 +1496,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("Py_TYPE(wrapper) = %s;" % wrapper_type.typeptr_cname) code.putln("Py_TYPE(wrapper) = %s;" % wrapper_type.typeptr_cname)
code.putln("}") code.putln("}")
def generate_typedef(self, entry, code): def generate_typedef(self, entry, code):
base_type = entry.type.typedef_base_type base_type = entry.type.typedef_base_type
enclosing_scope = entry.scope enclosing_scope = entry.scope
...@@ -1632,6 +1665,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1632,6 +1665,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
arg_names = [] arg_names = []
generate_cpp_constructor_code(arg_decls, arg_names, is_implementing, py_attrs, constructor) generate_cpp_constructor_code(arg_decls, arg_names, is_implementing, py_attrs, constructor)
if type.is_cyp_class:
# Declare the method to check isolation
code.putln("virtual int CyObject_iso() const;")
# Declare the traverse method
code.putln("virtual int CyObject_traverse(void *(*visit)(const CyObject *o, void *arg), void *arg) const;")
if type.is_cyp_class and cypclass_attrs: if type.is_cyp_class and cypclass_attrs:
# Declaring a small destruction handler which will always try to Cy_XDECREF # Declaring a small destruction handler which will always try to Cy_XDECREF
# every cypclass attribute. This handler is defined after all class definition. # every cypclass attribute. This handler is defined after all class definition.
......
...@@ -83,10 +83,21 @@ ...@@ -83,10 +83,21 @@
private: private:
mutable std::atomic_int nogil_ob_refcnt; mutable std::atomic_int nogil_ob_refcnt;
mutable CyLock ob_lock; mutable CyLock ob_lock;
public: public:
CyObject(): nogil_ob_refcnt(1) {} mutable const CyObject * __next;
mutable int __refcnt;
CyObject(): nogil_ob_refcnt(1), __next(NULL), __refcnt(0) {}
virtual ~CyObject() {} virtual ~CyObject() {}
/* Object graph inspection methods */
virtual int CyObject_iso() const {
return this->nogil_ob_refcnt == 1;
}
virtual int CyObject_traverse(void *(*visit)(const CyObject *o, void *arg), void *arg) const {
return 0;
}
/* Locking methods */ /* Locking methods */
void CyObject_RLOCK(const char * context) const; void CyObject_RLOCK(const char * context) const;
void CyObject_WLOCK(const char * context) const; void CyObject_WLOCK(const char * context) const;
...@@ -475,6 +486,67 @@ ...@@ -475,6 +486,67 @@
return ob; return ob;
} }
/*
* Visit callback to collect reachable fields.
*/
static void *__Pyx_CyObject_visit_collect(const CyObject *ob, void *arg) {
if (!ob)
return 0;
if (ob->__refcnt)
return 0;
ob->__refcnt = ob->CyObject_GETREF();
const CyObject *head = reinterpret_cast<CyObject *>(arg);
const CyObject *tmp = head->__next;
ob->__next = tmp;
head->__next = ob;
return 0;
}
/*
* Visit callback to decref reachable fields.
*/
static void *__Pyx_CyObject_visit_decref(const CyObject *ob, void *arg) {
(void) arg;
if (!ob)
return 0;
ob->__refcnt -= 1;
return 0;
}
/*
* Check if a CyObject is owning.
*/
static inline int __Pyx_CyObject_owning(const CyObject *root) {
const CyObject *current;
bool owning = true;
int owners;
/* Mark the root as already visited */
root->__refcnt = root->CyObject_GETREF();
/* Collect the reachable objects */
for(current = root; current != NULL; current = current->__next) {
current->CyObject_traverse(__Pyx_CyObject_visit_collect, (void*)current);
}
/* Decref the reachable objects */
for(current = root; current != NULL; current = current->__next) {
current->CyObject_traverse(__Pyx_CyObject_visit_decref, (void*)current);
}
/* Search for externally reachable object */
for(current = root->__next; current != NULL; current = current->__next) {
if (current->__refcnt)
owning = false;
}
/* Count external potential owners */
owners = root->__refcnt;
/* Cleanup */
for(current = root; current != NULL;) {
current->__refcnt = 0;
const CyObject *next = current->__next;
current->__next = NULL;
current = next;
}
return owning ? owners : 0;
}
/* /*
* Cast from CyObject to PyObject: * Cast from CyObject to PyObject:
* - borrow an atomic reference * - borrow an atomic reference
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment