Commit 90fe4c8a authored by Grégory Wisniewski's avatar Grégory Wisniewski

Storage node check if all objects are stored before set write locks.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2102 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 2dc8f28a
...@@ -217,7 +217,7 @@ class EventHandler(object): ...@@ -217,7 +217,7 @@ class EventHandler(object):
def answerTransactionFinished(self, conn, tid): def answerTransactionFinished(self, conn, tid):
raise UnexpectedPacketError raise UnexpectedPacketError
def askLockInformation(self, conn, tid): def askLockInformation(self, conn, tid, oid_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerInformationLocked(self, conn, tid): def answerInformationLocked(self, conn, tid):
......
...@@ -75,7 +75,7 @@ class ClientServiceHandler(MasterHandler): ...@@ -75,7 +75,7 @@ class ClientServiceHandler(MasterHandler):
# Request locking data. # Request locking data.
# build a new set as we may not send the message to all nodes as some # build a new set as we may not send the message to all nodes as some
# might be not reachable at that time # might be not reachable at that time
p = Packets.AskLockInformation(tid) p = Packets.AskLockInformation(tid, oid_list)
used_uuid_set = set() used_uuid_set = set()
for node in self.app.nm.getIdentifiedList(pool_set=uuid_set): for node in self.app.nm.getIdentifiedList(pool_set=uuid_set):
node.ask(p, timeout=60) node.ask(p, timeout=60)
......
...@@ -783,12 +783,28 @@ class AskLockInformation(Packet): ...@@ -783,12 +783,28 @@ class AskLockInformation(Packet):
""" """
Lock information on a transaction. PM -> S. Lock information on a transaction. PM -> S.
""" """
def _encode(self, tid): # XXX: Identical to InvalidateObjects and AskFinishTransaction
return _encodeTID(tid) _header_format = '!8sL'
_list_entry_format = '8s'
_list_entry_len = calcsize(_list_entry_format)
def _encode(self, tid, oid_list):
body = [pack(self._header_format, tid, len(oid_list))]
body.extend(oid_list)
return ''.join(body)
def _decode(self, body): def _decode(self, body):
(tid, ) = unpack('8s', body) offset = self._header_len
return (_decodeTID(tid), ) (tid, n) = unpack(self._header_format, body[:offset])
oid_list = []
list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len
for _ in xrange(n):
next_offset = offset + list_entry_len
oid = unpack(list_entry_format, body[offset:next_offset])[0]
offset = next_offset
oid_list.append(oid)
return (tid, oid_list)
class AnswerInformationLocked(Packet): class AnswerInformationLocked(Packet):
""" """
......
...@@ -143,6 +143,13 @@ class PartitionTable(object): ...@@ -143,6 +143,13 @@ class PartitionTable(object):
return self.getCellList(self._getPartitionFromIndex(u64(oid)), return self.getCellList(self._getPartitionFromIndex(u64(oid)),
readable, writable) readable, writable)
def isAssigned(self, oid, uuid):
""" Check if the oid is assigned to the given node """
for cell in self.partition_list[u64(oid) % self.np]:
if cell.getUUID() == uuid:
return True
return False
def _getPartitionFromIndex(self, index): def _getPartitionFromIndex(self, index):
return index % self.np return index % self.np
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from neo import logging from neo import logging
from neo.util import dump
from neo.protocol import CellStates, Packets, ProtocolError from neo.protocol import CellStates, Packets, ProtocolError
from neo.storage.handlers import BaseMasterHandler from neo.storage.handlers import BaseMasterHandler
...@@ -52,10 +52,10 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -52,10 +52,10 @@ class MasterOperationHandler(BaseMasterHandler):
elif state == CellStates.OUT_OF_DATE: elif state == CellStates.OUT_OF_DATE:
app.replicator.addPartition(offset) app.replicator.addPartition(offset)
def askLockInformation(self, conn, tid): def askLockInformation(self, conn, tid, oid_list):
if not tid in self.app.tm: if not tid in self.app.tm:
raise ProtocolError('Unknown transaction') raise ProtocolError('Unknown transaction')
self.app.tm.lock(tid) self.app.tm.lock(tid, oid_list)
conn.answer(Packets.AnswerInformationLocked(tid)) conn.answer(Packets.AnswerInformationLocked(tid))
def notifyUnlockInformation(self, conn, tid): def notifyUnlockInformation(self, conn, tid):
......
...@@ -154,7 +154,7 @@ class TransactionManager(object): ...@@ -154,7 +154,7 @@ class TransactionManager(object):
self._load_lock_dict.clear() self._load_lock_dict.clear()
self._uuid_dict.clear() self._uuid_dict.clear()
def lock(self, tid): def lock(self, tid, oid_list):
""" """
Lock a transaction Lock a transaction
""" """
...@@ -163,6 +163,12 @@ class TransactionManager(object): ...@@ -163,6 +163,12 @@ class TransactionManager(object):
transaction.lock() transaction.lock()
for oid in transaction.getOIDList(): for oid in transaction.getOIDList():
self._load_lock_dict[oid] = tid self._load_lock_dict[oid] = tid
# check every object that should be locked
uuid = transaction.getUUID()
is_assigned = self._app.pt.isAssigned
for oid in oid_list:
if is_assigned(oid, uuid) and self._load_lock_dict.get(oid) != tid:
raise ValueError, 'Some locks are not held'
object_list = transaction.getObjectList() object_list = transaction.getObjectList()
# txn_info is None is the transaction information is not stored on # txn_info is None is the transaction information is not stored on
# this storage. # this storage.
......
...@@ -132,19 +132,22 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -132,19 +132,22 @@ class StorageMasterHandlerTests(NeoTestBase):
""" Unknown transaction """ """ Unknown transaction """
self.app.tm = Mock({'__contains__': False}) self.app.tm = Mock({'__contains__': False})
conn = self._getConnection() conn = self._getConnection()
oid_list = [self.getOID(1), self.getOID(2)]
tid = self.getNextTID() tid = self.getNextTID()
handler = self.operation handler = self.operation
self.assertRaises(ProtocolError, handler.askLockInformation, conn, tid) self.assertRaises(ProtocolError, handler.askLockInformation, conn, tid,
oid_list)
def test_askLockInformation2(self): def test_askLockInformation2(self):
""" Lock transaction """ """ Lock transaction """
self.app.tm = Mock({'__contains__': True}) self.app.tm = Mock({'__contains__': True})
conn = self._getConnection() conn = self._getConnection()
tid = self.getNextTID() tid = self.getNextTID()
self.operation.askLockInformation(conn, tid) oid_list = [self.getOID(1), self.getOID(2)]
self.operation.askLockInformation(conn, tid, oid_list)
calls = self.app.tm.mockGetNamedCalls('lock') calls = self.app.tm.mockGetNamedCalls('lock')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid) calls[0].checkArgs(tid, oid_list)
self.checkAnswerInformationLocked(conn) self.checkAnswerInformationLocked(conn)
def test_notifyUnlockInformation1(self): def test_notifyUnlockInformation1(self):
...@@ -153,7 +156,7 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -153,7 +156,7 @@ class StorageMasterHandlerTests(NeoTestBase):
conn = self._getConnection() conn = self._getConnection()
tid = self.getNextTID() tid = self.getNextTID()
handler = self.operation handler = self.operation
self.assertRaises(ProtocolError, handler.notifyUnlockInformation, self.assertRaises(ProtocolError, handler.notifyUnlockInformation,
conn, tid) conn, tid)
def test_notifyUnlockInformation2(self): def test_notifyUnlockInformation2(self):
......
...@@ -78,6 +78,7 @@ class TransactionManagerTests(NeoTestBase): ...@@ -78,6 +78,7 @@ class TransactionManagerTests(NeoTestBase):
self.app = Mock() self.app = Mock()
# no history # no history
self.app.dm = Mock({'getObjectHistory': []}) self.app.dm = Mock({'getObjectHistory': []})
self.app.pt = Mock({'isAssigned': True})
self.manager = TransactionManager(self.app) self.manager = TransactionManager(self.app)
self.ltid = None self.ltid = None
...@@ -86,6 +87,11 @@ class TransactionManagerTests(NeoTestBase): ...@@ -86,6 +87,11 @@ class TransactionManagerTests(NeoTestBase):
oid_list = [self.getOID(1), self.getOID(2)] oid_list = [self.getOID(1), self.getOID(2)]
return (tid, (oid_list, 'USER', 'DESC', 'EXT', False)) return (tid, (oid_list, 'USER', 'DESC', 'EXT', False))
def _storeTransactionObjects(self, tid, txn):
for i, oid in enumerate(txn[0]):
self.manager.storeObject(self.getNewUUID(), tid, None,
oid, 1, str(i), '0' + str(i), None)
def _getObject(self, value): def _getObject(self, value):
oid = self.getOID(value) oid = self.getOID(value)
serial = self.getNextTID() serial = self.getNextTID()
...@@ -115,7 +121,7 @@ class TransactionManagerTests(NeoTestBase): ...@@ -115,7 +121,7 @@ class TransactionManagerTests(NeoTestBase):
self.manager.storeObject(uuid, tid, serial1, *object1) self.manager.storeObject(uuid, tid, serial1, *object1)
self.manager.storeObject(uuid, tid, serial2, *object2) self.manager.storeObject(uuid, tid, serial2, *object2)
self.assertTrue(tid in self.manager) self.assertTrue(tid in self.manager)
self.manager.lock(tid) self.manager.lock(tid, txn[0])
self._checkTransactionStored(tid, [object1, object2], txn) self._checkTransactionStored(tid, [object1, object2], txn)
self.manager.unlock(tid) self.manager.unlock(tid)
self.assertFalse(tid in self.manager) self.assertFalse(tid in self.manager)
...@@ -130,8 +136,8 @@ class TransactionManagerTests(NeoTestBase): ...@@ -130,8 +136,8 @@ class TransactionManagerTests(NeoTestBase):
# first transaction lock the object # first transaction lock the object
self.manager.storeTransaction(uuid, tid1, *txn1) self.manager.storeTransaction(uuid, tid1, *txn1)
self.assertTrue(tid1 in self.manager) self.assertTrue(tid1 in self.manager)
self.manager.storeObject(uuid, tid1, serial, *obj) self._storeTransactionObjects(tid1, txn1)
self.manager.lock(tid1) self.manager.lock(tid1, txn1[0])
# the second is delayed # the second is delayed
self.manager.storeTransaction(uuid, tid2, *txn2) self.manager.storeTransaction(uuid, tid2, *txn2)
self.assertTrue(tid2 in self.manager) self.assertTrue(tid2 in self.manager)
...@@ -148,7 +154,8 @@ class TransactionManagerTests(NeoTestBase): ...@@ -148,7 +154,8 @@ class TransactionManagerTests(NeoTestBase):
self.manager.storeTransaction(uuid, tid2, *txn2) self.manager.storeTransaction(uuid, tid2, *txn2)
self.manager.storeObject(uuid, tid2, serial, *obj) self.manager.storeObject(uuid, tid2, serial, *obj)
self.assertTrue(tid2 in self.manager) self.assertTrue(tid2 in self.manager)
self.manager.lock(tid2) self._storeTransactionObjects(tid2, txn2)
self.manager.lock(tid2, txn2[0])
# the previous it's not using the latest version # the previous it's not using the latest version
self.manager.storeTransaction(uuid, tid1, *txn1) self.manager.storeTransaction(uuid, tid1, *txn1)
self.assertTrue(tid1 in self.manager) self.assertTrue(tid1 in self.manager)
...@@ -167,8 +174,8 @@ class TransactionManagerTests(NeoTestBase): ...@@ -167,8 +174,8 @@ class TransactionManagerTests(NeoTestBase):
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
uuid, tid, serial, *obj) uuid, tid, serial, *obj)
def testConflictWithTwoNodes(self): def testLockDelayed(self):
""" Ensure conflict/delaytion is working with different nodes""" """ Check lock delaytion"""
uuid1 = self.getNewUUID() uuid1 = self.getNewUUID()
uuid2 = self.getNewUUID() uuid2 = self.getNewUUID()
self.assertNotEqual(uuid1, uuid2) self.assertNotEqual(uuid1, uuid2)
...@@ -176,25 +183,41 @@ class TransactionManagerTests(NeoTestBase): ...@@ -176,25 +183,41 @@ class TransactionManagerTests(NeoTestBase):
tid2, txn2 = self._getTransaction() tid2, txn2 = self._getTransaction()
serial1, obj1 = self._getObject(1) serial1, obj1 = self._getObject(1)
serial2, obj2 = self._getObject(2) serial2, obj2 = self._getObject(2)
# first transaction lock the object # first transaction lock objects
self.manager.storeTransaction(uuid1, tid1, *txn1) self.manager.storeTransaction(uuid1, tid1, *txn1)
self.assertTrue(tid1 in self.manager) self.assertTrue(tid1 in self.manager)
self.manager.storeObject(uuid1, tid1, serial1, *obj1) self.manager.storeObject(uuid1, tid1, serial1, *obj1)
self.manager.lock(tid1) self.manager.storeObject(uuid1, tid1, serial1, *obj2)
self.manager.lock(tid1, txn1[0])
# second transaction is delayed # second transaction is delayed
self.manager.storeTransaction(uuid2, tid2, *txn2) self.manager.storeTransaction(uuid2, tid2, *txn2)
self.assertTrue(tid2 in self.manager) self.assertTrue(tid2 in self.manager)
self.assertRaises(DelayedError, self.manager.storeObject, self.assertRaises(DelayedError, self.manager.storeObject,
uuid2, tid2, serial1, *obj1) uuid2, tid2, serial1, *obj1)
# the second transaction lock another object self.assertRaises(DelayedError, self.manager.storeObject,
uuid2, tid2, serial2, *obj2)
def testLockConflict(self):
""" Check lock conflict """
uuid1 = self.getNewUUID()
uuid2 = self.getNewUUID()
self.assertNotEqual(uuid1, uuid2)
tid1, txn1 = self._getTransaction()
tid2, txn2 = self._getTransaction()
serial1, obj1 = self._getObject(1)
serial2, obj2 = self._getObject(2)
# the second transaction lock objects
self.manager.storeTransaction(uuid2, tid2, *txn2) self.manager.storeTransaction(uuid2, tid2, *txn2)
self.manager.storeObject(uuid2, tid2, serial1, *obj1)
self.manager.storeObject(uuid2, tid2, serial2, *obj2) self.manager.storeObject(uuid2, tid2, serial2, *obj2)
self.assertTrue(tid2 in self.manager) self.assertTrue(tid2 in self.manager)
self.manager.lock(tid2) self.manager.lock(tid2, txn1[0])
# the first get a conflict # the first get a conflict
self.manager.storeTransaction(uuid1, tid1, *txn1) self.manager.storeTransaction(uuid1, tid1, *txn1)
self.assertTrue(tid1 in self.manager) self.assertTrue(tid1 in self.manager)
self.assertRaises(ConflictError, self.manager.storeObject, self.assertRaises(ConflictError, self.manager.storeObject,
uuid1, tid1, serial1, *obj1)
self.assertRaises(ConflictError, self.manager.storeObject,
uuid1, tid1, serial2, *obj2) uuid1, tid1, serial2, *obj2)
def testAbortUnlocked(self): def testAbortUnlocked(self):
...@@ -215,15 +238,15 @@ class TransactionManagerTests(NeoTestBase): ...@@ -215,15 +238,15 @@ class TransactionManagerTests(NeoTestBase):
""" Try to abort a locked transaction """ """ Try to abort a locked transaction """
uuid = self.getNewUUID() uuid = self.getNewUUID()
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
serial, obj = self._getObject(1)
self.manager.storeTransaction(uuid, tid, *txn) self.manager.storeTransaction(uuid, tid, *txn)
self.manager.storeObject(uuid, tid, serial, *obj) self._storeTransactionObjects(tid, txn)
# lock transaction # lock transaction
self.manager.lock(tid) self.manager.lock(tid, txn[0])
self.assertTrue(tid in self.manager) self.assertTrue(tid in self.manager)
self.manager.abort(tid, even_if_locked=False) self.manager.abort(tid, even_if_locked=False)
self.assertTrue(tid in self.manager) self.assertTrue(tid in self.manager)
self.assertTrue(self.manager.loadLocked(obj[0])) for oid in txn[0]:
self.assertTrue(self.manager.loadLocked(oid))
self._checkQueuedEventExecuted(number=0) self._checkQueuedEventExecuted(number=0)
def testAbortForNode(self): def testAbortForNode(self):
...@@ -238,7 +261,8 @@ class TransactionManagerTests(NeoTestBase): ...@@ -238,7 +261,8 @@ class TransactionManagerTests(NeoTestBase):
# node 2 owns tid2 & tid3 and lock tid2 only # node 2 owns tid2 & tid3 and lock tid2 only
self.manager.storeTransaction(uuid2, tid2, *txn2) self.manager.storeTransaction(uuid2, tid2, *txn2)
self.manager.storeTransaction(uuid2, tid3, *txn3) self.manager.storeTransaction(uuid2, tid3, *txn3)
self.manager.lock(tid2) self._storeTransactionObjects(tid2, txn2)
self.manager.lock(tid2, txn2[0])
self.assertTrue(tid1 in self.manager) self.assertTrue(tid1 in self.manager)
self.assertTrue(tid2 in self.manager) self.assertTrue(tid2 in self.manager)
self.assertTrue(tid3 in self.manager) self.assertTrue(tid3 in self.manager)
...@@ -253,14 +277,14 @@ class TransactionManagerTests(NeoTestBase): ...@@ -253,14 +277,14 @@ class TransactionManagerTests(NeoTestBase):
""" Reset the manager """ """ Reset the manager """
uuid = self.getNewUUID() uuid = self.getNewUUID()
tid, txn = self._getTransaction() tid, txn = self._getTransaction()
serial, obj = self._getObject(1)
self.manager.storeTransaction(uuid, tid, *txn) self.manager.storeTransaction(uuid, tid, *txn)
self.manager.storeObject(uuid, tid, serial, *obj) self._storeTransactionObjects(tid, txn)
self.manager.lock(tid) self.manager.lock(tid, txn[0])
self.assertTrue(tid in self.manager) self.assertTrue(tid in self.manager)
self.manager.reset() self.manager.reset()
self.assertFalse(tid in self.manager) self.assertFalse(tid in self.manager)
self.assertFalse(self.manager.loadLocked(obj[0])) for oid in txn[0]:
self.assertFalse(self.manager.loadLocked(oid))
def test_getObjectFromTransaction(self): def test_getObjectFromTransaction(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
......
...@@ -290,10 +290,14 @@ class ProtocolTests(NeoTestBase): ...@@ -290,10 +290,14 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
def test_38_askLockInformation(self): def test_38_askLockInformation(self):
oid1 = self.getNextTID()
oid2 = self.getNextTID()
oid_list = [oid1, oid2]
tid = self.getNextTID() tid = self.getNextTID()
p = Packets.AskLockInformation(tid) p = Packets.AskLockInformation(tid, oid_list)
ptid = p.decode()[0] ptid, p_oid_list = p.decode()
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertEqual(oid_list, p_oid_list)
def test_39_answerInformationLocked(self): def test_39_answerInformationLocked(self):
tid = self.getNextTID() tid = self.getNextTID()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment