Commit a2e278d5 authored by Julien Muchembled's avatar Julien Muchembled

client: fix race condition between Storage.load() and invalidations

This fixes a bug that could manifest as follows:

  Traceback (most recent call last):
    File "neo/client/app.py", line 432, in load
      self._cache.store(oid, data, tid, next_tid)
    File "neo/client/cache.py", line 223, in store
      assert item.tid == tid, (item, tid)
  AssertionError: (<CacheItem oid='\x00\x00\x00\x00\x00\x00\x00\x01' tid='\x03\xcb\xc6\xca\xfd\xc7\xda\xee' next_tid='\x03\xcb\xc6\xca\xfd\xd8\t\x88' data='...' counter=1 level=1 expire=10000 prev=<...> next=<...>>, '\x03\xcb\xc6\xca\xfd\xd8\t\x88')

The big changes in the threaded test framework are required because we need to
reproduce a race condition between client threads and this conflicts with the
serialization of epoll events (deadlock).
parent 743026d5
...@@ -410,8 +410,16 @@ class Application(ThreadedApplication): ...@@ -410,8 +410,16 @@ class Application(ThreadedApplication):
if result: if result:
return result return result
self._loading_oid = oid self._loading_oid = oid
self._loading_invalidated = []
finally: finally:
release() release()
# While the cache lock is released, an arbitrary number of
# invalidations may be processed, for this oid or not. And at this
# precise moment, if both tid and before_tid are None (which is
# unlikely to happen with recent ZODB), self.last_tid can be any
# new tid. Since we can get any serial from storage, fixing
# next_tid requires to keep a list of all possible serials.
# When not bound to a ZODB Connection, load() may be the # When not bound to a ZODB Connection, load() may be the
# first method called and last_tid may still be None. # first method called and last_tid may still be None.
# This happens, for example, when opening the DB. # This happens, for example, when opening the DB.
...@@ -423,12 +431,11 @@ class Application(ThreadedApplication): ...@@ -423,12 +431,11 @@ class Application(ThreadedApplication):
acquire() acquire()
try: try:
if self._loading_oid: if self._loading_oid:
# Common case (no race condition).
self._cache.store(oid, data, tid, next_tid)
elif self._loading_invalidated:
# oid has just been invalidated.
if not next_tid: if not next_tid:
next_tid = self._loading_invalidated for t in self._loading_invalidated:
if tid < t:
next_tid = t
break
self._cache.store(oid, data, tid, next_tid) self._cache.store(oid, data, tid, next_tid)
# Else, we just reconnected to the master. # Else, we just reconnected to the master.
finally: finally:
......
...@@ -127,8 +127,7 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -127,8 +127,7 @@ class PrimaryNotificationsHandler(MTEventHandler):
for oid in oid_list: for oid in oid_list:
invalidate(oid, tid) invalidate(oid, tid)
if oid == loading: if oid == loading:
app._loading_oid = None app._loading_invalidated.append(tid)
app._loading_invalidated = tid
db = app.getDB() db = app.getDB()
if db is not None: if db is not None:
db.invalidate(tid, oid_list) db.invalidate(tid, oid_list)
......
...@@ -26,6 +26,7 @@ from zlib import decompress ...@@ -26,6 +26,7 @@ from zlib import decompress
import transaction, ZODB import transaction, ZODB
import neo.admin.app, neo.master.app, neo.storage.app import neo.admin.app, neo.master.app, neo.storage.app
import neo.client.app, neo.neoctl.app import neo.client.app, neo.neoctl.app
from neo.admin.handler import MasterEventHandler
from neo.client import Storage from neo.client import Storage
from neo.lib import logging from neo.lib import logging
from neo.lib.connection import BaseConnection, \ from neo.lib.connection import BaseConnection, \
...@@ -36,6 +37,7 @@ from neo.lib.locking import SimpleQueue ...@@ -36,6 +37,7 @@ from neo.lib.locking import SimpleQueue
from neo.lib.protocol import uuid_str, \ from neo.lib.protocol import uuid_str, \
ClusterStates, Enum, NodeStates, NodeTypes, Packets ClusterStates, Enum, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
from neo.master.recovery import RecoveryManager
from .. import (getTempDirectory, setupMySQLdb, from .. import (getTempDirectory, setupMySQLdb,
ImporterConfigParser, NeoTestBase, Patch, ImporterConfigParser, NeoTestBase, Patch,
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER) ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER)
...@@ -119,9 +121,12 @@ class Serialized(object): ...@@ -119,9 +121,12 @@ class Serialized(object):
detect which node has a readable epoll object. detect which node has a readable epoll object.
""" """
check_timeout = False check_timeout = False
_disabled = False
@classmethod @classmethod
def init(cls): def init(cls):
if cls._disabled:
return
cls._busy = set() cls._busy = set()
cls._busy_cond = threading.Condition(threading.Lock()) cls._busy_cond = threading.Condition(threading.Lock())
cls._epoll = select.epoll() cls._epoll = select.epoll()
...@@ -138,6 +143,8 @@ class Serialized(object): ...@@ -138,6 +143,8 @@ class Serialized(object):
@classmethod @classmethod
def stop(cls): def stop(cls):
if cls._disabled:
return
assert not cls._fd_dict, ("file descriptor leak (%r)\nThis may happen" assert not cls._fd_dict, ("file descriptor leak (%r)\nThis may happen"
" when a test fails, in which case you can see the real exception" " when a test fails, in which case you can see the real exception"
" by disabling this one." % cls._fd_dict) " by disabling this one." % cls._fd_dict)
...@@ -148,6 +155,25 @@ class Serialized(object): ...@@ -148,6 +155,25 @@ class Serialized(object):
def _sort_key(cls, fd_event): def _sort_key(cls, fd_event):
return -cls._fd_dict[fd_event[0]]._last return -cls._fd_dict[fd_event[0]]._last
@classmethod
@contextmanager
def until(cls, patched=None, **patch):
if cls._disabled:
if patched is None:
yield int
else:
l = threading.Lock()
l.acquire()
(name, patch), = patch.iteritems()
def release():
p.revert()
l.release()
with Patch(patched, **{name: lambda *args, **kw:
patch(release, *args, **kw)}) as p:
yield l.acquire
else:
yield cls.tic
@classmethod @classmethod
@contextmanager @contextmanager
def pdb(cls): def pdb(cls):
...@@ -174,6 +200,10 @@ class Serialized(object): ...@@ -174,6 +200,10 @@ class Serialized(object):
# We also increase SocketConnector.SOMAXCONN in tests so that # We also increase SocketConnector.SOMAXCONN in tests so that
# a connection attempt is never delayed inside the kernel. # a connection attempt is never delayed inside the kernel.
timeout=0): timeout=0):
if cls._disabled:
if timeout:
time.sleep(timeout)
return
# If you're in a pdb here, 'n' switches to another thread # If you're in a pdb here, 'n' switches to another thread
# (the following lines are not supposed to be debugged into) # (the following lines are not supposed to be debugged into)
with cls._tic_lock, cls.pdb(): with cls._tic_lock, cls.pdb():
...@@ -208,6 +238,8 @@ class Serialized(object): ...@@ -208,6 +238,8 @@ class Serialized(object):
cls._sched_lock.acquire() cls._sched_lock.acquire()
def __init__(self, app, busy=True): def __init__(self, app, busy=True):
if self._disabled:
return
self._epoll = app.em.epoll self._epoll = app.em.epoll
app.em.epoll = self app.em.epoll = self
# XXX: It may have been initialized before the SimpleQueue is patched. # XXX: It may have been initialized before the SimpleQueue is patched.
...@@ -360,6 +392,7 @@ class ServerNode(Node): ...@@ -360,6 +392,7 @@ class ServerNode(Node):
finally: finally:
self._afterRun() self._afterRun()
logging.debug('stopping %r', self) logging.debug('stopping %r', self)
if isinstance(self.em.epoll, Serialized):
self.em.epoll.exit() self.em.epoll.exit()
def _afterRun(self): def _afterRun(self):
...@@ -427,6 +460,7 @@ class ClientApplication(Node, neo.client.app.Application): ...@@ -427,6 +460,7 @@ class ClientApplication(Node, neo.client.app.Application):
try: try:
super(ClientApplication, self)._run() super(ClientApplication, self)._run()
finally: finally:
if isinstance(self.em.epoll, Serialized):
self.em.epoll.exit() self.em.epoll.exit()
def start(self): def start(self):
...@@ -616,6 +650,8 @@ class NEOCluster(object): ...@@ -616,6 +650,8 @@ class NEOCluster(object):
def __init__(orig, self): # temporary definition for SimpleQueue patch def __init__(orig, self): # temporary definition for SimpleQueue patch
orig(self) orig(self)
if Serialized._disabled:
return
lock = self._lock lock = self._lock
def _lock(blocking=True): def _lock(blocking=True):
if blocking: if blocking:
...@@ -765,22 +801,41 @@ class NEOCluster(object): ...@@ -765,22 +801,41 @@ class NEOCluster(object):
self.started = True self.started = True
self._patch() self._patch()
self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL) self.neoctl = NeoCTL(self.admin.getVirtualAddress(), ssl=self.SSL)
for node in self.master_list if master_list is None else master_list: if master_list is None:
master_list = self.master_list
if storage_list is None:
storage_list = self.storage_list
def answerPartitionTable(release, orig, *args):
orig(*args)
release()
def dispatch(release, orig, handler, *args):
orig(handler, *args)
node_list = handler.app.nm.getStorageList(only_identified=True)
if len(node_list) == len(storage_list) and not any(
node.getConnection().isPending() for node in node_list):
release()
expected_state = (ClusterStates.RECOVERING,) if recovering else (
ClusterStates.RUNNING, ClusterStates.BACKINGUP)
def notifyClusterInformation(release, orig, handler, conn, state):
orig(handler, conn, state)
if state in expected_state:
release()
with Serialized.until(MasterEventHandler,
answerPartitionTable=answerPartitionTable) as tic1, \
Serialized.until(RecoveryManager, dispatch=dispatch) as tic2, \
Serialized.until(MasterEventHandler,
notifyClusterInformation=notifyClusterInformation) as tic3:
for node in master_list:
node.start() node.start()
for node in self.admin_list: for node in self.admin_list:
node.start() node.start()
Serialized.tic() tic1()
if storage_list is None:
storage_list = self.storage_list
for node in storage_list: for node in storage_list:
node.start() node.start()
Serialized.tic() tic2()
if recovering: if not recovering:
expected_state = ClusterStates.RECOVERING
else:
self.startCluster() self.startCluster()
Serialized.tic() tic3()
expected_state = ClusterStates.RUNNING, ClusterStates.BACKINGUP
self.checkStarted(expected_state, storage_list) self.checkStarted(expected_state, storage_list)
def checkStarted(self, expected_state, storage_list=None): def checkStarted(self, expected_state, storage_list=None):
...@@ -1120,12 +1175,16 @@ def predictable_random(seed=None): ...@@ -1120,12 +1175,16 @@ def predictable_random(seed=None):
return wraps(wrapped)(wrapper) return wraps(wrapped)(wrapper)
return decorator return decorator
def with_cluster(start_cluster=True, **cluster_kw): def with_cluster(serialized=True, start_cluster=True, **cluster_kw):
def decorator(wrapped): def decorator(wrapped):
def wrapper(self, *args, **kw): def wrapper(self, *args, **kw):
try:
Serialized._disabled = not serialized
with NEOCluster(**cluster_kw) as cluster: with NEOCluster(**cluster_kw) as cluster:
if start_cluster: if start_cluster:
cluster.start() cluster.start()
return wrapped(self, cluster, *args, **kw) return wrapped(self, cluster, *args, **kw)
finally:
Serialized._disabled = False
return wraps(wrapped)(wrapper) return wraps(wrapped)(wrapper)
return decorator return decorator
...@@ -37,7 +37,7 @@ from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes, ...@@ -37,7 +37,7 @@ from neo.lib.protocol import (CellStates, ClusterStates, NodeStates, NodeTypes,
Packets, Packet, uuid_str, ZERO_OID, ZERO_TID, MAX_TID) Packets, Packet, uuid_str, ZERO_OID, ZERO_TID, MAX_TID)
from .. import unpickle_state, Patch, TransactionalResource from .. import unpickle_state, Patch, TransactionalResource
from . import ClientApplication, ConnectionFilter, LockLock, NEOCluster, \ from . import ClientApplication, ConnectionFilter, LockLock, NEOCluster, \
NEOThreadedTest, RandomConflictDict, ThreadId, with_cluster NEOThreadedTest, RandomConflictDict, Serialized, ThreadId, with_cluster
from neo.lib.util import add64, makeChecksum, p64, u64 from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.transactions import Transaction from neo.client.transactions import Transaction
...@@ -979,6 +979,45 @@ class Test(NEOThreadedTest): ...@@ -979,6 +979,45 @@ class Test(NEOThreadedTest):
self.assertFalse(invalidations(c1)) self.assertFalse(invalidations(c1))
self.assertEqual(x1.value, 1) self.assertEqual(x1.value, 1)
@with_cluster(serialized=False)
def testExternalInvalidation2(self, cluster):
t, c = cluster.getTransaction()
r = c.root()
x = r[''] = PCounter()
t.commit()
tid1 = x._p_serial
nonlocal_ = [0, 1]
l1 = threading.Lock(); l1.acquire()
l2 = threading.Lock(); l2.acquire()
def invalidateObjects(orig, *args):
if not nonlocal_[0]:
l1.acquire()
orig(*args)
nonlocal_[0] += 1
if nonlocal_[0] == 2:
l2.release()
def _cache_lock_release(orig):
orig()
if nonlocal_[1]:
nonlocal_[1] = 0
l1.release()
l2.acquire()
with cluster.newClient() as client, \
Patch(client.notifications_handler,
invalidateObjects=invalidateObjects):
client.sync()
with cluster.master.filterConnection(client) as mc2:
mc2.delayInvalidateObjects()
x._p_changed = 1
t.commit()
tid2 = x._p_serial
self.assertEqual((tid1, tid2), client.load(x._p_oid)[1:])
r._p_changed = 1
t.commit()
with Patch(client, _cache_lock_release=_cache_lock_release):
self.assertEqual((tid2, None), client.load(x._p_oid)[1:])
self.assertEqual(nonlocal_, [2, 0])
@with_cluster(storage_count=2, partitions=2) @with_cluster(storage_count=2, partitions=2)
def testReadVerifyingStorage(self, cluster): def testReadVerifyingStorage(self, cluster):
s1, s2 = cluster.sortStorageList() s1, s2 = cluster.sortStorageList()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment