Commit bc3b8ae5 authored by Xavier Thompson's avatar Xavier Thompson

Fix cypclass lock transform for attribute locking

parent b96c2f36
...@@ -429,32 +429,39 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -429,32 +429,39 @@ class CypclassLockTransform(Visitor.EnvTransform):
Check that cypclass objects are properly locked and insert locks if required. Check that cypclass objects are properly locked and insert locks if required.
""" """
def lock(self, node, exclusive=True): def id(self, obj):
if not obj:
return None
while isinstance(obj, ExprNodes.CoerceToTempNode):
obj = obj.arg
if obj.is_name:
return obj.entry
return None
def lock(self, obj, exclusive=True):
try: try:
lock = self.locked[node] lock = self.locked[self.id(obj)]
if exclusive and lock.state != "rlocked": if exclusive and lock.state == "rlocked":
error(node.pos, "A writelock is required, but a readlock is manually acquired") error(lock.pos, "A writelock is required, but a readlock is manually acquired")
return node return obj
except: except:
return ExprNodes.CoerceToLockedTempNode(node, self.current_env(), rlock_only = not exclusive) return ExprNodes.CoerceToLockedNode(obj, self.current_env(), rlock_only = not exclusive)
def __call__(self, root): def __call__(self, root):
self.locked = None self.locked = {}
return super(CypclassLockTransform, self).__call__(root) return super(CypclassLockTransform, self).__call__(root)
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node)
return node return node
def visit_LockCypclassNode(self, node): def visit_LockCypclassNode(self, node):
if node.nested: if node.nested:
self.visit(node.body) self.visit(node.body)
return node return node
if self.locked is not None:
error(node.pos, "A lock is already acquired")
return node
node.objs = [] node.objs = []
self.locked = {}
lock_node = node lock_node = node
locked = {}
while True: while True:
obj = lock_node.obj obj = lock_node.obj
if not obj.is_name: if not obj.is_name:
...@@ -463,31 +470,34 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -463,31 +470,34 @@ class CypclassLockTransform(Visitor.EnvTransform):
elif not obj.type.is_cyp_class: elif not obj.type.is_cyp_class:
error(obj.pos, "Locking non-cypclass reference") error(obj.pos, "Locking non-cypclass reference")
return node return node
if self.locked.set_default(obj.entry, locked_node) is not locked_node: if locked.setdefault(obj.entry, lock_node) is not lock_node:
error(obj.pos, "Locking the same name twice") error(obj.pos, "Locking the same name twice")
return node return node
try: try:
lock_node = lock_node.body.stats[0] lock_node = lock_node.body.stats[0]
except: except:
return node return node
if not isinstance(lock_node, LockCypclassNode): if not isinstance(lock_node, Nodes.LockCypclassNode):
break break
lock_node.nested = True lock_node.nested = True
node.objs.append(obj) node.objs.append(obj)
self.locked.update(locked)
self.visit(node.body) self.visit(node.body)
for key in locked:
self.locked.pop(key)
return node return node
def visit_AttributeNode(self, node): def visit_AttributeNode(self, node):
self.visitchildren(node)
obj = node.obj obj = node.obj
if obj.type and obj.type.is_cyp_class: if obj.type and obj.type.is_cyp_class:
if not node.is_called: if not node.type.is_cfunction:
node.obj = self.lock(node.obj, exclusive=node.is_target) node.obj = self.lock(node.obj, exclusive=node.is_target)
self.visitchildren(node)
return node return node
def visit_DelStatNode(self, node): def visit_DelStatNode(self, node):
for arg in node.args: for arg in node.args:
arg_entry = arg.entry arg_entry = self.id(arg)
if arg_entry in self.locked: if arg_entry in self.locked:
# Disallow unbinding a locked name # Disallow unbinding a locked name
error(arg.pos, "Deleting a locked cypclass reference") error(arg.pos, "Deleting a locked cypclass reference")
...@@ -497,7 +507,7 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -497,7 +507,7 @@ class CypclassLockTransform(Visitor.EnvTransform):
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
lhs = node.lhs lhs = node.lhs
if lhs.entry in self.locked: if self.id(lhs) in self.locked:
# Disallow re-binding a locked name # Disallow re-binding a locked name
error(lhs.pos, "Assigning to a locked cypclass reference") error(lhs.pos, "Assigning to a locked cypclass reference")
return node return node
...@@ -506,7 +516,7 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -506,7 +516,7 @@ class CypclassLockTransform(Visitor.EnvTransform):
def visit_CascadedAssignmentNode(self, node): def visit_CascadedAssignmentNode(self, node):
for lhs in node.lhs_list: for lhs in node.lhs_list:
if lhs.entry in self.locked: if self.id(lhs) in self.locked:
# Disallow re-binding a locked name # Disallow re-binding a locked name
error(lhs.pos, "Assigning to a locked cypclass reference") error(lhs.pos, "Assigning to a locked cypclass reference")
return node return node
...@@ -514,8 +524,7 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -514,8 +524,7 @@ class CypclassLockTransform(Visitor.EnvTransform):
return node return node
def visit_WithTargetAssignmentStatNode(self, node): def visit_WithTargetAssignmentStatNode(self, node):
target_entry = node.target.entry if self.id(node.target) in self.locked:
if target_entry in self.locked:
# Disallow re-binding a locked name # Disallow re-binding a locked name
error(node.lhs.pos, "With expression target is a locked cypclass reference") error(node.lhs.pos, "With expression target is a locked cypclass reference")
return node return node
...@@ -523,8 +532,7 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -523,8 +532,7 @@ class CypclassLockTransform(Visitor.EnvTransform):
return node return node
def visit__ForInStatNode(self, node): def visit__ForInStatNode(self, node):
target_entry = node.target.entry if self.id(node.target) in self.locked:
if target_entry in self.locked:
# Disallow re-binding a locked name # Disallow re-binding a locked name
error(node.target.pos, "For-Loop target is a locked cypclass reference") error(node.target.pos, "For-Loop target is a locked cypclass reference")
return node return node
...@@ -532,9 +540,7 @@ class CypclassLockTransform(Visitor.EnvTransform): ...@@ -532,9 +540,7 @@ class CypclassLockTransform(Visitor.EnvTransform):
return node return node
def visit_ExceptClauseNode(self, node): def visit_ExceptClauseNode(self, node):
if node.target: if self.id(node.target) in self.locked:
target_entry = node.target.entry
if target_entry in self.locked:
# Disallow re-binding a locked name # Disallow re-binding a locked name
error(node.target.pos, "For-Loop target is a locked cypclass reference") error(node.target.pos, "For-Loop target is a locked cypclass reference")
return node return node
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment