Commit 04ab2a4c authored by Julien Muchembled's avatar Julien Muchembled

client: fix cache invalidation during a load from a storage for the same oid

This fixes the following an invalidation bug:

ERROR ZODB.Connection Couldn't load state for 0x504e
Traceback (most recent call last):
  File "ZODB/Connection.py", line 851, in setstate
    self._setstate(obj)
  File "ZODB/Connection.py", line 916, in _setstate
    self._load_before_or_conflict(obj)
  File "ZODB/Connection.py", line 931, in _load_before_or_conflict
    if not self._setstate_noncurrent(obj):
  File "ZODB/Connection.py", line 954, in _setstate_noncurrent
    assert end is not None
AssertionError
parent 5eef8d63
...@@ -18,6 +18,7 @@ from cPickle import dumps, loads ...@@ -18,6 +18,7 @@ from cPickle import dumps, loads
from zlib import compress as real_compress, decompress from zlib import compress as real_compress, decompress
from neo.lib.locking import Empty from neo.lib.locking import Empty
from random import shuffle from random import shuffle
from thread import get_ident
import heapq import heapq
import time import time
import os import os
...@@ -97,6 +98,7 @@ class Application(object): ...@@ -97,6 +98,7 @@ class Application(object):
# no self-assigned UUID, primary master will supply us one # no self-assigned UUID, primary master will supply us one
self.uuid = None self.uuid = None
self._cache = ClientCache() self._cache = ClientCache()
self._loading = {}
self.new_oid_list = [] self.new_oid_list = []
self.last_oid = '\0' * 8 self.last_oid = '\0' * 8
self.storage_event_handler = storage.StorageEventHandler(self) self.storage_event_handler = storage.StorageEventHandler(self)
...@@ -408,17 +410,29 @@ class Application(object): ...@@ -408,17 +410,29 @@ class Application(object):
# TODO: # TODO:
# - rename parameters (here? and in handlers & packet definitions) # - rename parameters (here? and in handlers & packet definitions)
acquire = self._cache_lock_acquire
release = self._cache_lock_release
self._load_lock_acquire() self._load_lock_acquire()
try: try:
result = self._loadFromCache(oid, tid, before_tid) acquire()
if not result: try:
result = self._loadFromStorage(oid, tid, before_tid) result = self._loadFromCache(oid, tid, before_tid)
self._cache_lock_acquire() if result:
return result
loading_key = oid, get_ident()
self._loading[loading_key] = None
release()
try: try:
self._cache.store(oid, *result) result = self._loadFromStorage(oid, tid, before_tid)
finally: finally:
self._cache_lock_release() acquire()
return result invalidated = self._loading.pop(loading_key)
if invalidated and not result[2]:
result = result[0], result[1], invalidated
self._cache.store(oid, *result)
return result
finally:
release()
finally: finally:
self._load_lock_release() self._load_lock_release()
...@@ -450,15 +464,11 @@ class Application(object): ...@@ -450,15 +464,11 @@ class Application(object):
""" """
Load from local cache, return None if not found. Load from local cache, return None if not found.
""" """
self._cache_lock_acquire() if at_tid:
try: result = self._cache.load(oid, at_tid + '*')
if at_tid: assert not result or result[1] == at_tid
result = self._cache.load(oid, at_tid + '*') return result
assert not result or result[1] == at_tid return self._cache.load(oid, before_tid)
return result
return self._cache.load(oid, before_tid)
finally:
self._cache_lock_release()
@profiler_decorator @profiler_decorator
def tpc_begin(self, transaction, tid=None, status=' '): def tpc_begin(self, transaction, tid=None, status=' '):
...@@ -772,6 +782,7 @@ class Application(object): ...@@ -772,6 +782,7 @@ class Application(object):
self._cache_lock_acquire() self._cache_lock_acquire()
try: try:
cache = self._cache cache = self._cache
loading = self._loading
for oid, data in cache_dict.iteritems(): for oid, data in cache_dict.iteritems():
if data is CHECKED_SERIAL: if data is CHECKED_SERIAL:
# this is just a remain of # this is just a remain of
...@@ -779,7 +790,12 @@ class Application(object): ...@@ -779,7 +790,12 @@ class Application(object):
# was modified). # was modified).
continue continue
# Update ex-latest value in cache # Update ex-latest value in cache
cache.invalidate(oid, tid) try:
cache.invalidate(oid, tid)
except KeyError:
for k in loading:
if k[0] == oid and not loading[k]:
loading[k] = tid
if data is not None: if data is not None:
# Store in cache with no next_tid # Store in cache with no next_tid
cache.store(oid, data, tid, None) cache.store(oid, data, tid, None)
......
...@@ -179,8 +179,11 @@ class ClientCache(object): ...@@ -179,8 +179,11 @@ class ClientCache(object):
if size < max_size: if size < max_size:
item = self._load(oid, next_tid) item = self._load(oid, next_tid)
if item: if item:
assert not (item.data or item.level)
assert item.tid == tid and item.next_tid == next_tid assert item.tid == tid and item.next_tid == next_tid
if item.level: # already stored
assert item.data == data
return
assert not item.data
self._history_size -= 1 self._history_size -= 1
else: else:
item = CacheItem() item = CacheItem()
...@@ -221,12 +224,31 @@ class ClientCache(object): ...@@ -221,12 +224,31 @@ class ClientCache(object):
def invalidate(self, oid, tid): def invalidate(self, oid, tid):
"""Mark data record as being valid only up to given tid""" """Mark data record as being valid only up to given tid"""
try: item = self._oid_dict[oid][-1]
item = self._oid_dict[oid][-1] if item.next_tid is None:
except KeyError: item.next_tid = tid
pass
else: else:
if item.next_tid is None: assert item.next_tid <= tid, (item, oid, tid)
item.next_tid = tid
else:
assert item.next_tid <= tid, (item, oid, tid) def test(self):
cache = ClientCache()
self.assertEqual(cache.load(1, 10), None)
self.assertEqual(cache.load(1, None), None)
self.assertRaises(KeyError, cache.invalidate, 1, 10)
data = 'foo', 5, 10
# 2 identical stores happens if 2 threads got a cache miss at the same time
cache.store(1, *data)
cache.store(1, *data)
self.assertEqual(cache.load(1, 10), data)
self.assertEqual(cache.load(1, None), None)
data = 'bar', 10, None
cache.store(1, *data)
self.assertEqual(cache.load(1, None), data)
cache.invalidate(1, 20)
self.assertEqual(cache.load(1, 20), ('bar', 10, 20))
if __name__ == '__main__':
import unittest
unittest.TextTestRunner().run(type('', (unittest.TestCase,), {
'runTest': test})())
...@@ -113,8 +113,14 @@ class PrimaryNotificationsHandler(BaseHandler): ...@@ -113,8 +113,14 @@ class PrimaryNotificationsHandler(BaseHandler):
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
invalidate = app._cache.invalidate invalidate = app._cache.invalidate
loading = app._loading
for oid in oid_list: for oid in oid_list:
invalidate(oid, tid) try:
invalidate(oid, tid)
except KeyError:
for k in loading:
if k[0] == oid and not loading[k]:
loading[k] = tid
db = app.getDB() db = app.getDB()
if db is not None: if db is not None:
db.invalidate(tid, oid_list) db.invalidate(tid, oid_list)
......
...@@ -20,6 +20,7 @@ from mock import Mock, ReturnValues ...@@ -20,6 +20,7 @@ from mock import Mock, ReturnValues
from ZODB.POSException import StorageTransactionError, UndoError, ConflictError from ZODB.POSException import StorageTransactionError, UndoError, ConflictError
from .. import NeoUnitTestBase, buildUrlFromString, ADDRESS_TYPE from .. import NeoUnitTestBase, buildUrlFromString, ADDRESS_TYPE
from neo.client.app import Application from neo.client.app import Application
from neo.client.cache import test as testCache
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError from neo.client.exception import NEOStorageDoesNotExistError
from neo.lib.protocol import NodeTypes, Packet, Packets, Errors, INVALID_TID, \ from neo.lib.protocol import NodeTypes, Packet, Packets, Errors, INVALID_TID, \
...@@ -160,6 +161,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -160,6 +161,8 @@ class ClientApplicationTests(NeoUnitTestBase):
#self.assertEqual(calls[0].getParam(0), conn) #self.assertEqual(calls[0].getParam(0), conn)
#self.assertTrue(isinstance(calls[0].getParam(2), Queue)) #self.assertTrue(isinstance(calls[0].getParam(2), Queue))
testCache = testCache
def test_registerDB(self): def test_registerDB(self):
app = self.getApp() app = self.getApp()
dummy_db = [] dummy_db = []
......
...@@ -511,26 +511,52 @@ class Test(NEOThreadedTest): ...@@ -511,26 +511,52 @@ class Test(NEOThreadedTest):
l1.release() l1.release()
l2.acquire() l2.acquire()
orig(conn, packet, kw, handler) orig(conn, packet, kw, handler)
def _loadFromStorage(orig, *args):
try:
return orig(*args)
finally:
l1.release()
l2.acquire()
cluster = NEOCluster() cluster = NEOCluster()
try: try:
cluster.start() cluster.start()
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
c1.root()['x'] = x = PCounter() c1.root()['x'] = x1 = PCounter()
t1.commit() t1.commit()
t1.begin() t1.begin()
x.value = 1 x1.value = 1
t2, c2 = cluster.getTransaction() t2, c2 = cluster.getTransaction()
x = c2.root()['x'] x2 = c2.root()['x']
p = Patch(cluster.client, _handlePacket=_handlePacket) p = Patch(cluster.client, _handlePacket=_handlePacket)
try: try:
t = self.newThread(t1.commit) t = self.newThread(t1.commit)
l1.acquire() l1.acquire()
t2.abort() t2.abort()
finally:
del p
l2.release() l2.release()
t.join() t.join()
self.assertEqual(x2.value, 1)
return # Following is disabled due to deadlock
# caused by client load lock
t1.begin()
x1.value = 0
x2._p_deactivate()
cluster.client._cache.clear()
p = Patch(cluster.client, _loadFromStorage=_loadFromStorage)
try:
t = self.newThread(x2._p_activate)
l1.acquire()
t1.commit()
t1.begin()
finally: finally:
del p del p
self.assertEqual(x.value, 1) l2.release()
t.join()
x1._p_deactivate()
self.assertEqual(x2.value, 1)
self.assertEqual(x1.value, 0)
finally: finally:
cluster.stop() cluster.stop()
...@@ -540,12 +566,13 @@ class Test(NEOThreadedTest): ...@@ -540,12 +566,13 @@ class Test(NEOThreadedTest):
cluster.start() cluster.start()
# Initialize objects # Initialize objects
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
c1.root()['x'] = x = PCounter() c1.root()['x'] = x1 = PCounter()
c1.root()['y'] = y = PCounter() c1.root()['y'] = y = PCounter()
y.value = 1 y.value = 1
t1.commit() t1.commit()
# Get pickle of y # Get pickle of y
t1.begin() t1.begin()
x = c1._storage.load(x1._p_oid)[0]
y = c1._storage.load(y._p_oid)[0] y = c1._storage.load(y._p_oid)[0]
# Start the testing transaction # Start the testing transaction
# (at this time, we still have x=0 and y=1) # (at this time, we still have x=0 and y=1)
...@@ -557,30 +584,67 @@ class Test(NEOThreadedTest): ...@@ -557,30 +584,67 @@ class Test(NEOThreadedTest):
client.setPoll(1) client.setPoll(1)
txn = transaction.Transaction() txn = transaction.Transaction()
client.tpc_begin(txn) client.tpc_begin(txn)
client.store(x._p_oid, x._p_serial, y, '', txn) client.store(x1._p_oid, x1._p_serial, y, '', txn)
# Delay invalidation for x # Delay invalidation for x
master_client = cluster.master.filterConnection(cluster.client) master_client = cluster.master.filterConnection(cluster.client)
try: try:
master_client.add(lambda conn, packet: master_client.add(lambda conn, packet:
isinstance(packet, Packets.InvalidateObjects)) isinstance(packet, Packets.InvalidateObjects))
client.tpc_finish(txn, None) tid = client.tpc_finish(txn, None)
client.close()
client.setPoll(0) client.setPoll(0)
cluster.client.setPoll(1) cluster.client.setPoll(1)
# Change to x is committed. Testing connection must ask the # Change to x is committed. Testing connection must ask the
# storage node to return original value of x, even if we # storage node to return original value of x, even if we
# haven't processed yet any invalidation for x. # haven't processed yet any invalidation for x.
x = c2.root()['x'] x2 = c2.root()['x']
cluster.client._cache.clear() # bypass cache cluster.client._cache.clear() # bypass cache
self.assertEqual(x.value, 0) self.assertEqual(x2.value, 0)
finally: finally:
master_client() master_client()
x._p_deactivate() x2._p_deactivate()
t1.abort() # process invalidation and sync connection storage t1.abort() # process invalidation and sync connection storage
self.assertEqual(x.value, 0) self.assertEqual(x2.value, 0)
# New testing transaction. Now we can see the last value of x. # New testing transaction. Now we can see the last value of x.
t2.abort() t2.abort()
self.assertEqual(x.value, 1) self.assertEqual(x2.value, 1)
# Now test cache invalidation during a load from a storage
l1 = threading.Lock(); l1.acquire()
l2 = threading.Lock(); l2.acquire()
def _loadFromStorage(orig, *args):
try:
return orig(*args)
finally:
l1.release()
l2.acquire()
x2._p_deactivate()
cluster.client._cache.clear()
p = Patch(cluster.client, _loadFromStorage=_loadFromStorage)
try:
t = self.newThread(x2._p_activate)
l1.acquire()
# At this point, x could not be found the cache and the result
# from the storage (which is <value=1, next_tid=None>) is about
# to processed.
# Now modify x to receive an invalidation for it.
cluster.client.setPoll(0)
client.setPoll(1)
txn = transaction.Transaction()
client.tpc_begin(txn)
client.store(x2._p_oid, tid, x, '', txn)
tid = client.tpc_finish(txn, None)
client.close()
client.setPoll(0)
cluster.client.setPoll(1)
t1.abort() # make sure invalidation is processed
finally:
del p
# Resume processing of answer from storage. An entry should be
# added in cache for x=1 with a fixed next_tid (i.e. not None)
l2.release()
t.join()
self.assertEqual(x2.value, 1)
self.assertEqual(x1.value, 0)
finally: finally:
cluster.stop() cluster.stop()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment