Commit af6626ae authored by gsamain's avatar gsamain Committed by Xavier Thompson

Tracked states don't go outside of scope

parent 7611da16
...@@ -318,6 +318,8 @@ class ExprNode(Node): ...@@ -318,6 +318,8 @@ class ExprNode(Node):
use_managed_ref = True # can be set by optimisation transforms use_managed_ref = True # can be set by optimisation transforms
result_is_used = True result_is_used = True
is_numpy_attribute = False is_numpy_attribute = False
tracked_state = None
was_locked = True
# The Analyse Expressions phase for expressions is split # The Analyse Expressions phase for expressions is split
# into two sub-phases: # into two sub-phases:
...@@ -722,25 +724,45 @@ class ExprNode(Node): ...@@ -722,25 +724,45 @@ class ExprNode(Node):
error(self.pos, "Address is not constant") error(self.pos, "Address is not constant")
def set_autorlock(self, env): def set_autorlock(self, env):
self.entry.is_rlocked = True self.tracked_state.was_locked = True
self.entry.needs_rlock = True self.tracked_state.is_rlocked = True
self.tracked_state.needs_rlock = True
def set_autowlock(self, env): def set_autowlock(self, env):
print "Setting wlock" self.tracked_state.was_locked = True
self.entry.is_wlocked = True self.tracked_state.is_wlocked = True
self.entry.needs_wlock = True self.tracked_state.needs_wlock = True
def is_autolock(self, env): def is_autolock(self):
return self.type.is_cyp_class and self.type.lock_mode == "autolock" return self.type.is_cyp_class and self.type.lock_mode == "autolock"
def is_checklock(self, env): def is_checklock(self):
return self.type.is_cyp_class and self.type.lock_mode == "checklock" return self.type.is_cyp_class and self.type.lock_mode == "checklock"
def get_tracked_state(self, env):
if not hasattr(self, 'entry') or not self.entry.type.is_cyp_class:
return
self.tracked_state = env.lookup_tracked(self.entry)
if self.tracked_state is None:
self.tracked_state = env.declare_tracked(self.entry)
if self.is_autolock():
env.declare_autolocked(self)
self.was_locked = self.tracked_state.was_locked
self.tracked_state.was_locked = True
def is_rhs_locked(self, env): def is_rhs_locked(self, env):
return not(hasattr(self, 'entry') and self.entry.type.is_cyp_class and not (self.entry.is_rlocked or self.entry.is_wlocked)) if not hasattr(self, 'entry') or not self.entry.type.is_cyp_class:
# These nodes couldn't be tracked (because it is for example a constant),
# so we let them pass silently
return True
return self.tracked_state.is_rlocked or self.tracked_state.is_wlocked
def is_lhs_locked(self, env): def is_lhs_locked(self, env):
return not(hasattr(self, 'entry') and self.entry.type.is_cyp_class and not self.entry.is_wlocked) if not hasattr(self, 'entry') or not self.entry.type.is_cyp_class:
# These nodes couldn't be tracked (because it is for example a constant),
# so we let them pass silently
return True
return self.tracked_state.is_wlocked
def ensure_subexpr_rhs_locked(self, env): def ensure_subexpr_rhs_locked(self, env):
for node in self.subexpr_nodes(): for node in self.subexpr_nodes():
...@@ -752,20 +774,24 @@ class ExprNode(Node): ...@@ -752,20 +774,24 @@ class ExprNode(Node):
def ensure_rhs_locked(self, env, is_dereferenced = False): def ensure_rhs_locked(self, env, is_dereferenced = False):
self.ensure_subexpr_rhs_locked(env) self.ensure_subexpr_rhs_locked(env)
if not self.tracked_state:
self.get_tracked_state(env)
if is_dereferenced: if is_dereferenced:
if not self.is_rhs_locked(env): if not self.is_rhs_locked(env):
if self.is_checklock(env): if self.is_checklock():
error(self.pos, "This expression is not correctly locked (read lock needed)") error(self.pos, "This expression is not correctly locked (read lock needed)")
elif self.is_autolock(env): elif self.is_autolock():
self.set_autorlock(env) self.set_autorlock(env)
def ensure_lhs_locked(self, env, is_dereferenced = False): def ensure_lhs_locked(self, env, is_dereferenced = False):
self.ensure_subexpr_lhs_locked(env) self.ensure_subexpr_lhs_locked(env)
if not self.tracked_state:
self.get_tracked_state(env)
if is_dereferenced: if is_dereferenced:
if not self.is_lhs_locked(env): if not self.is_lhs_locked(env):
if self.is_checklock(env): if self.is_checklock():
error(self.pos, "This expression is not correctly locked (write lock needed)") error(self.pos, "This expression is not correctly locked (write lock needed)")
elif self.is_autolock(env): elif self.is_autolock():
self.set_autowlock(env) self.set_autowlock(env)
# ----------------- Result Allocation ----------------- # ----------------- Result Allocation -----------------
...@@ -2372,8 +2398,14 @@ class NameNode(AtomicExprNode): ...@@ -2372,8 +2398,14 @@ class NameNode(AtomicExprNode):
code.error_goto_if_null(self.result(), self.pos))) code.error_goto_if_null(self.result(), self.pos)))
self.generate_gotref(code) self.generate_gotref(code)
elif entry.is_local and entry.type.is_cyp_class: elif entry.type.is_cyp_class:
code.put_cygotref(self.result()) code.put_cygotref(self.result())
if not self.was_locked and self.is_autolock():
tracked_state = self.tracked_state
if tracked_state.needs_wlock:
code.putln("Cy_WLOCK(%s);" % self.result())
elif tracked_state.needs_rlock:
code.putln("Cy_RLOCK(%s);" % self.result())
#pass #pass
# code.putln(entry.cname) # code.putln(entry.cname)
elif entry.is_local or entry.in_closure or entry.from_closure or entry.type.is_memoryviewslice: elif entry.is_local or entry.in_closure or entry.from_closure or entry.type.is_memoryviewslice:
...@@ -2391,6 +2423,7 @@ class NameNode(AtomicExprNode): ...@@ -2391,6 +2423,7 @@ class NameNode(AtomicExprNode):
exception_check=None, exception_value=None): exception_check=None, exception_value=None):
#print "NameNode.generate_assignment_code:", self.name ### #print "NameNode.generate_assignment_code:", self.name ###
entry = self.entry entry = self.entry
tracked_state = self.tracked_state
if entry is None: if entry is None:
return # There was an error earlier return # There was an error earlier
...@@ -2398,7 +2431,7 @@ class NameNode(AtomicExprNode): ...@@ -2398,7 +2431,7 @@ class NameNode(AtomicExprNode):
and not self.lhs_of_first_assignment and not rhs.in_module_scope): and not self.lhs_of_first_assignment and not rhs.in_module_scope):
error(self.pos, "Literal list must be assigned to pointer at time of declaration") error(self.pos, "Literal list must be assigned to pointer at time of declaration")
if entry.needs_wlock or entry.needs_rlock: if self.is_autolock() and tracked_state and (tracked_state.needs_wlock or tracked_state.needs_rlock):
code.putln("Cy_UNLOCK(%s);" % self.result()) code.putln("Cy_UNLOCK(%s);" % self.result())
# is_pyglobal seems to be True for module level-globals only. # is_pyglobal seems to be True for module level-globals only.
...@@ -2506,10 +2539,11 @@ class NameNode(AtomicExprNode): ...@@ -2506,10 +2539,11 @@ class NameNode(AtomicExprNode):
code.putln('new (&%s) decltype(%s){%s};' % (self.result(), self.result(), result)) code.putln('new (&%s) decltype(%s){%s};' % (self.result(), self.result(), result))
elif result != self.result(): elif result != self.result():
code.putln('%s = %s;' % (self.result(), result)) code.putln('%s = %s;' % (self.result(), result))
if entry.needs_wlock: if self.is_autolock():
code.putln("Cy_WLOCK(%s);" % self.result()) if tracked_state.needs_wlock:
elif entry.needs_rlock: code.putln("Cy_WLOCK(%s);" % self.result())
code.putln("Cy_RLOCK(%s);" % self.result()) elif tracked_state.needs_rlock:
code.putln("Cy_RLOCK(%s);" % self.result())
if debug_disposal_code: if debug_disposal_code:
print("NameNode.generate_assignment_code:") print("NameNode.generate_assignment_code:")
print("...generating post-assignment code for %s" % rhs) print("...generating post-assignment code for %s" % rhs)
...@@ -5921,10 +5955,11 @@ class SimpleCallNode(CallNode): ...@@ -5921,10 +5955,11 @@ class SimpleCallNode(CallNode):
for i in range(min(max_nargs, actual_nargs)): for i in range(min(max_nargs, actual_nargs)):
formal_arg = func_type.args[i] formal_arg = func_type.args[i]
actual_arg = args[i] actual_arg = args[i]
if formal_arg.type.is_const: if formal_arg.type.is_cyp_class:
actual_arg.ensure_rhs_locked(env, is_dereferenced = True) if formal_arg.type.is_const:
else: actual_arg.ensure_rhs_locked(env, is_dereferenced = True)
actual_arg.ensure_lhs_locked(env, is_dereferenced = True) else:
actual_arg.ensure_lhs_locked(env, is_dereferenced = True)
# Coerce arguments # Coerce arguments
some_args_in_temps = False some_args_in_temps = False
for i in range(min(max_nargs, actual_nargs)): for i in range(min(max_nargs, actual_nargs)):
...@@ -7472,6 +7507,12 @@ class AttributeNode(ExprNode): ...@@ -7472,6 +7507,12 @@ class AttributeNode(ExprNode):
'"Memoryview is not initialized");' '"Memoryview is not initialized");'
'%s' '%s'
'}' % (self.result(), code.error_goto(self.pos))) '}' % (self.result(), code.error_goto(self.pos)))
elif self.is_autolock():
if not self.was_locked:
if self.tracked_state.needs_wlock:
code.putln("Cy_WLOCK(%s);" % self.result())
elif self.tracked_state.needs_rlock:
code.putln("Cy_RLOCK(%s);" % self.result())
else: else:
# result_code contains what is needed, but we may need to insert # result_code contains what is needed, but we may need to insert
# a check and raise an exception # a check and raise an exception
...@@ -7508,7 +7549,9 @@ class AttributeNode(ExprNode): ...@@ -7508,7 +7549,9 @@ class AttributeNode(ExprNode):
rhs.result_as(self.ctype()))) rhs.result_as(self.ctype())))
else: else:
select_code = self.result() select_code = self.result()
if self.entry.needs_rlock or self.entry.needs_wlock: # XXX - Greater to have a getter, right ?
tracked_state = self.tracked_state
if self.is_autolock() and tracked_state and (tracked_state.needs_rlock or tracked_state.needs_wlock):
code.putln("Cy_UNLOCK(%s);" % select_code) code.putln("Cy_UNLOCK(%s);" % select_code)
if self.type.is_pyobject and self.use_managed_ref: if self.type.is_pyobject and self.use_managed_ref:
rhs.make_owned_reference(code) rhs.make_owned_reference(code)
...@@ -7531,10 +7574,11 @@ class AttributeNode(ExprNode): ...@@ -7531,10 +7574,11 @@ class AttributeNode(ExprNode):
select_code, select_code,
rhs.move_result_rhs_as(self.ctype()))) rhs.move_result_rhs_as(self.ctype())))
#rhs.result())) #rhs.result()))
if self.entry.needs_wlock: if self.is_autolock():
code.putln("Cy_WLOCK(%s);" % select_code) if tracked_state.needs_wlock:
elif self.entry.needs_rlock: code.putln("Cy_WLOCK(%s);" % select_code)
code.putln("Cy_RLOCK(%s);" % select_code) elif tracked_state.needs_rlock:
code.putln("Cy_RLOCK(%s);" % select_code)
rhs.generate_post_assignment_code(code) rhs.generate_post_assignment_code(code)
rhs.free_temps(code) rhs.free_temps(code)
......
...@@ -2200,6 +2200,23 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -2200,6 +2200,23 @@ class FuncDefNode(StatNode, BlockNode):
align_error_path_gil_to_success_path() align_error_path_gil_to_success_path()
code.put_label(code.return_from_error_cleanup_label) code.put_label(code.return_from_error_cleanup_label)
for node in reversed(lenv.autolocked_nodes):
# We iterate in the reverse order to properly unlock
# nested locked objects (aka most nested first).
# For example, if we have the following situation:
# obj.sub_obj.attr = some_value
# If obj and sub_obj are both of autolocked types,
# the obj (name)node is declared before the sub_obj (attribute)node.
# If we unlock first obj, another thread could immediately acquire
# a write lock and change where sub_obj points to.
# We would then try to unlock the new sub_obj reference,
# which leads to a dangling lock on the previous reference
# (and attempt to unlock a non-locked ref).
if not node.was_locked and (node.tracked_state.needs_wlock or node.tracked_state.needs_rlock):
code.putln("Cy_UNLOCK(%s);" % node.result())
for entry in lenv.var_entries: for entry in lenv.var_entries:
if not entry.used or entry.in_closure: if not entry.used or entry.in_closure:
continue continue
...@@ -2215,10 +2232,6 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -2215,10 +2232,6 @@ class FuncDefNode(StatNode, BlockNode):
# FIXME ideally use entry.xdecref_cleanup but this currently isn't reliable # FIXME ideally use entry.xdecref_cleanup but this currently isn't reliable
code.put_var_xdecref(entry, have_gil=gil_owned['success']) code.put_var_xdecref(entry, have_gil=gil_owned['success'])
for node in lenv.autolocked_nodes:
if node.entry.needs_rlock or node.entry.needs_wlock:
code.putln("Cy_UNLOCK(%s);" % node.result())
# Decref any increfed args # Decref any increfed args
for entry in lenv.arg_entries: for entry in lenv.arg_entries:
if entry.type.is_memoryviewslice: if entry.type.is_memoryviewslice:
...@@ -2604,7 +2617,8 @@ class CFuncDefNode(FuncDefNode): ...@@ -2604,7 +2617,8 @@ class CFuncDefNode(FuncDefNode):
entry = self.local_scope.declare(_name, _cname, _type, _pos, 'private') entry = self.local_scope.declare(_name, _cname, _type, _pos, 'private')
entry.is_variable = 1 entry.is_variable = 1
# Even if it is checklock it should be OK to mess with self without locking # Even if it is checklock it should be OK to mess with self without locking
entry.is_wlocked = True self_locking_state = self.local_scope.declare_tracked(entry)
self_locking_state.is_wlocked = True
def declare_cpdef_wrapper(self, env): def declare_cpdef_wrapper(self, env):
if self.overridable: if self.overridable:
...@@ -5767,11 +5781,6 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5767,11 +5781,6 @@ class SingleAssignmentNode(AssignmentNode):
self.lhs = self.lhs.analyse_target_types(env) self.lhs = self.lhs.analyse_target_types(env)
self.lhs.gil_assignment_check(env) self.lhs.gil_assignment_check(env)
if hasattr(self.lhs, 'entry'):
entry = self.lhs.entry
if entry.type.is_cyp_class and entry.type.lock_mode == "autolock"\
and not (entry.needs_rlock or entry.needs_wlock):
env.declare_autolocked(self.lhs)
self.rhs.ensure_rhs_locked(env) self.rhs.ensure_rhs_locked(env)
self.lhs.ensure_lhs_locked(env) self.lhs.ensure_lhs_locked(env)
unrolled_assignment = self.unroll_lhs(env) unrolled_assignment = self.unroll_lhs(env)
...@@ -8403,13 +8412,16 @@ class LockCypclassNode(StatNode): ...@@ -8403,13 +8412,16 @@ class LockCypclassNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.obj = self.obj.analyse_types(env) self.obj = self.obj.analyse_types(env)
self.obj.ensure_rhs_locked(env)
if not hasattr(self.obj, 'entry'): if not hasattr(self.obj, 'entry'):
error(self.pos, "The (un)locking target has no entry") error(self.pos, "The (un)locking target has no entry")
if not self.obj.type.is_cyp_class: if not self.obj.type.is_cyp_class:
error(self.pos, "Cannot (un)lock a non-cypclass variable !") error(self.pos, "Cannot (un)lock a non-cypclass variable !")
is_rlocked = self.obj.entry.is_rlocked # FIXME: this is a bit redundant here
is_wlocked = self.obj.entry.is_wlocked self.obj.get_tracked_state(env)
is_rlocked = self.obj.is_rhs_locked(env)
is_wlocked = self.obj.is_lhs_locked(env)
if self.state == "unclocked" and not (is_rlocked or is_wlocked): if self.state == "unclocked" and not (is_rlocked or is_wlocked):
error(self.pos, "Cannot unlock an already unlocked object !") error(self.pos, "Cannot unlock an already unlocked object !")
...@@ -8426,23 +8438,29 @@ class LockCypclassNode(StatNode): ...@@ -8426,23 +8438,29 @@ class LockCypclassNode(StatNode):
self.was_rlocked = is_rlocked self.was_rlocked = is_rlocked
self.was_wlocked = is_wlocked self.was_wlocked = is_wlocked
tracked_state = self.obj.tracked_state
if self.state == "rlocked": if self.state == "rlocked":
self.obj.entry.is_rlocked = True tracked_state.is_rlocked = True
self.obj.entry.is_wlocked = False tracked_state.is_wlocked = False
elif self.state == "wlocked": elif self.state == "wlocked":
self.obj.entry.is_rlocked = False tracked_state.is_rlocked = False
self.obj.entry.is_wlocked = True tracked_state.is_wlocked = True
else: else:
self.obj.entry.is_rlocked = False tracked_state.is_rlocked = False
self.obj.entry.is_wlocked = False tracked_state.is_wlocked = False
self.body = self.body.analyse_expressions(env) self.body = self.body.analyse_expressions(env)
self.obj.entry.is_rlocked = self.was_rlocked #self.obj.entry.is_rlocked = self.was_rlocked
self.obj.entry.is_wlocked = self.was_wlocked #self.obj.entry.is_wlocked = self.was_wlocked
tracked_state.is_rlocked = self.was_rlocked
tracked_state.is_wlocked = self.was_wlocked
return self return self
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.obj.generate_evaluation_code(code)
# We must unlock if it's a 'with unlocked' statement, # We must unlock if it's a 'with unlocked' statement,
# or if we're changing lock type. # or if we're changing lock type.
if self.was_rlocked or self.was_wlocked: if self.was_rlocked or self.was_wlocked:
......
...@@ -162,6 +162,7 @@ class Entry(object): ...@@ -162,6 +162,7 @@ class Entry(object):
# is_rlocked boolean Is locked with a read lock (used for cypclass) # is_rlocked boolean Is locked with a read lock (used for cypclass)
# needs_rlock boolean The entry needs a read lock (used in autolock mode) # needs_rlock boolean The entry needs a read lock (used in autolock mode)
# needs_wlock boolean The entry needs a write lock (used in autolock mode) # needs_wlock boolean The entry needs a write lock (used in autolock mode)
# was_locked boolean Indicates to nodes falling through that the first lock already took place
# TODO: utility_code and utility_code_definition serves the same purpose... # TODO: utility_code and utility_code_definition serves the same purpose...
...@@ -237,6 +238,7 @@ class Entry(object): ...@@ -237,6 +238,7 @@ class Entry(object):
is_rlocked = False is_rlocked = False
needs_rlock = False needs_rlock = False
needs_wlock = False needs_wlock = False
was_locked = False
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
...@@ -311,6 +313,15 @@ class InnerEntry(Entry): ...@@ -311,6 +313,15 @@ class InnerEntry(Entry):
def all_entries(self): def all_entries(self):
return self.defining_entry.all_entries() return self.defining_entry.all_entries()
class TrackedLockedEntry:
def __init__(self, entry, scope):
self.entry = entry
self.scope = scope
self.is_wlocked = False
self.is_rlocked = False
self.needs_wlock = False
self.needs_rlock = False
self.was_locked = False
class Scope(object): class Scope(object):
# name string Unqualified name # name string Unqualified name
...@@ -393,6 +404,8 @@ class Scope(object): ...@@ -393,6 +404,8 @@ class Scope(object):
self.buffer_entries = [] self.buffer_entries = []
self.lambda_defs = [] self.lambda_defs = []
self.id_counters = {} self.id_counters = {}
self.tracked_entries = {}
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
return self return self
...@@ -475,6 +488,18 @@ class Scope(object): ...@@ -475,6 +488,18 @@ class Scope(object):
for scope in sorted(self.subscopes, key=operator.attrgetter('scope_prefix')): for scope in sorted(self.subscopes, key=operator.attrgetter('scope_prefix')):
yield scope yield scope
def declare_tracked(self, entry):
# Keying only with the name is wrong: if we have multiple attributes
# with the same name in different cypclass, this will conflict.
key = entry
self.tracked_entries[key] = TrackedLockedEntry(entry, self)
return self.tracked_entries[key]
def lookup_tracked(self, entry):
# We don't chain up the scopes on purpose: we want to keep things local
key = entry
return self.tracked_entries.get(key, None)
def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0, create_wrapper = 0): def declare(self, name, cname, type, pos, visibility, shadow = 0, is_type = 0, create_wrapper = 0):
# Create new entry, and add to dictionary if # Create new entry, and add to dictionary if
# name is not None. Reports a warning if already # name is not None. Reports a warning if already
...@@ -1875,7 +1900,8 @@ class LocalScope(Scope): ...@@ -1875,7 +1900,8 @@ class LocalScope(Scope):
entry.init = "0" entry.init = "0"
entry.is_arg = 1 entry.is_arg = 1
if type.is_cyp_class and type.lock_mode == "autolock": if type.is_cyp_class and type.lock_mode == "autolock":
entry.is_wlocked = True arg_lock_state = self.declare_tracked(entry)
arg_lock_state.is_wlocked = True
#entry.borrowed = 1 # Not using borrowed arg refs for now #entry.borrowed = 1 # Not using borrowed arg refs for now
self.arg_entries.append(entry) self.arg_entries.append(entry)
return entry return entry
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment