Commit 5aef3d90 authored by Xavier Thompson's avatar Xavier Thompson

Refactor the whole cypclass locking framework into a single Visitor

parent 18dee5dc
...@@ -6,9 +6,11 @@ ...@@ -6,9 +6,11 @@
from __future__ import absolute_import from __future__ import absolute_import
import cython import cython
cython.declare(Naming=object, PyrexTypes=object, EncodedString=object) cython.declare(Naming=object, PyrexTypes=object, EncodedString=object, error=object)
from collections import defaultdict from collections import defaultdict
from contextlib import ExitStack
from itertools import chain
from . import Naming from . import Naming
from . import Nodes from . import Nodes
...@@ -19,6 +21,7 @@ from . import TreeFragment ...@@ -19,6 +21,7 @@ from . import TreeFragment
from .StringEncoding import EncodedString from .StringEncoding import EncodedString
from .ParseTreeTransforms import NormalizeTree, InterpretCompilerDirectives, DecoratorTransform, AnalyseDeclarationsTransform from .ParseTreeTransforms import NormalizeTree, InterpretCompilerDirectives, DecoratorTransform, AnalyseDeclarationsTransform
from .Errors import error
# #
# Visitor for wrapper cclass injection # Visitor for wrapper cclass injection
...@@ -444,3 +447,347 @@ def NAME(ARGDECLS): ...@@ -444,3 +447,347 @@ def NAME(ARGDECLS):
return method_wrapper return method_wrapper
class CypclassLockTransform(Visitor.EnvTransform):
class StackLock:
def __init__(self, transform, obj_entry, state):
self.transform = transform
self.state = state
self.entry = obj_entry
def __enter__(self):
state = self.state
entry = self.entry
self.old_rlocked = self.transform.rlocked[entry]
self.old_wlocked = self.transform.wlocked[entry]
if state == 'rlocked':
self.transform.rlocked[entry] += 1
elif state == 'wlocked':
self.transform.wlocked[entry] += 1
elif state == 'unlocked':
if self.rlocked > 0:
self.transform.rlocked[entry] -= 1
elif self.wlocked > 0:
self.transform.wlocked[entry] -= 1
def __exit__(self, *args):
entry = self.entry
self.transform.rlocked[entry] = self.old_rlocked
self.transform.wlocked[entry] = self.old_wlocked
def stacklock(self, obj_entry, state):
return self.StackLock(self, obj_entry, state)
class AccessContext:
def __init__(self, collector, reading=False, writing=False, deleting=False):
self.collector = collector
self.reading = reading
self.writing = writing
self.deleting = deleting
def __enter__(self):
self.reading, self.collector.reading = self.collector.reading, self.reading
self.writing, self.collector.writing = self.collector.writing, self.writing
self.deleting, self.collector.deleting = self.collector.deleting, self.deleting
def __exit__(self, *args):
self.collector.reading = self.reading
self.collector.writing = self.writing
self.collector.deleting = self.deleting
def accesscontext(self, reading=False, writing=False, deleting=False):
return self.AccessContext(self, reading=reading, writing=writing, deleting=deleting)
def __call__(self, root):
self.rlocked = defaultdict(int)
self.wlocked = defaultdict(int)
self.reading = False
self.writing = False
self.deleting = False
return super(CypclassLockTransform, self).__call__(root)
def reference_identifier(self, node):
while isinstance(node, ExprNodes.CoerceToTempNode): # works for CoerceToLockedTempNode too
node = node.arg
if node.is_name:
return node.entry
return None
def id_to_name(self, id):
return id.name
def lockcheck_on_context(self, node):
if self.writing or self.deleting:
return self.lockcheck_written(node)
elif self.reading:
return self.lockcheck_read(node)
return node
def lockcheck_read(self, read_node):
lock_mode = read_node.type.lock_mode
if lock_mode == "nolock":
return read_node
ref_id = self.reference_identifier(read_node)
if ref_id:
if not (self.rlocked[ref_id] > 0 or self.wlocked[ref_id] > 0):
if lock_mode == "checklock":
error(read_node.pos, (
"Reference '%s' is not correctly locked in this expression "
"(read lock required)"
) % self.id_to_name(ref_id) )
elif lock_mode == "autolock":
# for now, lock a temporary for each expression
return ExprNodes.CoerceToLockedTempNode(read_node, self.current_env(), rlock_only=True)
else:
if lock_mode == "checklock":
error(read_node.pos, "This expression is not correctly locked (read lock required)")
elif lock_mode == "autolock":
if not isinstance(read_node, ExprNodes.CoerceToLockedTempNode):
return ExprNodes.CoerceToLockedTempNode(read_node, self.current_env(), rlock_only=True)
return read_node
def lockcheck_written(self, written_node):
lock_mode = written_node.type.lock_mode
if lock_mode == "nolock":
return written_node
ref_id = self.reference_identifier(written_node)
if ref_id:
if not self.wlocked[ref_id] > 0:
if lock_mode == "checklock":
error(written_node.pos, (
"Reference '%s' is not correctly locked in this expression "
"(write lock required)"
) % self.id_to_name(ref_id) )
elif lock_mode == "autolock":
# for now, lock a temporary for each expression
return ExprNodes.CoerceToLockedTempNode(written_node, self.current_env(), rlock_only=False)
else:
if lock_mode == "checklock":
error(written_node.pos, "This expression is not correctly locked (write lock required)")
elif lock_mode == "autolock":
if isinstance(written_node, ExprNodes.CoerceToLockedTempNode):
written_node.rlock_only = False
else:
return ExprNodes.CoerceToLockedTempNode(written_node, self.current_env())
return written_node
def lockcheck_written_or_read(self, node, reading=False):
if reading:
return self.lockcheck_read(node)
else:
return self.lockcheck_written(node)
return node
def lockcheck_if_subscript_rhs(self, lhs, rhs):
if lhs.is_subscript and lhs.base.type.is_cyp_class:
setitem = lhs.base.type.scope.lookup("__setitem__")
if setitem and len(setitem.type.args) == 2:
arg_type = setitem.type.args[1].type
if arg_type.is_cyp_class:
return self.lockcheck_written_or_read(rhs, reading=arg_type.is_const)
# else: should have caused a previous error
return rhs
def visit_CFuncDefNode(self, node):
cyp_class_args = (e for e in node.local_scope.arg_entries if e.type.is_cyp_class)
with ExitStack() as locked_args_stack:
for arg in cyp_class_args:
is_rlocked = arg.type.is_const or arg.is_self_arg and node.entry.type.is_const_method
arg_id = arg
locked_args_stack.enter_context(self.stacklock(arg_id, "rlocked" if is_rlocked else "wlocked"))
self.visit(node.body)
return node
def visit_LockCypclassNode(self, node):
obj_ref_id = self.reference_identifier(node.obj)
if not obj_ref_id:
error(node.obj.pos, "Locking an unnamed reference")
return node
if not node.obj.type.is_cyp_class:
error(node.obj.pos, "Locking non-cypclass reference")
return node
with self.stacklock(obj_ref_id, node.state):
self.visit(node.body)
return node
def visit_Node(self, node):
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_DelStatNode(self, node):
for arg in node.args:
arg_ref_id = self.reference_identifier(arg)
if self.rlocked[arg_ref_id] > 0 or self.wlocked[arg_ref_id] > 0:
error(arg.pos, "Deleting a locked cypclass reference")
return node
with self.accesscontext(deleting=True):
self.visitchildren(node)
return node
def visit_SingleAssignmentNode(self, node):
lhs_ref_id = self.reference_identifier(node.lhs)
if self.rlocked[lhs_ref_id] > 0 or self.wlocked[lhs_ref_id] > 0:
error(node.lhs.pos, "Assigning to a locked cypclass reference")
return node
node.rhs = self.lockcheck_if_subscript_rhs(node.lhs, node.rhs)
with self.accesscontext(writing=True):
self.visit(node.lhs)
with self.accesscontext(reading=True):
self.visit(node.rhs)
return node
def visit_CascadedAssignmentNode(self, node):
for lhs in node.lhs_list:
lhs_ref_id = self.reference_identifier(lhs)
if self.rlocked[lhs_ref_id] > 0 or self.wlocked[lhs_ref_id] > 0:
error(lhs.pos, "Assigning to a locked cypclass reference")
return node
for lhs in node.lhs_list:
node.rhs = self.lockcheck_if_subscript_rhs(lhs, node.rhs)
with self.accesscontext(writing=True):
for lhs in node.lhs_list:
self.visit(lhs)
with self.accesscontext(reading=True):
self.visit(node.rhs)
return node
def visit_WithTargetAssignmentStatNode(self, node):
target_id = self.reference_identifier(node.lhs)
if self.rlocked[target_id] > 0 or self.wlocked[target_id] > 0:
error(node.lhs.pos, "With expression target is a locked cypclass reference")
return node
node.rhs = self.lockcheck_if_subscript_rhs(node.lhs, node.rhs)
with self.accesscontext(writing=True):
self.visit(node.lhs)
with self.accesscontext(reading=True):
self.visit(node.rhs)
return node
def visit__ForInStatNode(self, node):
target_id = self.reference_identifier(node.target)
if self.rlocked[target_id] > 0 or self.wlocked[target_id] > 0:
error(node.target.pos, "For-Loop target is a locked cypclass reference")
return node
node.item = self.lockcheck_if_subscript_rhs(node.target, node.item)
with self.accesscontext(writing=True):
self.visit(node.target)
with self.accesscontext(reading=True):
self.visit(node.item)
self.visit(node.body)
self.visit(node.iterator)
if node.else_clause:
self.visit(node.else_clause)
return node
def visit_ExceptClauseNode(self, node):
if not node.target:
self.visitchildren(node)
else:
target_id = self.reference_identifier(node.target)
if self.rlocked[target_id] > 0 or self.wlocked[target_id] > 0:
error(node.target.pos, "Except clause target is a locked cypclass reference")
return node
with self.accesscontext(writing=True):
self.visit(node.target)
for p in node.pattern:
self.visit(p)
self.visit(node.body)
return node
def visit_AttributeNode(self, node):
if node.obj.type and node.obj.type.is_cyp_class:
if node.is_called:
if not node.type.is_static_method:
node.obj = self.lockcheck_written_or_read(node.obj, reading=node.type.is_const_method)
else:
node.obj = self.lockcheck_on_context(node.obj)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_SimpleCallNode(self, node):
for i, arg in enumerate(node.args or ()): # provide an empty tuple fallback in case node.args is None
if arg.type.is_cyp_class:
node.args[i] = self.lockcheck_written_or_read(arg, reading=arg.type.is_const)
# TODO: lock callable objects
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_IndexNode(self, node):
if node.base.type.is_cyp_class:
func_entry = None
if self.deleting:
func_entry = node.base.type.scope.lookup("__delitem__")
elif self.writing:
func_entry = node.base.type.scope.lookup("__setitem__")
elif self.reading:
func_entry = node.base.type.scope.lookup("__getitem__")
if func_entry:
func_type = func_entry.type
node.base = self.lockcheck_written_or_read(node.base, reading=func_type.is_const_method)
if len(func_type.args):
node.index = self.lockcheck_written_or_read(node.index, reading=func_type.args[0].type.is_const)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def _visit_binop(self, node, func_type):
if func_type is not None:
if node.operand1.type.is_cyp_class and len(func_type.args) == 1:
node.operand1 = self.lockcheck_written_or_read(node.operand1, reading=func_type.is_const_method)
arg_type = func_type.args[0].type
if arg_type.is_cyp_class:
node.operand2 = self.lockcheck_written_or_read(node.operand2, reading=arg_type.is_const)
elif len(func_type.args) == 2:
arg1_type = func_type.args[0].type
if arg1_type.is_cyp_class:
node.operand1 = self.lockcheck_written_or_read(node.operand1, reading=arg1_type.is_const)
arg2_type = func_type.args[1].type
if arg2_type.is_cyp_class:
node.operand2 = self.lockcheck_written_or_read(node.operand2, reading=arg2_type.is_const)
def visit_BinopNode(self, node):
func_type = node.op_func_type
self._visit_binop(node, func_type)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_PrimaryCmpNode(self, node):
func_type = node.cmp_func_type
self._visit_binop(node, func_type)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_InPlaceAssignmentNode(self, node):
# operator = "operator%s="% node.operator
# if node.lhs.type.is_cyp_class:
# TODO: get operator function type and treat it like a binop with lhs and rhs
with self.accesscontext(reading=True, writing=True):
self.visit(node.lhs)
with self.accesscontext(reading=True):
self.visit(node.rhs)
return node
def _visit_unop(self, node, func_type):
if func_type is not None:
if node.operand.type.is_cyp_class and len(func_type.args) == 0:
node.operand = self.lockcheck_written_or_read(node.operand, reading=func_type.is_const_method)
def visit_UnopNode(self, node):
func_type = node.op_func_type
self._visit_unop(node, func_type)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
def visit_TypecastNode(self, node):
func_type = node.op_func_type
self._visit_unop(node, func_type)
with self.accesscontext(reading=True):
self.visitchildren(node)
return node
...@@ -320,7 +320,6 @@ class ExprNode(Node): ...@@ -320,7 +320,6 @@ class ExprNode(Node):
use_managed_ref = True # can be set by optimisation transforms use_managed_ref = True # can be set by optimisation transforms
result_is_used = True result_is_used = True
is_numpy_attribute = False is_numpy_attribute = False
tracked_state = None
# The Analyse Expressions phase for expressions is split # The Analyse Expressions phase for expressions is split
# into two sub-phases: # into two sub-phases:
...@@ -724,89 +723,6 @@ class ExprNode(Node): ...@@ -724,89 +723,6 @@ class ExprNode(Node):
def addr_not_const(self): def addr_not_const(self):
error(self.pos, "Address is not constant") error(self.pos, "Address is not constant")
def set_autorlock(self, env):
self.tracked_state.is_rlocked = True
self.tracked_state.needs_rlock = True
def set_autowlock(self, env):
self.tracked_state.is_wlocked = True
self.tracked_state.needs_wlock = True
def needs_rlock(self):
if self.tracked_state is None:
return False
return self.tracked_state.needs_rlock
def needs_wlock(self):
if self.tracked_state is None:
return False
return self.tracked_state.needs_wlock
def is_autolock(self):
return self.type is not None and self.type.is_cyp_class and self.type.lock_mode == "autolock"
def is_checklock(self):
return self.type is not None and self.type.is_cyp_class and self.type.lock_mode == "checklock"
def get_tracked_state(self, env):
if not hasattr(self, 'entry') or not self.entry or not self.entry.type.is_cyp_class:
return
self.tracked_state = env.lookup_tracked(self.entry)
if self.tracked_state is None:
self.tracked_state = env.declare_tracked(self.entry)
if self.is_autolock() and self.entry.is_variable:
env.declare_autolocked(self)
def is_rhs_locked(self, env):
if not hasattr(self, 'entry') or self.entry.type is None or not self.entry.type.is_cyp_class:
# These nodes couldn't be tracked (because it is for example a constant),
# so we let them pass silently
return True
return self.tracked_state.is_rlocked or self.tracked_state.is_wlocked
def is_lhs_locked(self, env):
if not hasattr(self, 'entry') or self.entry.type is None or not self.entry.type.is_cyp_class:
# These nodes couldn't be tracked (because it is for example a constant),
# so we let them pass silently
return True
return self.tracked_state.is_wlocked
def ensure_subexpr_rhs_locked(self, env):
for node in self.subexpr_nodes():
node.ensure_rhs_locked(env)
def ensure_subexpr_lhs_locked(self, env):
for node in self.subexpr_nodes():
node.ensure_lhs_locked(env)
def ensure_rhs_locked(self, env, is_dereferenced = False):
self.ensure_subexpr_rhs_locked(env)
if not self.tracked_state:
self.get_tracked_state(env)
if is_dereferenced and self.tracked_state:
if not self.is_rhs_locked(env):
if self.is_checklock():
error(self.pos, "This expression is not correctly locked (read lock needed)")
elif self.is_autolock():
self.set_autorlock(env)
def ensure_lhs_locked(self, env, is_dereferenced = False, is_top_lhs = False):
if not is_dereferenced:
self.ensure_subexpr_lhs_locked(env)
else:
self.ensure_subexpr_rhs_locked(env)
if not self.tracked_state:
self.get_tracked_state(env)
if self.is_autolock() and is_top_lhs:
#env.declare_autolocked(self)
self.tracked_as_lhs = True
if is_dereferenced and self.tracked_state:
if not self.is_lhs_locked(env):
if self.is_checklock():
error(self.pos, "This expression is not correctly locked (write lock needed)")
elif self.is_autolock():
self.set_autowlock(env)
# ----------------- Result Allocation ----------------- # ----------------- Result Allocation -----------------
def result_in_temp(self): def result_in_temp(self):
...@@ -2438,7 +2354,6 @@ class NameNode(AtomicExprNode): ...@@ -2438,7 +2354,6 @@ class NameNode(AtomicExprNode):
exception_check=None, exception_value=None): exception_check=None, exception_value=None):
#print "NameNode.generate_assignment_code:", self.name ### #print "NameNode.generate_assignment_code:", self.name ###
entry = self.entry entry = self.entry
tracked_state = self.tracked_state
if entry is None: if entry is None:
return # There was an error earlier return # There was an error earlier
...@@ -2446,9 +2361,6 @@ class NameNode(AtomicExprNode): ...@@ -2446,9 +2361,6 @@ class NameNode(AtomicExprNode):
and not self.lhs_of_first_assignment and not rhs.in_module_scope): and not self.lhs_of_first_assignment and not rhs.in_module_scope):
error(self.pos, "Literal list must be assigned to pointer at time of declaration") error(self.pos, "Literal list must be assigned to pointer at time of declaration")
if self.is_autolock() and tracked_state and (tracked_state.needs_wlock or tracked_state.needs_rlock):
code.putln("Cy_UNLOCK(%s);" % self.result())
# is_pyglobal seems to be True for module level-globals only. # is_pyglobal seems to be True for module level-globals only.
# We use this to access class->tp_dict if necessary. # We use this to access class->tp_dict if necessary.
if entry.is_pyglobal: if entry.is_pyglobal:
...@@ -2546,11 +2458,6 @@ class NameNode(AtomicExprNode): ...@@ -2546,11 +2458,6 @@ class NameNode(AtomicExprNode):
code.putln('new (&%s) decltype(%s){%s};' % (self.result(), self.result(), result)) code.putln('new (&%s) decltype(%s){%s};' % (self.result(), self.result(), result))
elif result != self.result(): elif result != self.result():
code.putln('%s = %s;' % (self.result(), result)) code.putln('%s = %s;' % (self.result(), result))
if self.is_autolock():
if tracked_state.needs_wlock:
code.putln("Cy_WLOCK(%s);" % self.result())
elif tracked_state.needs_rlock:
code.putln("Cy_RLOCK(%s);" % self.result())
if debug_disposal_code: if debug_disposal_code:
print("NameNode.generate_assignment_code:") print("NameNode.generate_assignment_code:")
print("...generating post-assignment code for %s" % rhs) print("...generating post-assignment code for %s" % rhs)
...@@ -3981,16 +3888,6 @@ class IndexNode(_IndexingBaseNode): ...@@ -3981,16 +3888,6 @@ class IndexNode(_IndexingBaseNode):
error(self.pos, "Invalid index type '%s'" % self.index.type) error(self.pos, "Invalid index type '%s'" % self.index.type)
return self return self
def ensure_base_and_index_locked(self, env, func_type):
if func_type.is_const_method:
self.base.ensure_rhs_locked(env, is_dereferenced=True)
else:
self.base.ensure_lhs_locked(env, is_dereferenced=True)
if func_type.args[0].type.is_const:
self.index.ensure_rhs_locked(env, is_dereferenced=True)
else:
self.index.ensure_lhs_locked(env, is_dereferenced=True)
def analyse_as_cpp(self, env, setting): def analyse_as_cpp(self, env, setting):
base_type = self.base.type base_type = self.base.type
function = env.lookup_operator("[]", [self.base, self.index]) function = env.lookup_operator("[]", [self.base, self.index])
...@@ -4009,7 +3906,6 @@ class IndexNode(_IndexingBaseNode): ...@@ -4009,7 +3906,6 @@ class IndexNode(_IndexingBaseNode):
self.is_temp = True self.is_temp = True
if self.exception_value is None: if self.exception_value is None:
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp")) env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
self.ensure_base_and_index_locked(env, func_type)
self.index = self.index.coerce_to(func_type.args[0].type, env) self.index = self.index.coerce_to(func_type.args[0].type, env)
self.type = func_type.return_type self.type = func_type.return_type
if setting and not func_type.return_type.is_reference: if setting and not func_type.return_type.is_reference:
...@@ -4065,7 +3961,6 @@ class IndexNode(_IndexingBaseNode): ...@@ -4065,7 +3961,6 @@ class IndexNode(_IndexingBaseNode):
self.result_code = "<error>" self.result_code = "<error>"
return self return self
self.type = setitem_type.args[1].type self.type = setitem_type.args[1].type
self.ensure_base_and_index_locked(env, func_type)
return self return self
def analyse_as_c_function(self, env): def analyse_as_c_function(self, env):
...@@ -5164,13 +5059,6 @@ class SliceIndexNode(ExprNode): ...@@ -5164,13 +5059,6 @@ class SliceIndexNode(ExprNode):
index = not_a_constant index = not_a_constant
return self.base.inferable_item_node(index) return self.base.inferable_item_node(index)
def ensure_subexpr_lhs_locked(self, env):
self.base.ensure_lhs_locked(env)
if self.start:
self.start.ensure_rhs_locked(env)
elif self.stop:
self.stop.ensure_rhs_locked(env)
def may_be_none(self): def may_be_none(self):
base_type = self.base.type base_type = self.base.type
if base_type: if base_type:
...@@ -5926,9 +5814,6 @@ class SimpleCallNode(CallNode): ...@@ -5926,9 +5814,6 @@ class SimpleCallNode(CallNode):
# analysed bool used internally # analysed bool used internally
# overflowcheck bool used internally # overflowcheck bool used internally
# explicit_cpp_self bool used internally # explicit_cpp_self bool used internally
# rlocked bool used internally
# wlocked bool used internally
# tracked_state bool used internally
# needs_deref bool used internally # needs_deref bool used internally
subexprs = ['self', 'coerced_self', 'function', 'args', 'arg_tuple'] subexprs = ['self', 'coerced_self', 'function', 'args', 'arg_tuple']
...@@ -5942,9 +5827,6 @@ class SimpleCallNode(CallNode): ...@@ -5942,9 +5827,6 @@ class SimpleCallNode(CallNode):
analysed = False analysed = False
overflowcheck = False overflowcheck = False
explicit_cpp_self = None explicit_cpp_self = None
rlocked = False
wlocked = False
tracked_state = True # Something random, anything that is not None
needs_deref = False needs_deref = False
def compile_time_value(self, denv): def compile_time_value(self, denv):
...@@ -6276,58 +6158,6 @@ class SimpleCallNode(CallNode): ...@@ -6276,58 +6158,6 @@ class SimpleCallNode(CallNode):
self.overflowcheck = env.directives['overflowcheck'] self.overflowcheck = env.directives['overflowcheck']
def ensure_subexpr_rhs_locked(self, env):
func_type = self.function_type()
if func_type.is_pyobject:
self.arg_tuple.ensure_rhs_locked(env)
elif func_type.is_cfunction:
max_nargs = len(func_type.args)
actual_nargs = len(self.args)
# Check for args locks: read-lock for const args, write-locks for other
for i in range(min(max_nargs, actual_nargs)):
formal_arg = func_type.args[i]
actual_arg = self.args[i]
deref_flag = formal_arg.type.is_cyp_class
wlock_flag = deref_flag and not formal_arg.type.is_const
if wlock_flag:
actual_arg.ensure_lhs_locked(env, is_dereferenced = True)
else:
actual_arg.ensure_rhs_locked(env, is_dereferenced = deref_flag)
# XXX - Should we do something in a pyfunc case ?
if func_type.is_static_method:
pass # no need to lock the object on which a static method is called
elif func_type.is_const_method:
self.function.ensure_rhs_locked(env)
else:
self.function.ensure_lhs_locked(env)
def ensure_subexpr_lhs_locked(self, env):
# This may be seen a bit weird
# In fact, the only thing that changes between lhs & rhs analysis for function
# calls is that the result should be locked, but the subexpr analysis is
# exactly the same, because the result is not explicitely tied to args
# and base object (in case of a method call).
self.ensure_subexpr_rhs_locked(env)
def is_lhs_locked(self, env):
return self.wlocked
def is_rhs_locked(self, env):
return self.rlocked
def set_autorlock(self, env):
self.rlocked = True
def set_autowlock(self, env):
self.wlocked = True
def needs_rlock(self):
return self.rlocked
def needs_wlock(self):
return self.wlocked
def calculate_result_code(self): def calculate_result_code(self):
return self.c_call_code() return self.c_call_code()
...@@ -6505,10 +6335,6 @@ class SimpleCallNode(CallNode): ...@@ -6505,10 +6335,6 @@ class SimpleCallNode(CallNode):
else: else:
goto_error = "" goto_error = ""
code.putln("%s%s; %s" % (lhs, rhs, goto_error)) code.putln("%s%s; %s" % (lhs, rhs, goto_error))
if self.wlocked:
code.putln("Cy_WLOCK(%s);" % self.result())
elif self.rlocked:
code.putln("Cy_RLOCK(%s);" % self.result())
if self.type.is_pyobject and self.result(): if self.type.is_pyobject and self.result():
self.generate_gotref(code) self.generate_gotref(code)
elif self.type.is_cyp_class and self.result(): elif self.type.is_cyp_class and self.result():
...@@ -6516,10 +6342,6 @@ class SimpleCallNode(CallNode): ...@@ -6516,10 +6342,6 @@ class SimpleCallNode(CallNode):
if self.has_optional_args: if self.has_optional_args:
code.funcstate.release_temp(self.opt_arg_struct) code.funcstate.release_temp(self.opt_arg_struct)
def generate_disposal_code(self, code):
if self.wlocked or self.rlocked:
code.putln("Cy_UNLOCK(%s);" % self.result())
ExprNode.generate_disposal_code(self, code)
class NumPyMethodCallNode(ExprNode): class NumPyMethodCallNode(ExprNode):
# Pythran call to a NumPy function or method. # Pythran call to a NumPy function or method.
...@@ -7648,22 +7470,6 @@ class AttributeNode(ExprNode): ...@@ -7648,22 +7470,6 @@ class AttributeNode(ExprNode):
gil_message = "Accessing Python attribute" gil_message = "Accessing Python attribute"
def ensure_subexpr_rhs_locked(self, env):
if not self.entry:
self.obj.ensure_lhs_locked(env, is_dereferenced = True)
elif self.entry.is_cfunction:
if self.entry.type.is_static_method:
pass
elif self.entry.type.is_const_method:
self.obj.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.obj.ensure_lhs_locked(env, is_dereferenced = True)
else:
self.obj.ensure_rhs_locked(env, is_dereferenced = True)
def ensure_subexpr_lhs_locked(self, env):
self.obj.ensure_lhs_locked(env, is_dereferenced = True)
def is_cimported_module_without_shadow(self, env): def is_cimported_module_without_shadow(self, env):
return self.obj.is_cimported_module_without_shadow(env) return self.obj.is_cimported_module_without_shadow(env)
...@@ -7806,10 +7612,6 @@ class AttributeNode(ExprNode): ...@@ -7806,10 +7612,6 @@ class AttributeNode(ExprNode):
rhs.free_temps(code) rhs.free_temps(code)
else: else:
select_code = self.result() select_code = self.result()
# XXX - Greater to have a getter, right ?
tracked_state = self.tracked_state
if self.is_autolock() and tracked_state and (tracked_state.needs_rlock or tracked_state.needs_wlock):
code.putln("Cy_UNLOCK(%s);" % select_code)
if self.type.is_pyobject and self.use_managed_ref: if self.type.is_pyobject and self.use_managed_ref:
rhs.make_owned_reference(code) rhs.make_owned_reference(code)
rhs.generate_giveref(code) rhs.generate_giveref(code)
...@@ -7831,11 +7633,6 @@ class AttributeNode(ExprNode): ...@@ -7831,11 +7633,6 @@ class AttributeNode(ExprNode):
select_code, select_code,
rhs.move_result_rhs_as(self.ctype()))) rhs.move_result_rhs_as(self.ctype())))
#rhs.result())) #rhs.result()))
if self.is_autolock():
if self.needs_wlock():
code.putln("Cy_WLOCK(%s);" % select_code)
elif self.needs_rlock():
code.putln("Cy_RLOCK(%s);" % select_code)
rhs.generate_post_assignment_code(code) rhs.generate_post_assignment_code(code)
rhs.free_temps(code) rhs.free_temps(code)
...@@ -10552,6 +10349,7 @@ compile_time_unary_operators = { ...@@ -10552,6 +10349,7 @@ compile_time_unary_operators = {
class UnopNode(ExprNode): class UnopNode(ExprNode):
# operator string # operator string
# operand ExprNode # operand ExprNode
# op_func_type CFuncType or None
# #
# Processing during analyse_expressions phase: # Processing during analyse_expressions phase:
# #
...@@ -10563,6 +10361,7 @@ class UnopNode(ExprNode): ...@@ -10563,6 +10361,7 @@ class UnopNode(ExprNode):
subexprs = ['operand'] subexprs = ['operand']
infix = True infix = True
op_func_type = None
def calculate_constant_result(self): def calculate_constant_result(self):
func = compile_time_unary_operators[self.operator] func = compile_time_unary_operators[self.operator]
...@@ -10679,6 +10478,7 @@ class UnopNode(ExprNode): ...@@ -10679,6 +10478,7 @@ class UnopNode(ExprNode):
self.type_error() self.type_error()
return return
if entry: if entry:
self.op_func_type = entry.type
self.exception_check = entry.type.exception_check self.exception_check = entry.type.exception_check
self.exception_value = entry.type.exception_value self.exception_value = entry.type.exception_value
if self.exception_check == '+': if self.exception_check == '+':
...@@ -10934,6 +10734,7 @@ class TypecastNode(ExprNode): ...@@ -10934,6 +10734,7 @@ class TypecastNode(ExprNode):
# declarator CDeclaratorNode # declarator CDeclaratorNode
# typecheck boolean # typecheck boolean
# overloaded boolean # overloaded boolean
# op_func_type CFuncType or None
# #
# If used from a transform, one can if wanted specify the attribute # If used from a transform, one can if wanted specify the attribute
# "type" directly and leave base_type and declarator to None # "type" directly and leave base_type and declarator to None
...@@ -10941,6 +10742,7 @@ class TypecastNode(ExprNode): ...@@ -10941,6 +10742,7 @@ class TypecastNode(ExprNode):
subexprs = ['operand'] subexprs = ['operand']
base_type = declarator = type = None base_type = declarator = type = None
overloaded = False overloaded = False
op_func_type = None
def type_dependencies(self, env): def type_dependencies(self, env):
return () return ()
...@@ -11010,6 +10812,8 @@ class TypecastNode(ExprNode): ...@@ -11010,6 +10812,8 @@ class TypecastNode(ExprNode):
operator = 'operator ' + self.type.declaration_code('') operator = 'operator ' + self.type.declaration_code('')
entry = self.operand.type.scope.lookup_here(operator) entry = self.operand.type.scope.lookup_here(operator)
self.overloaded = entry is not None self.overloaded = entry is not None
if entry:
self.op_func_type = entry.type
if self.type.is_cyp_class: if self.type.is_cyp_class:
self.is_temp = True self.is_temp = True
if self.type.is_ptr and self.type.base_type.is_cfunction and self.type.base_type.nogil: if self.type.is_ptr and self.type.base_type.is_cfunction and self.type.base_type.nogil:
...@@ -11584,6 +11388,7 @@ class BinopNode(ExprNode): ...@@ -11584,6 +11388,7 @@ class BinopNode(ExprNode):
# operator string # operator string
# operand1 ExprNode # operand1 ExprNode
# operand2 ExprNode # operand2 ExprNode
# op_func_type CFuncType or None
# #
# Processing during analyse_expressions phase: # Processing during analyse_expressions phase:
# #
...@@ -11595,6 +11400,7 @@ class BinopNode(ExprNode): ...@@ -11595,6 +11400,7 @@ class BinopNode(ExprNode):
subexprs = ['operand1', 'operand2'] subexprs = ['operand1', 'operand2']
inplace = False inplace = False
op_func_type = None
def calculate_constant_result(self): def calculate_constant_result(self):
func = compile_time_binary_operators[self.operator] func = compile_time_binary_operators[self.operator]
...@@ -11704,25 +11510,10 @@ class BinopNode(ExprNode): ...@@ -11704,25 +11510,10 @@ class BinopNode(ExprNode):
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp")) env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
if func_type.is_ptr: if func_type.is_ptr:
func_type = func_type.base_type func_type = func_type.base_type
self.op_func_type = func_type
if len(func_type.args) == 1: if len(func_type.args) == 1:
if func_type.is_const_method:
self.operand1.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.operand1.ensure_lhs_locked(env, is_dereferenced = True)
if func_type.args[0].type.is_const:
self.operand2.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.operand2.ensure_lhs_locked(env, is_dereferenced = True)
self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env) self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env)
else: else:
if func_type.args[0].type.is_const:
self.operand1.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.operand1.ensure_lhs_locked(env, is_dereferenced = True)
if func_type.args[1].type.is_const:
self.operand2.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.operand2.ensure_lhs_locked(env, is_dereferenced = True)
self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env) self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env) self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.type = func_type.return_type self.type = func_type.return_type
...@@ -13200,6 +12991,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -13200,6 +12991,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
# operand1 ExprNode # operand1 ExprNode
# operand2 ExprNode # operand2 ExprNode
# cascade CascadedCmpNode # cascade CascadedCmpNode
# cmp_func_type CFuncType or None
# We don't use the subexprs mechanism, because # We don't use the subexprs mechanism, because
# things here are too complicated for it to handle. # things here are too complicated for it to handle.
...@@ -13211,6 +13003,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -13211,6 +13003,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
cascade = None cascade = None
coerced_operand2 = None coerced_operand2 = None
is_memslice_nonecheck = False is_memslice_nonecheck = False
cmp_func_type = None
def infer_type(self, env): def infer_type(self, env):
type1 = self.operand1.infer_type(env) type1 = self.operand1.infer_type(env)
...@@ -13243,12 +13036,6 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -13243,12 +13036,6 @@ class PrimaryCmpNode(ExprNode, CmpNode):
operand1 = self.operand1.compile_time_value(denv) operand1 = self.operand1.compile_time_value(denv)
return self.cascaded_compile_time_value(operand1, denv) return self.cascaded_compile_time_value(operand1, denv)
#def check_rhs_locked(self, env):
# self.operand1.check_rhs_locked(env)
# self.operand2.check_rhs_locked(env)
# if self.cascade:
# self.cascade.check_rhs_locked(env)
def analyse_types(self, env): def analyse_types(self, env):
self.operand1 = self.operand1.analyse_types(env) self.operand1 = self.operand1.analyse_types(env)
self.operand2 = self.operand2.analyse_types(env) self.operand2 = self.operand2.analyse_types(env)
...@@ -13390,6 +13177,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -13390,6 +13177,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
func_type = entry.type func_type = entry.type
if func_type.is_ptr: if func_type.is_ptr:
func_type = func_type.base_type func_type = func_type.base_type
self.cmp_func_type = func_type
self.exception_check = func_type.exception_check self.exception_check = func_type.exception_check
self.exception_value = func_type.exception_value self.exception_value = func_type.exception_value
if self.exception_check == '+': if self.exception_check == '+':
...@@ -13397,24 +13185,8 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -13397,24 +13185,8 @@ class PrimaryCmpNode(ExprNode, CmpNode):
if self.exception_value is None: if self.exception_value is None:
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp")) env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
if len(func_type.args) == 1: if len(func_type.args) == 1:
if func_type.is_const_method:
self.operand1.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.operand1.ensure_lhs_locked(env, is_dereferenced = True)
if func_type.args[0].type.is_const:
self.operand2.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.operand2.ensure_lhs_locked(env, is_dereferenced = True)
self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env) self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env)
else: else:
if func_type.args[0].type.is_const:
self.operand1.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.operand1.ensure_lhs_locked(env, is_dereferenced = True)
if func_type.args[1].type.is_const:
self.operand2.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.operand2.ensure_lhs_locked(env, is_dereferenced = True)
self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env) self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env) self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.type = func_type.return_type self.type = func_type.return_type
...@@ -13460,12 +13232,6 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -13460,12 +13232,6 @@ class PrimaryCmpNode(ExprNode, CmpNode):
else: else:
return self.operand1.check_const() and self.operand2.check_const() return self.operand1.check_const() and self.operand2.check_const()
def ensure_subexpr_rhs_locked(self, env):
self.operand1.ensure_rhs_locked(env)
self.operand2.ensure_rhs_locked(env)
if self.cascade:
self.cascade.ensure_rhs_locked(env)
def calculate_result_code(self): def calculate_result_code(self):
operand1, operand2 = self.operand1, self.operand2 operand1, operand2 = self.operand1, self.operand2
if operand1.type.is_complex: if operand1.type.is_complex:
...@@ -13577,16 +13343,6 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -13577,16 +13343,6 @@ class CascadedCmpNode(Node, CmpNode):
return self.constant_result is not constant_value_not_set and \ return self.constant_result is not constant_value_not_set and \
self.constant_result is not not_a_constant self.constant_result is not not_a_constant
#def check_rhs_locked(self, env):
# self.operand2.check_rhs_locked(env)
# if self.cascade:
# self.cascade.check_rhs_locked(env)
def ensure_rhs_locked(self, env):
self.operand2.ensure_rhs_locked(env)
if self.cascade:
self.cascade.ensure_rhs_locked(env)
def analyse_types(self, env): def analyse_types(self, env):
self.operand2 = self.operand2.analyse_types(env) self.operand2 = self.operand2.analyse_types(env)
if self.cascade: if self.cascade:
...@@ -14190,7 +13946,7 @@ class CoerceToTempNode(CoercionNode): ...@@ -14190,7 +13946,7 @@ class CoerceToTempNode(CoercionNode):
# to be stored in a temporary. It is only used if the # to be stored in a temporary. It is only used if the
# argument node's result is not already in a temporary. # argument node's result is not already in a temporary.
def __init__(self, arg, env): def __init__(self, arg, env=None):
CoercionNode.__init__(self, arg) CoercionNode.__init__(self, arg)
self.type = self.arg.type.as_argument_type() self.type = self.arg.type.as_argument_type()
self.constant_result = self.arg.constant_result self.constant_result = self.arg.constant_result
...@@ -14207,14 +13963,6 @@ class CoerceToTempNode(CoercionNode): ...@@ -14207,14 +13963,6 @@ class CoerceToTempNode(CoercionNode):
def may_be_none(self): def may_be_none(self):
return self.arg.may_be_none() return self.arg.may_be_none()
def ensure_rhs_locked(self, env, is_dereferenced = False):
self.arg.ensure_rhs_locked(env, is_dereferenced)
self.tracked_state = self.arg.tracked_state
def ensure_lhs_locked(self, env, is_dereferenced = False):
self.arg.ensure_lhs_locked(env, is_dereferenced)
self.tracked_state = self.arg.tracked_state
def coerce_to_boolean(self, env): def coerce_to_boolean(self, env):
self.arg = self.arg.coerce_to_boolean(env) self.arg = self.arg.coerce_to_boolean(env)
if self.arg.is_simple(): if self.arg.is_simple():
...@@ -14238,6 +13986,27 @@ class CoerceToTempNode(CoercionNode): ...@@ -14238,6 +13986,27 @@ class CoerceToTempNode(CoercionNode):
code.put_incref_memoryviewslice(self.result(), self.type, code.put_incref_memoryviewslice(self.result(), self.type,
have_gil=not self.in_nogil_context) have_gil=not self.in_nogil_context)
class CoerceToLockedTempNode(CoerceToTempNode):
# rlock_only boolean
def __init__(self, arg, env=None, rlock_only=False):
self.rlock_only = rlock_only
if isinstance(arg, CoerceToTempNode):
arg = arg.arg
super(CoerceToLockedTempNode, self).__init__(arg, env)
def generate_result_code(self, code):
super(CoerceToLockedTempNode, self).generate_result_code(code)
if self.rlock_only:
code.putln("Cy_RLOCK(%s);" % self.result())
else:
code.putln("Cy_WLOCK(%s);" % self.result())
def generate_disposal_code(self, code):
code.putln("Cy_UNLOCK(%s);" % self.result())
super(CoerceToLockedTempNode, self).generate_disposal_code(code)
class ProxyNode(CoercionNode): class ProxyNode(CoercionNode):
""" """
A node that should not be replaced by transforms or other means, A node that should not be replaced by transforms or other means,
......
...@@ -2136,7 +2136,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -2136,7 +2136,7 @@ class FuncDefNode(StatNode, BlockNode):
code.put_var_incref_memoryviewslice(entry, code.put_var_incref_memoryviewslice(entry,
have_gil=code.funcstate.gil_owned) have_gil=code.funcstate.gil_owned)
# We have to Cy_INCREF the nogil classes (ccdef'ed ones) # We have to Cy_INCREF the nogil classes (ccdef'ed ones)
elif entry.type.is_cyp_class and len(entry.cf_assignments) > 1: elif entry.type.is_cyp_class and len(entry.cf_assignments) > 1 and not entry.is_self_arg:
code.put_cyincref(entry.cname) code.put_cyincref(entry.cname)
for entry in lenv.var_entries: for entry in lenv.var_entries:
if entry.is_arg and entry.cf_is_reassigned and not entry.in_closure: if entry.is_arg and entry.cf_is_reassigned and not entry.in_closure:
...@@ -2161,15 +2161,6 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -2161,15 +2161,6 @@ class FuncDefNode(StatNode, BlockNode):
code.put_release_ensured_gil() code.put_release_ensured_gil()
code.funcstate.gil_owned = False code.funcstate.gil_owned = False
for node in lenv.autolocked_nodes:
if node.entry.is_variable and not node.entry.is_local and (node.needs_wlock() or node.needs_rlock()):
node_result = node.result()
code.putln("if (%s != NULL)" % node_result)
if node.needs_wlock():
code.putln(" Cy_WLOCK(%s);" % node_result)
elif node.needs_rlock():
code.putln(" Cy_RLOCK(%s);" % node_result)
# ------------------------- # -------------------------
# ----- Function body ----- # ----- Function body -----
# ------------------------- # -------------------------
...@@ -2325,23 +2316,6 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -2325,23 +2316,6 @@ class FuncDefNode(StatNode, BlockNode):
align_error_path_gil_to_success_path() align_error_path_gil_to_success_path()
code.put_label(code.return_from_error_cleanup_label) code.put_label(code.return_from_error_cleanup_label)
for node in reversed(lenv.autolocked_nodes):
# We iterate in the reverse order to properly unlock
# nested locked objects (aka most nested first).
# For example, if we have the following situation:
# obj.sub_obj.attr = some_value
# If obj and sub_obj are both of autolocked types,
# the obj (name)node is declared before the sub_obj (attribute)node.
# If we unlock first obj, another thread could immediately acquire
# a write lock and change where sub_obj points to.
# We would then try to unlock the new sub_obj reference,
# which leads to a dangling lock on the previous reference
# (and attempt to unlock a non-locked ref).
if node.needs_wlock() or node.needs_rlock():
code.putln("Cy_UNLOCK(%s);" % node.result())
for entry in lenv.var_entries: for entry in lenv.var_entries:
if not entry.used or entry.in_closure: if not entry.used or entry.in_closure:
continue continue
...@@ -2371,7 +2345,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -2371,7 +2345,7 @@ class FuncDefNode(StatNode, BlockNode):
continue continue
if entry.type.needs_refcounting: if entry.type.needs_refcounting:
assure_gil('success') assure_gil('success')
if entry.type.is_cyp_class: if entry.type.is_cyp_class and not entry.is_self_arg:
# We must check for NULL because it is possible to have # We must check for NULL because it is possible to have
# NULL as a valid cypclass (with a typecast) # NULL as a valid cypclass (with a typecast)
code.put_cyxdecref(entry.cname) code.put_cyxdecref(entry.cname)
...@@ -2741,11 +2715,8 @@ class CFuncDefNode(FuncDefNode): ...@@ -2741,11 +2715,8 @@ class CFuncDefNode(FuncDefNode):
_cname = "this" _cname = "this"
entry = self.local_scope.declare(_name, _cname, _type, _pos, 'private') entry = self.local_scope.declare(_name, _cname, _type, _pos, 'private')
entry.is_variable = 1 entry.is_variable = 1
# Even if it is checklock it should be OK to mess with self without locking entry.is_self_arg = 1
self_locking_state = self.local_scope.declare_tracked(entry) self.local_scope.arg_entries.append(entry)
self_locking_state.is_rlocked = self.is_const_method
self_locking_state.is_wlocked = not self.is_const_method
def declare_cpdef_wrapper(self, env): def declare_cpdef_wrapper(self, env):
if self.overridable: if self.overridable:
...@@ -5787,9 +5758,6 @@ class ExprStatNode(StatNode): ...@@ -5787,9 +5758,6 @@ class ExprStatNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.expr.result_is_used = False # hint that .result() may safely be left empty self.expr.result_is_used = False # hint that .result() may safely be left empty
self.expr = self.expr.analyse_expressions(env) self.expr = self.expr.analyse_expressions(env)
from . import ExprNodes
if isinstance(self.expr, ExprNodes.ExprNode):
self.expr.ensure_rhs_locked(env)
# Repeat in case of node replacement. # Repeat in case of node replacement.
self.expr.result_is_used = False # hint that .result() may safely be left empty self.expr.result_is_used = False # hint that .result() may safely be left empty
return self return self
...@@ -5861,9 +5829,6 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5861,9 +5829,6 @@ class SingleAssignmentNode(AssignmentNode):
first = False first = False
is_overloaded_assignment = False is_overloaded_assignment = False
declaration_only = False declaration_only = False
needs_unlock = False
needs_rlock = False
needs_wlock = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
from . import ExprNodes from . import ExprNodes
...@@ -5960,14 +5925,7 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5960,14 +5925,7 @@ class SingleAssignmentNode(AssignmentNode):
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env) self.lhs.gil_assignment_check(env)
if self.lhs.is_subscript and self.lhs.base.type.is_cyp_class:
if self.lhs.type.is_const: # type of the formal 'value' argument of __setitem__
self.rhs.ensure_rhs_locked(env, is_dereferenced = True)
else:
self.rhs.ensure_lhs_locked(env, is_dereferenced = True)
else:
self.rhs.ensure_rhs_locked(env)
self.lhs.ensure_lhs_locked(env, is_top_lhs = True)
unrolled_assignment = self.unroll_lhs(env) unrolled_assignment = self.unroll_lhs(env)
if unrolled_assignment: if unrolled_assignment:
return unrolled_assignment return unrolled_assignment
...@@ -6191,11 +6149,10 @@ class CascadedAssignmentNode(AssignmentNode): ...@@ -6191,11 +6149,10 @@ class CascadedAssignmentNode(AssignmentNode):
for i, lhs in enumerate(self.lhs_list): for i, lhs in enumerate(self.lhs_list):
lhs = self.lhs_list[i] = lhs.analyse_target_types(env) lhs = self.lhs_list[i] = lhs.analyse_target_types(env)
lhs.gil_assignment_check(env) lhs.gil_assignment_check(env)
lhs.ensure_lhs_locked(env, is_top_lhs = True)
lhs_types.add(lhs.type) lhs_types.add(lhs.type)
rhs = self.rhs.analyse_types(env) rhs = self.rhs.analyse_types(env)
rhs.ensure_rhs_locked(env)
# common special case: only one type needed on the LHS => coerce only once # common special case: only one type needed on the LHS => coerce only once
if len(lhs_types) == 1: if len(lhs_types) == 1:
# Avoid coercion for overloaded assignment operators. # Avoid coercion for overloaded assignment operators.
...@@ -6336,9 +6293,7 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -6336,9 +6293,7 @@ class InPlaceAssignmentNode(AssignmentNode):
def analyse_types(self, env): def analyse_types(self, env):
self.rhs = self.rhs.analyse_types(env) self.rhs = self.rhs.analyse_types(env)
self.rhs.ensure_rhs_locked(env)
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
self.rhs.ensure_lhs_locked(env)
# When assigning to a fully indexed buffer or memoryview, coerce the rhs # When assigning to a fully indexed buffer or memoryview, coerce the rhs
if self.lhs.is_memview_index or self.lhs.is_buffer_access: if self.lhs.is_memview_index or self.lhs.is_buffer_access:
...@@ -6401,7 +6356,6 @@ class PrintStatNode(StatNode): ...@@ -6401,7 +6356,6 @@ class PrintStatNode(StatNode):
stream = self.stream.analyse_expressions(env) stream = self.stream.analyse_expressions(env)
self.stream = stream.coerce_to_pyobject(env) self.stream = stream.coerce_to_pyobject(env)
arg_tuple = self.arg_tuple.analyse_expressions(env) arg_tuple = self.arg_tuple.analyse_expressions(env)
arg_tuple.ensure_rhs_locked(env)
self.arg_tuple = arg_tuple.coerce_to_pyobject(env) self.arg_tuple = arg_tuple.coerce_to_pyobject(env)
env.use_utility_code(printing_utility_code) env.use_utility_code(printing_utility_code)
if len(self.arg_tuple.args) == 1 and self.append_newline: if len(self.arg_tuple.args) == 1 and self.append_newline:
...@@ -6503,7 +6457,6 @@ class DelStatNode(StatNode): ...@@ -6503,7 +6457,6 @@ class DelStatNode(StatNode):
child_attrs = ["args"] child_attrs = ["args"]
ignore_nonexisting = False ignore_nonexisting = False
was_locked = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
for arg in self.args: for arg in self.args:
...@@ -6515,14 +6468,11 @@ class DelStatNode(StatNode): ...@@ -6515,14 +6468,11 @@ class DelStatNode(StatNode):
arg = self.args[i] = arg.analyse_del_expression(env) arg = self.args[i] = arg.analyse_del_expression(env)
else: else:
arg = self.args[i] = arg.analyse_target_expression(env, None) arg = self.args[i] = arg.analyse_target_expression(env, None)
arg.ensure_lhs_locked(env)
if arg.type.is_pyobject or (arg.is_name and arg.type.is_memoryviewslice): if arg.type.is_pyobject or (arg.is_name and arg.type.is_memoryviewslice):
if arg.is_name and arg.entry.is_cglobal: if arg.is_name and arg.entry.is_cglobal:
error(arg.pos, "Deletion of global C variable") error(arg.pos, "Deletion of global C variable")
elif arg.type.is_ptr and arg.type.base_type.is_cpp_class or arg.type.is_cyp_class: elif arg.type.is_ptr and arg.type.base_type.is_cpp_class or arg.type.is_cyp_class:
self.cpp_check(env) self.cpp_check(env)
if arg.type.is_cyp_class:
self.was_locked = arg.needs_rlock() or arg.needs_wlock()
elif arg.type.is_cpp_class: elif arg.type.is_cpp_class:
error(arg.pos, "Deletion of non-heap C++ object") error(arg.pos, "Deletion of non-heap C++ object")
elif arg.is_subscript and arg.base.type is Builtin.bytearray_type: elif arg.is_subscript and arg.base.type is Builtin.bytearray_type:
...@@ -6555,11 +6505,8 @@ class DelStatNode(StatNode): ...@@ -6555,11 +6505,8 @@ class DelStatNode(StatNode):
arg.free_temps(code) arg.free_temps(code)
elif arg.type.is_cyp_class: elif arg.type.is_cyp_class:
arg.generate_evaluation_code(code) arg.generate_evaluation_code(code)
if arg.type.lock_mode == "autolock" and self.was_locked: code.putln("Cy_DECREF(%s);" % arg.result())
code.putln("Cy_UNLOCK(%s);" % arg.result()) code.putln("%s = NULL;" % arg.result())
else:
code.putln("Cy_DECREF(%s);" % arg.result())
code.putln("%s = NULL;" % arg.result())
arg.generate_disposal_code(code) arg.generate_disposal_code(code)
# else error reported earlier # else error reported earlier
...@@ -6648,7 +6595,6 @@ class ReturnStatNode(StatNode): ...@@ -6648,7 +6595,6 @@ class ReturnStatNode(StatNode):
if self.in_async_gen: if self.in_async_gen:
error(self.pos, "Return with value in async generator") error(self.pos, "Return with value in async generator")
self.value = self.value.analyse_types(env) self.value = self.value.analyse_types(env)
self.value.ensure_rhs_locked(env)
if return_type.is_void or return_type.is_returncode: if return_type.is_void or return_type.is_returncode:
error(self.value.pos, "Return with value in void function") error(self.value.pos, "Return with value in void function")
else: else:
...@@ -6993,7 +6939,6 @@ class IfClauseNode(Node): ...@@ -6993,7 +6939,6 @@ class IfClauseNode(Node):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.condition = self.condition.analyse_temp_boolean_expression(env) self.condition = self.condition.analyse_temp_boolean_expression(env)
self.condition.ensure_rhs_locked(env)
self.body = self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
return self return self
...@@ -7125,7 +7070,6 @@ class WhileStatNode(LoopNode, StatNode): ...@@ -7125,7 +7070,6 @@ class WhileStatNode(LoopNode, StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
if self.condition: if self.condition:
self.condition = self.condition.analyse_temp_boolean_expression(env) self.condition = self.condition.analyse_temp_boolean_expression(env)
self.condition.ensure_rhs_locked(env)
self.body = self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
if self.else_clause: if self.else_clause:
self.else_clause = self.else_clause.analyse_expressions(env) self.else_clause = self.else_clause.analyse_expressions(env)
...@@ -7357,9 +7301,7 @@ class _ForInStatNode(LoopNode, StatNode): ...@@ -7357,9 +7301,7 @@ class _ForInStatNode(LoopNode, StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.target = self.target.analyse_target_types(env) self.target = self.target.analyse_target_types(env)
self.target.ensure_lhs_locked(env)
self.iterator = self.iterator.analyse_expressions(env) self.iterator = self.iterator.analyse_expressions(env)
self.iterator.ensure_rhs_locked(env)
self._create_item_node() # must rewrap self.item after analysis self._create_item_node() # must rewrap self.item after analysis
self.item = self.item.analyse_expressions(env) self.item = self.item.analyse_expressions(env)
if (not self.is_async and if (not self.is_async and
...@@ -7501,17 +7443,13 @@ class ForFromStatNode(LoopNode, StatNode): ...@@ -7501,17 +7443,13 @@ class ForFromStatNode(LoopNode, StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
from . import ExprNodes from . import ExprNodes
self.target = self.target.analyse_target_types(env) self.target = self.target.analyse_target_types(env)
self.target.ensure_lhs_locked(env)
self.bound1 = self.bound1.analyse_types(env) self.bound1 = self.bound1.analyse_types(env)
self.bound1.ensure_rhs_locked(env)
self.bound2 = self.bound2.analyse_types(env) self.bound2 = self.bound2.analyse_types(env)
self.bound2.ensure_rhs_locked(env)
if self.step is not None: if self.step is not None:
if isinstance(self.step, ExprNodes.UnaryMinusNode): if isinstance(self.step, ExprNodes.UnaryMinusNode):
warning(self.step.pos, "Probable infinite loop in for-from-by statement. " warning(self.step.pos, "Probable infinite loop in for-from-by statement. "
"Consider switching the directions of the relations.", 2) "Consider switching the directions of the relations.", 2)
self.step = self.step.analyse_types(env) self.step = self.step.analyse_types(env)
self.step.ensure_rhs_locked(env)
self.set_up_loop(env) self.set_up_loop(env)
target_type = self.target.type target_type = self.target.type
...@@ -8622,54 +8560,7 @@ class LockCypclassNode(StatNode): ...@@ -8622,54 +8560,7 @@ class LockCypclassNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.obj = self.obj.analyse_types(env) self.obj = self.obj.analyse_types(env)
self.obj.ensure_rhs_locked(env)
if not hasattr(self.obj, 'entry'):
error(self.pos, "The (un)locking target has no entry")
return
if not self.obj.type.is_cyp_class:
error(self.pos, "Cannot (un)lock a non-cypclass variable !")
return
# FIXME: this is a bit redundant here
self.obj.get_tracked_state(env)
is_rlocked = self.obj.is_rhs_locked(env)
is_wlocked = self.obj.is_lhs_locked(env)
if self.obj.type.lock_mode != "nolock":
if self.state == "unclocked" and not (is_rlocked or is_wlocked):
error(self.pos, "Cannot unlock an already unlocked object !")
elif self.state == "rlocked" and is_rlocked:
error(self.pos, "Double read lock !")
elif self.state == "wlocked" and is_wlocked:
error(self.pos, "Double write lock !")
# We need to save states because in case of 'with unlocked' statement,
# we must know which lock has to be restored after the with body.
self.was_wlocked = is_wlocked
self.was_rlocked = is_rlocked and not is_wlocked
tracked_state = self.obj.tracked_state
if self.state == "rlocked":
tracked_state.is_rlocked = True
tracked_state.is_wlocked = False
elif self.state == "wlocked":
tracked_state.is_rlocked = False
tracked_state.is_wlocked = True
else:
tracked_state.is_rlocked = False
tracked_state.is_wlocked = False
self.body = self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
#self.obj.entry.is_rlocked = self.was_rlocked
#self.obj.entry.is_wlocked = self.was_wlocked
tracked_state.is_rlocked = self.was_rlocked
tracked_state.is_wlocked = self.was_wlocked
return self return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
......
...@@ -141,7 +141,7 @@ def inject_utility_code_stage_factory(context): ...@@ -141,7 +141,7 @@ def inject_utility_code_stage_factory(context):
def create_pipeline(context, mode, exclude_classes=()): def create_pipeline(context, mode, exclude_classes=()):
assert mode in ('pyx', 'py', 'pxd') assert mode in ('pyx', 'py', 'pxd')
from .Visitor import PrintTree from .Visitor import PrintTree
from .CypclassTransforms import CypclassWrapperInjection from .CypclassTransforms import CypclassWrapperInjection, CypclassLockTransform
from .ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse from .ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
from .ParseTreeTransforms import ForwardDeclareTypes, InjectGilHandling, AnalyseDeclarationsTransform from .ParseTreeTransforms import ForwardDeclareTypes, InjectGilHandling, AnalyseDeclarationsTransform
from .ParseTreeTransforms import AnalyseExpressionsTransform, FindInvalidUseOfFusedTypes from .ParseTreeTransforms import AnalyseExpressionsTransform, FindInvalidUseOfFusedTypes
...@@ -212,6 +212,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -212,6 +212,7 @@ def create_pipeline(context, mode, exclude_classes=()):
_check_c_declarations, _check_c_declarations,
InlineDefNodeCalls(context), InlineDefNodeCalls(context),
AnalyseExpressionsTransform(context), AnalyseExpressionsTransform(context),
CypclassLockTransform(context),
FindInvalidUseOfFusedTypes(context), FindInvalidUseOfFusedTypes(context),
ExpandInplaceOperators(context), ExpandInplaceOperators(context),
IterationTransform(context), IterationTransform(context),
......
...@@ -158,10 +158,6 @@ class Entry(object): ...@@ -158,10 +158,6 @@ class Entry(object):
# is_fused_specialized boolean Whether this entry of a cdef or def function # is_fused_specialized boolean Whether this entry of a cdef or def function
# is a specialization # is a specialization
# is_cgetter boolean Is a c-level getter function # is_cgetter boolean Is a c-level getter function
# is_wlocked boolean Is locked with a write lock (used for cypclass)
# is_rlocked boolean Is locked with a read lock (used for cypclass)
# needs_rlock boolean The entry needs a read lock (used in autolock mode)
# needs_wlock boolean The entry needs a write lock (used in autolock mode)
# #
# is_default boolean This entry is a compiler-generated default and # is_default boolean This entry is a compiler-generated default and
# is not user-defined (e.g default contructor) # is not user-defined (e.g default contructor)
...@@ -173,7 +169,7 @@ class Entry(object): ...@@ -173,7 +169,7 @@ class Entry(object):
# mro_index integer The index of the type where this entry was originally # mro_index integer The index of the type where this entry was originally
# declared in the mro of the cypclass where it is now # declared in the mro of the cypclass where it is now
# #
# defining_classes [CypClassType or CppClassType or CStructOrUnionType] # defining_classes [CypClassType or CppClassType or CStructOrUnionType]
# All the base classes that define an entry that this entry # All the base classes that define an entry that this entry
# overrides, if this entry represents a cypclass method # overrides, if this entry represents a cypclass method
# #
...@@ -251,10 +247,6 @@ class Entry(object): ...@@ -251,10 +247,6 @@ class Entry(object):
cf_used = True cf_used = True
outer_entry = None outer_entry = None
is_cgetter = False is_cgetter = False
is_wlocked = False
is_rlocked = False
needs_rlock = False
needs_wlock = False
is_default = False is_default = False
mro_index = 0 mro_index = 0
from_type = None from_type = None
...@@ -335,14 +327,6 @@ class InnerEntry(Entry): ...@@ -335,14 +327,6 @@ class InnerEntry(Entry):
def all_entries(self): def all_entries(self):
return self.defining_entry.all_entries() return self.defining_entry.all_entries()
class TrackedLockedEntry:
def __init__(self, entry, scope):
self.entry = entry
self.scope = scope
self.is_wlocked = False
self.is_rlocked = False
self.needs_wlock = False
self.needs_rlock = False
class Scope(object): class Scope(object):
# name string Unqualified name # name string Unqualified name
...@@ -357,7 +341,6 @@ class Scope(object): ...@@ -357,7 +341,6 @@ class Scope(object):
# cfunc_entries [Entry] C function entries # cfunc_entries [Entry] C function entries
# c_class_entries [Entry] All extension type entries # c_class_entries [Entry] All extension type entries
# cypclass_entries [Entry] All cypclass entries # cypclass_entries [Entry] All cypclass entries
# autolocked_nodes [ExprNodes] All autolocked nodes that needs unlocking
# cname_to_entry {string : Entry} Temp cname to entry mapping # cname_to_entry {string : Entry} Temp cname to entry mapping
# return_type PyrexType or None Return type of function owning scope # return_type PyrexType or None Return type of function owning scope
# is_builtin_scope boolean Is the builtin scope of Python/Cython # is_builtin_scope boolean Is the builtin scope of Python/Cython
...@@ -420,7 +403,6 @@ class Scope(object): ...@@ -420,7 +403,6 @@ class Scope(object):
self.c_class_entries = [] self.c_class_entries = []
self.cypclass_entries = [] self.cypclass_entries = []
self.defined_c_classes = [] self.defined_c_classes = []
self.autolocked_nodes = []
self.imported_c_classes = {} self.imported_c_classes = {}
self.cname_to_entry = {} self.cname_to_entry = {}
self.identifier_to_entry = {} self.identifier_to_entry = {}
...@@ -429,7 +411,6 @@ class Scope(object): ...@@ -429,7 +411,6 @@ class Scope(object):
self.buffer_entries = [] self.buffer_entries = []
self.lambda_defs = [] self.lambda_defs = []
self.id_counters = {} self.id_counters = {}
self.tracked_entries = {}
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
...@@ -525,18 +506,6 @@ class Scope(object): ...@@ -525,18 +506,6 @@ class Scope(object):
for e, s in cypclass_scope.iter_cypclass_entries_and_scopes(): for e, s in cypclass_scope.iter_cypclass_entries_and_scopes():
yield e, s yield e, s
def declare_tracked(self, entry):
# Keying only with the name is wrong: if we have multiple attributes
# with the same name in different cypclass, this will conflict.
key = entry
self.tracked_entries[key] = TrackedLockedEntry(entry, self)
return self.tracked_entries[key]
def lookup_tracked(self, entry):
# We don't chain up the scopes on purpose: we want to keep things local
key = entry
return self.tracked_entries.get(key, None)
def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0, create_wrapper = 0, from_type = None): def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0, create_wrapper = 0, from_type = None):
# Create new entry, and add to dictionary if # Create new entry, and add to dictionary if
# name is not None. Reports a warning if already # name is not None. Reports a warning if already
...@@ -2093,12 +2062,6 @@ class ModuleScope(Scope): ...@@ -2093,12 +2062,6 @@ class ModuleScope(Scope):
from .TypeInference import PyObjectTypeInferer from .TypeInference import PyObjectTypeInferer
PyObjectTypeInferer().infer_types(self) PyObjectTypeInferer().infer_types(self)
def declare_autolocked(self, node):
# Add an entry for autolocked cypclass
if not (node.type.is_cyp_class and node.type.lock_mode == "autolock"):
error(node.pos, "Trying to autolock a non (autolocked) cypclass object !")
self.autolocked_nodes.append(node)
class LocalScope(Scope): class LocalScope(Scope):
...@@ -2125,20 +2088,10 @@ class LocalScope(Scope): ...@@ -2125,20 +2088,10 @@ class LocalScope(Scope):
if type.is_pyobject: if type.is_pyobject:
entry.init = "0" entry.init = "0"
entry.is_arg = 1 entry.is_arg = 1
if type.is_cyp_class and type.lock_mode != "nolock":
arg_lock_state = self.declare_tracked(entry)
arg_lock_state.is_rlocked = type.is_const
arg_lock_state.is_wlocked = not type.is_const
#entry.borrowed = 1 # Not using borrowed arg refs for now #entry.borrowed = 1 # Not using borrowed arg refs for now
self.arg_entries.append(entry) self.arg_entries.append(entry)
return entry return entry
def declare_autolocked(self, node):
# Add an entry for autolocked cypclass
if not (node.type.is_cyp_class and node.type.lock_mode == "autolock"):
error(node.pos, "Trying to autolock a non (autolocked) cypclass object !")
self.autolocked_nodes.append(node)
def declare_var(self, name, type, pos, def declare_var(self, name, type, pos,
cname = None, visibility = 'private', cname = None, visibility = 'private',
api = 0, in_pxd = 0, is_cdef = 0): api = 0, in_pxd = 0, is_cdef = 0):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment