Commit 71a0de50 authored by Vincent Pelletier's avatar Vincent Pelletier

Implement revision-aware caching.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2531 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 6b799777
This diff is collapsed.
...@@ -123,10 +123,7 @@ class PrimaryNotificationsHandler(BaseHandler): ...@@ -123,10 +123,7 @@ class PrimaryNotificationsHandler(BaseHandler):
app._cache_lock_acquire() app._cache_lock_acquire()
try: try:
# ZODB required a dict with oid as key, so create it # ZODB required a dict with oid as key, so create it
mq_cache = app.mq_cache app.cache_revision_index.invalidate(oid_list, tid)
for oid in oid_list:
if oid in mq_cache:
del mq_cache[oid]
db = app.getDB() db = app.getDB()
if db is not None: if db is not None:
db.invalidate(tid, dict.fromkeys(oid_list, tid)) db.invalidate(tid, dict.fromkeys(oid_list, tid))
......
...@@ -21,7 +21,7 @@ from cPickle import dumps ...@@ -21,7 +21,7 @@ from cPickle import dumps
from mock import Mock, ReturnValues from mock import Mock, ReturnValues
from ZODB.POSException import StorageTransactionError, UndoError, ConflictError from ZODB.POSException import StorageTransactionError, UndoError, ConflictError
from neo.tests import NeoUnitTestBase from neo.tests import NeoUnitTestBase
from neo.client.app import Application from neo.client.app import Application, RevisionIndex
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError from neo.client.exception import NEOStorageError, NEOStorageNotFoundError
from neo.client.exception import NEOStorageDoesNotExistError from neo.client.exception import NEOStorageDoesNotExistError
from neo.protocol import Packet, Packets, Errors, INVALID_TID, INVALID_SERIAL from neo.protocol import Packet, Packets, Errors, INVALID_TID, INVALID_SERIAL
...@@ -208,7 +208,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -208,7 +208,8 @@ class ClientApplicationTests(NeoUnitTestBase):
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ', None) an_object = (1, oid, tid1, tid2, 0, makeChecksum('OBJ'), 'OBJ', None)
# connection to SN close # connection to SN close
self.assertTrue(oid not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq)
packet = Errors.OidNotFound('') packet = Errors.OidNotFound('')
packet.setId(0) packet.setId(0)
cell = Mock({ 'getUUID': '\x00' * 16}) cell = Mock({ 'getUUID': '\x00' * 16})
...@@ -224,7 +225,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -224,7 +225,8 @@ class ClientApplicationTests(NeoUnitTestBase):
self.checkAskObject(conn) self.checkAskObject(conn)
Application._waitMessage = _waitMessage Application._waitMessage = _waitMessage
# object not found in NEO -> NEOStorageNotFoundError # object not found in NEO -> NEOStorageNotFoundError
self.assertTrue(oid not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq)
packet = Errors.OidNotFound('') packet = Errors.OidNotFound('')
packet.setId(0) packet.setId(0)
cell = Mock({ 'getUUID': '\x00' * 16}) cell = Mock({ 'getUUID': '\x00' * 16})
...@@ -254,7 +256,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -254,7 +256,7 @@ class ClientApplicationTests(NeoUnitTestBase):
result = app.load(oid) result = app.load(oid)
self.assertEquals(result, ('OBJ', tid1)) self.assertEquals(result, ('OBJ', tid1))
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertTrue(oid in mq) self.assertTrue((oid, tid1) in mq)
# object is now cached, try to reload it # object is now cached, try to reload it
conn = Mock({ conn = Mock({
'getAddress': ('127.0.0.1', 0), 'getAddress': ('127.0.0.1', 0),
...@@ -272,7 +274,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -272,7 +274,8 @@ class ClientApplicationTests(NeoUnitTestBase):
tid1 = self.makeTID(1) tid1 = self.makeTID(1)
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
# object not found in NEO -> NEOStorageNotFoundError # object not found in NEO -> NEOStorageNotFoundError
self.assertTrue(oid not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq)
packet = Errors.OidNotFound('') packet = Errors.OidNotFound('')
packet.setId(0) packet.setId(0)
cell = Mock({ 'getUUID': '\x00' * 16}) cell = Mock({ 'getUUID': '\x00' * 16})
...@@ -285,10 +288,10 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -285,10 +288,10 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertRaises(NEOStorageNotFoundError, app.loadSerial, oid, tid2) self.assertRaises(NEOStorageNotFoundError, app.loadSerial, oid, tid2)
self.checkAskObject(conn) self.checkAskObject(conn)
# object should not have been cached # object should not have been cached
self.assertFalse(oid in mq) self.assertFalse((oid, tid2) in mq)
# now a cached version ewxists but should not be hit # now a cached version ewxists but should not be hit
mq.store(oid, (tid2, 'WRONG')) mq.store((oid, tid2), ('WRONG', None))
self.assertTrue(oid in mq) self.assertTrue((oid, tid2) in mq)
another_object = (1, oid, tid2, INVALID_SERIAL, 0, another_object = (1, oid, tid2, INVALID_SERIAL, 0,
makeChecksum('RIGHT'), 'RIGHT', None) makeChecksum('RIGHT'), 'RIGHT', None)
packet = Packets.AnswerObject(*another_object[1:]) packet = Packets.AnswerObject(*another_object[1:])
...@@ -302,7 +305,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -302,7 +305,7 @@ class ClientApplicationTests(NeoUnitTestBase):
result = app.loadSerial(oid, tid1) result = app.loadSerial(oid, tid1)
self.assertEquals(result, 'RIGHT') self.assertEquals(result, 'RIGHT')
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertTrue(oid in mq) self.assertTrue((oid, tid2) in mq)
def test_loadBefore(self): def test_loadBefore(self):
app = self.getApp() app = self.getApp()
...@@ -313,7 +316,8 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -313,7 +316,8 @@ class ClientApplicationTests(NeoUnitTestBase):
tid2 = self.makeTID(2) tid2 = self.makeTID(2)
tid3 = self.makeTID(3) tid3 = self.makeTID(3)
# object not found in NEO -> NEOStorageDoesNotExistError # object not found in NEO -> NEOStorageDoesNotExistError
self.assertTrue(oid not in mq) self.assertTrue((oid, tid1) not in mq)
self.assertTrue((oid, tid2) not in mq)
packet = Errors.OidDoesNotExist('') packet = Errors.OidDoesNotExist('')
packet.setId(0) packet.setId(0)
cell = Mock({ 'getUUID': '\x00' * 16}) cell = Mock({ 'getUUID': '\x00' * 16})
...@@ -337,11 +341,12 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -337,11 +341,12 @@ class ClientApplicationTests(NeoUnitTestBase):
app.local_var.asked_object = an_object[:-1] app.local_var.asked_object = an_object[:-1]
self.assertRaises(NEOStorageError, app.loadBefore, oid, tid1) self.assertRaises(NEOStorageError, app.loadBefore, oid, tid1)
# object should not have been cached # object should not have been cached
self.assertFalse(oid in mq) self.assertFalse((oid, tid1) in mq)
# as for loadSerial, the object is cached but should be loaded from db # as for loadSerial, the object is cached but should be loaded from db
mq.store(oid, (tid1, 'WRONG')) mq.store((oid, tid1), ('WRONG', tid2))
self.assertTrue(oid in mq) self.assertTrue((oid, tid1) in mq)
another_object = (1, oid, tid1, tid2, 0, makeChecksum('RIGHT'), app.cache_revision_index.invalidate([oid], tid2)
another_object = (1, oid, tid2, tid3, 0, makeChecksum('RIGHT'),
'RIGHT', None) 'RIGHT', None)
packet = Packets.AnswerObject(*another_object[1:]) packet = Packets.AnswerObject(*another_object[1:])
packet.setId(0) packet.setId(0)
...@@ -352,9 +357,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -352,9 +357,9 @@ class ClientApplicationTests(NeoUnitTestBase):
app.cp = Mock({ 'getConnForCell' : conn}) app.cp = Mock({ 'getConnForCell' : conn})
app.local_var.asked_object = another_object app.local_var.asked_object = another_object
result = app.loadBefore(oid, tid3) result = app.loadBefore(oid, tid3)
self.assertEquals(result, ('RIGHT', tid1, tid2)) self.assertEquals(result, ('RIGHT', tid2, tid3))
self.checkAskObject(conn) self.checkAskObject(conn)
self.assertTrue(oid in mq) self.assertTrue((oid, tid1) in mq)
def test_tpc_begin(self): def test_tpc_begin(self):
app = self.getApp() app = self.getApp()
...@@ -1156,6 +1161,90 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -1156,6 +1161,90 @@ class ClientApplicationTests(NeoUnitTestBase):
self.assertEqual(marker[0].getType(), Packets.AskPack) self.assertEqual(marker[0].getType(), Packets.AskPack)
# XXX: how to validate packet content ? # XXX: how to validate packet content ?
def test_RevisionIndex_1(self):
# Test add, getLatestSerial, getSerialList and clear
# without invalidations
oid1 = self.getOID(1)
oid2 = self.getOID(2)
tid1 = self.getOID(1)
tid2 = self.getOID(2)
tid3 = self.getOID(3)
ri = RevisionIndex()
# index is empty
self.assertEqual(ri.getSerialList(oid1), [])
ri.add((oid1, tid1))
# now, it knows oid1 at tid1
self.assertEqual(ri.getLatestSerial(oid1), tid1)
self.assertEqual(ri.getSerialList(oid1), [tid1])
self.assertEqual(ri.getSerialList(oid2), [])
ri.add((oid1, tid2))
# and at tid2
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialList(oid1), [tid2, tid1])
ri.remove((oid1, tid1))
# oid1 at tid1 was pruned from cache
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialList(oid1), [tid2])
ri.remove((oid1, tid2))
# oid1 is completely priuned from cache
self.assertEqual(ri.getLatestSerial(oid1), None)
self.assertEqual(ri.getSerialList(oid1), [])
ri.add((oid1, tid2))
ri.add((oid1, tid1))
# oid1 is populated, but in non-chronological order, check index
# still answers consistent result.
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialList(oid1), [tid2, tid1])
ri.add((oid2, tid3))
# which is not affected by the addition of oid2 at tid3
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialList(oid1), [tid2, tid1])
ri.clear()
# index is empty again
self.assertEqual(ri.getSerialList(oid1), [])
self.assertEqual(ri.getSerialList(oid2), [])
def test_RevisionIndex_2(self):
# Test getLatestSerial & getSerialBefore with invalidations
oid1 = self.getOID(1)
tid1 = self.getOID(1)
tid2 = self.getOID(2)
tid3 = self.getOID(3)
tid4 = self.getOID(4)
tid5 = self.getOID(5)
tid6 = self.getOID(6)
ri = RevisionIndex()
ri.add((oid1, tid1))
ri.add((oid1, tid2))
self.assertEqual(ri.getLatestSerial(oid1), tid2)
self.assertEqual(ri.getSerialBefore(oid1, tid2), tid1)
self.assertEqual(ri.getSerialBefore(oid1, tid3), tid2)
self.assertEqual(ri.getSerialBefore(oid1, tid4), tid2)
ri.invalidate([oid1], tid3)
# We don't have the latest data in cache, return None
self.assertEqual(ri.getLatestSerial(oid1), None)
self.assertEqual(ri.getSerialBefore(oid1, tid2), tid1)
self.assertEqual(ri.getSerialBefore(oid1, tid3), tid2)
# There is a gap between the last version we have and requested one,
# return None
self.assertEqual(ri.getSerialBefore(oid1, tid4), None)
ri.add((oid1, tid3))
# No gap anymore, tid3 found.
self.assertEqual(ri.getLatestSerial(oid1), tid3)
self.assertEqual(ri.getSerialBefore(oid1, tid4), tid3)
ri.invalidate([oid1], tid4)
ri.invalidate([oid1], tid5)
# A bigger gap...
self.assertEqual(ri.getLatestSerial(oid1), None)
self.assertEqual(ri.getSerialBefore(oid1, tid5), None)
self.assertEqual(ri.getSerialBefore(oid1, tid6), None)
# not entirely filled.
ri.add((oid1, tid5))
# Still, we know the latest and what is before tid6
self.assertEqual(ri.getLatestSerial(oid1), tid5)
self.assertEqual(ri.getSerialBefore(oid1, tid5), None)
self.assertEqual(ri.getSerialBefore(oid1, tid6), tid5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -156,17 +156,22 @@ class MasterNotificationsHandlerTests(MasterHandlerTests): ...@@ -156,17 +156,22 @@ class MasterNotificationsHandlerTests(MasterHandlerTests):
def test_invalidateObjects(self): def test_invalidateObjects(self):
conn = self.getConnection() conn = self.getConnection()
tid = self.getNextTID() tid = self.getNextTID()
oid1, oid2 = self.getOID(1), self.getOID(2) oid1, oid2, oid3 = self.getOID(1), self.getOID(2), self.getOID(3)
self.app.mq_cache = { self.app.mq_cache = {
oid1: tid, (oid1, tid): ('bla', None),
oid2: tid, (oid2, tid): ('bla', None),
} }
self.handler.invalidateObjects(conn, tid, [oid1]) self.app.cache_revision_index = Mock({
self.assertFalse(oid1 in self.app.mq_cache) 'invalidate': None,
self.assertTrue(oid2 in self.app.mq_cache) })
self.handler.invalidateObjects(conn, tid, [oid1, oid3])
cache_calls = self.app.cache_revision_index.mockGetNamedCalls(
'invalidate')
self.assertEqual(len(cache_calls), 1)
cache_calls[0].checkArgs([oid1, oid3], tid)
invalidation_calls = self.db.mockGetNamedCalls('invalidate') invalidation_calls = self.db.mockGetNamedCalls('invalidate')
self.assertEqual(len(invalidation_calls), 1) self.assertEqual(len(invalidation_calls), 1)
invalidation_calls[0].checkArgs(tid, {oid1:tid}) invalidation_calls[0].checkArgs(tid, {oid1:tid, oid3:tid})
def test_notifyPartitionChanges(self): def test_notifyPartitionChanges(self):
conn = self.getConnection() conn = self.getConnection()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment