Commit 47fa3717 authored by Xavier Thompson's avatar Xavier Thompson

Change cypclass lock transform to only lock around attribute accesses

parent 03d7c50c
......@@ -429,379 +429,114 @@ class CypclassLockTransform(Visitor.EnvTransform):
Check that cypclass objects are properly locked and insert locks if required.
"""
class StackLock:
"""
Context manager for tracking nested locks.
"""
def __init__(self, transform, obj_entry, state):
self.transform = transform
self.state = state
self.entry = obj_entry
def __enter__(self):
state = self.state
entry = self.entry
self.old_rlocked = self.transform.rlocked[entry]
self.old_wlocked = self.transform.wlocked[entry]
if state == 'rlocked':
self.transform.rlocked[entry] += 1
elif state == 'wlocked':
self.transform.wlocked[entry] += 1
def __exit__(self, *args):
entry = self.entry
self.transform.rlocked[entry] = self.old_rlocked
self.transform.wlocked[entry] = self.old_wlocked
def stacklock(self, obj_entry, state):
return self.StackLock(self, obj_entry, state)
def with_nested_stacklocks(self, stacklocks_iterator, body_callback):
# Poor mans's nested context managers
def lock(self, node, exclusive=True):
try:
stacklock = next(stacklocks_iterator)
except StopIteration:
return body_callback()
with stacklock:
return self.with_nested_stacklocks(stacklocks_iterator, body_callback)
class AccessContext:
"""
Context manager to track the kind of access (reading, writing ...).
"""
def __init__(self, collector, reading=False, writing=False, deleting=False):
self.collector = collector
self.reading = reading
self.writing = writing
self.deleting = deleting
def __enter__(self):
self.reading, self.collector.reading = self.collector.reading, self.reading
self.writing, self.collector.writing = self.collector.writing, self.writing
self.deleting, self.collector.deleting = self.collector.deleting, self.deleting
def __exit__(self, *args):
self.collector.reading = self.reading
self.collector.writing = self.writing
self.collector.deleting = self.deleting
def accesscontext(self, reading=False, writing=False, deleting=False):
return self.AccessContext(self, reading=reading, writing=writing, deleting=deleting)
lock = self.locked[node]
if exclusive and lock.state != "rlocked":
error(node.pos, "A writelock is required, but a readlock is manually acquired")
return node
except:
return ExprNodes.CoerceToLockedTempNode(node, self.current_env(), rlock_only = not exclusive)
def __call__(self, root):
self.rlocked = defaultdict(int)
self.wlocked = defaultdict(int)
self.reading = False
self.writing = False
self.deleting = False
self.locked = None
return super(CypclassLockTransform, self).__call__(root)
def reference_identifier(self, node):
while isinstance(node, ExprNodes.CoerceToTempNode):
node = node.arg
if node.is_name:
return node.entry
return None
def id_to_name(self, id):
return id.name
def lockcheck_on_context(self, node):
if self.writing or self.deleting:
return self.lockcheck_written(node)
elif self.reading:
return self.lockcheck_read(node)
return node
def lockcheck_read(self, read_node):
lock_mode = read_node.type.lock_mode
if lock_mode == "nolock":
return read_node
ref_id = self.reference_identifier(read_node)
if ref_id:
if not (self.rlocked[ref_id] > 0 or self.wlocked[ref_id] > 0):
if lock_mode == "checklock":
error(read_node.pos, (
"Reference '%s' is not correctly locked in this expression (read lock required)"
) % self.id_to_name(ref_id) )
elif lock_mode == "autolock":
# for now, lock a temporary for each expression
return ExprNodes.CoerceToLockedNode(read_node, self.current_env(), rlock_only=True)
else:
if lock_mode == "checklock":
error(read_node.pos, "This expression is not correctly locked (read lock required)")
elif lock_mode == "autolock":
if not isinstance(read_node, ExprNodes.CoerceToLockedNode):
return ExprNodes.CoerceToLockedNode(read_node, self.current_env(), rlock_only=True)
return read_node
def lockcheck_written(self, written_node):
lock_mode = written_node.type.lock_mode
if lock_mode == "nolock":
return written_node
ref_id = self.reference_identifier(written_node)
if ref_id:
if not self.wlocked[ref_id] > 0:
if lock_mode == "checklock":
error(written_node.pos, (
"Reference '%s' is not correctly locked in this expression (write lock required)"
) % self.id_to_name(ref_id) )
elif lock_mode == "autolock":
# for now, lock a temporary for each expression
return ExprNodes.CoerceToLockedNode(written_node, self.current_env(), rlock_only=False)
else:
if lock_mode == "checklock":
error(written_node.pos, "This expression is not correctly locked (write lock required)")
elif lock_mode == "autolock":
if isinstance(written_node, ExprNodes.CoerceToLockedNode):
written_node.rlock_only = False
else:
return ExprNodes.CoerceToLockedNode(written_node, self.current_env())
return written_node
def lockcheck_written_or_read(self, node, reading=False):
if reading:
return self.lockcheck_read(node)
else:
return self.lockcheck_written(node)
return node
def lockcheck_if_subscript_rhs(self, lhs, rhs):
if lhs.is_subscript and lhs.base.type.is_cyp_class:
setitem = lhs.base.type.scope.lookup("__setitem__")
if setitem and len(setitem.type.args) == 2:
arg_type = setitem.type.args[1].type
if arg_type.is_cyp_class:
return self.lockcheck_written_or_read(rhs, reading=arg_type.is_const_cyp_class)
# else: should have caused a previous error
return rhs
def visit_CFuncDefNode(self, node):
cyp_class_args = (e for e in node.local_scope.arg_entries if e.type.is_cyp_class)
arg_locks = []
for arg in cyp_class_args:
# Mark each cypclass arguments as locked within the function body
arg_locks.append(self.stacklock(arg, "rlocked" if arg.type.is_const_cyp_class else "wlocked"))
with_body = lambda: self.visit(node.body)
self.with_nested_stacklocks(iter(arg_locks), with_body)
def visit_Node(self, node):
return node
def visit_LockCypclassNode(self, node):
obj_ref_id = self.reference_identifier(node.obj)
if not obj_ref_id:
error(node.obj.pos, "Locking an unnamed reference")
return node
if not node.obj.type.is_cyp_class:
error(node.obj.pos, "Locking non-cypclass reference")
if node.nested:
self.visit(node.body)
return node
if not (obj_ref_id.is_local or obj_ref_id.is_arg or obj_ref_id.is_self_arg):
error(node.obj.pos, "Can only lock local variables or arguments")
if self.locked is not None:
error(node.pos, "A lock is already acquired")
return node
with self.stacklock(obj_ref_id, node.state):
self.visit(node.body)
node.objs = []
self.locked = {}
lock_node = node
while True:
obj = lock_node.obj
if not obj.is_name:
error(obj.pos, "Locking an unnamed reference")
return node
elif not obj.type.is_cyp_class:
error(obj.pos, "Locking non-cypclass reference")
return node
if self.locked.set_default(obj.entry, locked_node) is not locked_node:
error(obj.pos, "Locking the same name twice")
return node
try:
lock_node = lock_node.body.stats[0]
except:
return node
if not isinstance(lock_node, LockCypclassNode):
break
lock_node.nested = True
node.objs.append(obj)
self.visit(node.body)
return node
def visit_Node(self, node):
with self.accesscontext(reading=True):
self.visitchildren(node)
def visit_AttributeNode(self, node):
obj = node.obj
if obj.type and obj.type.is_cyp_class:
if not node.is_called:
node.obj = self.lock(node.obj, exclusive=node.is_target)
self.visitchildren(node)
return node
def visit_DelStatNode(self, node):
for arg in node.args:
arg_ref_id = self.reference_identifier(arg)
if self.rlocked[arg_ref_id] > 0 or self.wlocked[arg_ref_id] > 0:
arg_entry = arg.entry
if arg_entry in self.locked:
# Disallow unbinding a locked name
error(arg.pos, "Deleting a locked cypclass reference")
return node
with self.accesscontext(deleting=True):
self.visitchildren(node)
self.visitchildren(node)
return node
def visit_SingleAssignmentNode(self, node):
lhs_ref_id = self.reference_identifier(node.lhs)
if self.rlocked[lhs_ref_id] > 0 or self.wlocked[lhs_ref_id] > 0:
lhs = node.lhs
if lhs.entry in self.locked:
# Disallow re-binding a locked name
error(node.lhs.pos, "Assigning to a locked cypclass reference")
error(lhs.pos, "Assigning to a locked cypclass reference")
return node
node.rhs = self.lockcheck_if_subscript_rhs(node.lhs, node.rhs)
with self.accesscontext(writing=True):
self.visit(node.lhs)
with self.accesscontext(reading=True):
self.visit(node.rhs)
self.visitchildren(node)
return node
def visit_CascadedAssignmentNode(self, node):
for lhs in node.lhs_list:
lhs_ref_id = self.reference_identifier(lhs)
if self.rlocked[lhs_ref_id] > 0 or self.wlocked[lhs_ref_id] > 0:
if lhs.entry in self.locked:
# Disallow re-binding a locked name
error(lhs.pos, "Assigning to a locked cypclass reference")
return node
for lhs in node.lhs_list:
node.rhs = self.lockcheck_if_subscript_rhs(lhs, node.rhs)
with self.accesscontext(writing=True):
for lhs in node.lhs_list:
self.visit(lhs)
with self.accesscontext(reading=True):
self.visit(node.rhs)
self.visitchildren(node)
return node
def visit_WithTargetAssignmentStatNode(self, node):
target_id = self.reference_identifier(node.lhs)
if self.rlocked[target_id] > 0 or self.wlocked[target_id] > 0:
target_entry = node.target.entry
if target_entry in self.locked:
# Disallow re-binding a locked name
error(node.lhs.pos, "With expression target is a locked cypclass reference")
return node
node.rhs = self.lockcheck_if_subscript_rhs(node.lhs, node.rhs)
with self.accesscontext(writing=True):
self.visit(node.lhs)
with self.accesscontext(reading=True):
self.visit(node.rhs)
self.visitchildren(node)
return node
def visit__ForInStatNode(self, node):
target_id = self.reference_identifier(node.target)
if self.rlocked[target_id] > 0 or self.wlocked[target_id] > 0:
target_entry = node.target.entry
if target_entry in self.locked:
# Disallow re-binding a locked name
error(node.target.pos, "For-Loop target is a locked cypclass reference")
return node
node.item = self.lockcheck_if_subscript_rhs(node.target, node.item)
with self.accesscontext(writing=True):
self.visit(node.target)
with self.accesscontext(reading=True):
self.visit(node.item)
self.visit(node.body)
self.visit(node.iterator)
if node.else_clause:
self.visit(node.else_clause)
self.visitchildren(node)
return node
def visit_ExceptClauseNode(self, node):
if not node.target:
self.visitchildren(node)
else:
target_id = self.reference_identifier(node.target)
if self.rlocked[target_id] > 0 or self.wlocked[target_id] > 0:
if node.target:
target_entry = node.target.entry
if target_entry in self.locked:
# Disallow re-binding a locked name
error(node.target.pos, "Except clause target is a locked cypclass reference")
error(node.target.pos, "For-Loop target is a locked cypclass reference")
return node
with self.accesscontext(writing=True):
self.visit(node.target)
for p in node.pattern:
self.visit(p)
self.visit(node.body)
return node
def visit_AttributeNode(self, node):
if node.obj.type and node.obj.type.is_cyp_class:
if node.is_called and node.type.is_cfunction:
if not node.type.is_static_method:
node.obj = self.lockcheck_written_or_read(node.obj, reading=node.type.is_const_method)
else:
node.obj = self.lockcheck_on_context(node.obj)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_SimpleCallNode(self, node):
if node.type.is_error:
return node
func_type = node.function_type()
if func_type.is_cfunction:
formal_nargs = len(func_type.args)
actual_nargs = len(node.args)
for i, formal_arg, actual_arg in zip(range(actual_nargs), func_type.args, node.args):
if formal_arg.type.is_cyp_class and actual_arg.type.is_cyp_class:
node.args[i] = self.lockcheck_written_or_read(actual_arg, reading=formal_arg.type.is_const_cyp_class)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_CoerceFromCallable(self, node):
if node.arg.type.is_cyp_class:
node.arg = self.lockcheck_written_or_read(node.arg, reading=node.type.is_const_method)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_IndexNode(self, node):
if node.base.type and node.base.type.is_cyp_class:
func_entry = None
if self.deleting:
func_entry = node.base.type.scope.lookup("__delitem__")
elif self.writing:
func_entry = node.base.type.scope.lookup("__setitem__")
elif self.reading:
func_entry = node.base.type.scope.lookup("__getitem__")
if func_entry:
func_type = func_entry.type
node.base = self.lockcheck_written_or_read(node.base, reading=func_type.is_const_method)
if len(func_type.args):
if func_type.args[0].type.is_cyp_class:
node.index = self.lockcheck_written_or_read(node.index, reading=func_type.args[0].type.is_const_cyp_class)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def _visit_binop(self, node, func_type):
if func_type is not None:
if node.operand1.type.is_cyp_class and len(func_type.args) == 1:
node.operand1 = self.lockcheck_written_or_read(node.operand1, reading=func_type.is_const_method)
arg_type = func_type.args[0].type
if arg_type.is_cyp_class:
node.operand2 = self.lockcheck_written_or_read(node.operand2, reading=arg_type.is_const_cyp_class)
elif len(func_type.args) == 2:
arg1_type = func_type.args[0].type
if arg1_type.is_cyp_class:
node.operand1 = self.lockcheck_written_or_read(node.operand1, reading=arg1_type.is_const_cyp_class)
arg2_type = func_type.args[1].type
if arg2_type.is_cyp_class:
node.operand2 = self.lockcheck_written_or_read(node.operand2, reading=arg2_type.is_const_cyp_class)
def visit_BinopNode(self, node):
func_type = node.op_func_type
self._visit_binop(node, func_type)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_PrimaryCmpNode(self, node):
func_type = node.cmp_func_type
self._visit_binop(node, func_type)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_InPlaceAssignmentNode(self, node):
# operator = "operator%s="% node.operator
# if node.lhs.type.is_cyp_class:
# TODO: get operator function type and treat it like a binop with lhs and rhs
with self.accesscontext(reading=True, writing=True):
self.visit(node.lhs)
with self.accesscontext(reading=True):
self.visit(node.rhs)
return node
def _visit_unop(self, node, func_type):
if func_type is not None:
if node.operand.type.is_cyp_class and len(func_type.args) == 0:
node.operand = self.lockcheck_written_or_read(node.operand, reading=func_type.is_const_method)
def visit_UnopNode(self, node):
func_type = node.op_func_type
self._visit_unop(node, func_type)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_TypecastNode(self, node):
func_type = node.op_func_type
self._visit_unop(node, func_type)
with self.accesscontext(reading=True):
self.visitchildren(node)
self.visitchildren(node)
return node
......@@ -7295,6 +7295,7 @@ class AttributeNode(ExprNode):
return node
def analyse_types(self, env, target = 0):
self.is_target = target
self.initialized_check = env.directives['initializedcheck']
node = self.analyse_as_cimported_attribute_node(env, target)
if node is None and not target:
......
......@@ -8597,6 +8597,9 @@ class LockCypclassNode(StatNode):
child_attrs = ["body", "obj"]
nested = False
objs = None
def analyse_declarations(self, env):
self.body.analyse_declarations(env)
self.obj.analyse_declarations(env)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment