Commit 7b74fa53 authored by Vincent Pelletier's avatar Vincent Pelletier

Make it possible for client to send data_serial to storage nodes.

This makes it possible to implement an undo on client side, as a flaw has
been found in it when used in parallel with replication.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2273 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 7a3c8b2c
...@@ -594,32 +594,37 @@ class Application(object): ...@@ -594,32 +594,37 @@ class Application(object):
raise StorageTransactionError(self, transaction) raise StorageTransactionError(self, transaction)
logging.debug('storing oid %s serial %s', logging.debug('storing oid %s serial %s',
dump(oid), dump(serial)) dump(oid), dump(serial))
self._store(oid, serial, data)
return None
def _store(self, oid, serial, data, data_serial=None):
# Find which storage node to use # Find which storage node to use
cell_list = self._getCellListForOID(oid, writable=True) cell_list = self._getCellListForOID(oid, writable=True)
if len(cell_list) == 0: if len(cell_list) == 0:
raise NEOStorageError raise NEOStorageError
if data is None: if data is None or data_serial is not None:
assert data is None or data_serial is None, data_serial
# this is a George Bailey object, stored as an empty string # this is a George Bailey object, stored as an empty string
data = '' compressed_data = ''
if self.compress:
compressed_data = compress(data)
if len(compressed_data) > len(data):
compressed_data = data
compression = 0
else:
compression = 1
else:
compressed_data = data
compression = 0 compression = 0
else:
assert data_serial is None
if self.compress:
compressed_data = compress(data)
if len(compressed_data) > len(data):
compressed_data = data
compression = 0
else:
compression = 1
checksum = makeChecksum(compressed_data) checksum = makeChecksum(compressed_data)
p = Packets.AskStoreObject(oid, serial, compression, p = Packets.AskStoreObject(oid, serial, compression,
checksum, compressed_data, self.local_var.tid) checksum, compressed_data, data_serial, self.local_var.tid)
on_timeout = OnTimeout(self.onStoreTimeout, self.local_var.tid, oid) on_timeout = OnTimeout(self.onStoreTimeout, self.local_var.tid, oid)
# Store object in tmp cache # Store object in tmp cache
self.local_var.data_dict[oid] = data self.local_var.data_dict[oid] = data
# Store data on each node # Store data on each node
self.local_var.object_stored_counter_dict[oid] = {} self.local_var.object_stored_counter_dict[oid] = {}
self.local_var.object_serial_dict[oid] = (serial, version) self.local_var.object_serial_dict[oid] = serial
getConnForCell = self.cp.getConnForCell getConnForCell = self.cp.getConnForCell
queue = self.local_var.queue queue = self.local_var.queue
add_involved_nodes = self.local_var.involved_nodes.add add_involved_nodes = self.local_var.involved_nodes.add
...@@ -634,7 +639,6 @@ class Application(object): ...@@ -634,7 +639,6 @@ class Application(object):
continue continue
self._waitAnyMessage(False) self._waitAnyMessage(False)
return None
def onStoreTimeout(self, conn, msg_id, tid, oid): def onStoreTimeout(self, conn, msg_id, tid, oid):
# NOTE: this method is called from poll thread, don't use # NOTE: this method is called from poll thread, don't use
...@@ -664,7 +668,7 @@ class Application(object): ...@@ -664,7 +668,7 @@ class Application(object):
# A later serial has already been resolved, skip. # A later serial has already been resolved, skip.
resolved_serial_set.update(conflict_serial_dict.pop(oid)) resolved_serial_set.update(conflict_serial_dict.pop(oid))
continue continue
serial, version = object_serial_dict[oid] serial = object_serial_dict[oid]
data = data_dict[oid] data = data_dict[oid]
tid = local_var.tid tid = local_var.tid
resolved = False resolved = False
...@@ -677,8 +681,7 @@ class Application(object): ...@@ -677,8 +681,7 @@ class Application(object):
# Mark this conflict as resolved # Mark this conflict as resolved
resolved_serial_set.update(conflict_serial_dict.pop(oid)) resolved_serial_set.update(conflict_serial_dict.pop(oid))
# Try to store again # Try to store again
self.store(oid, conflict_serial, new_data, version, self._store(oid, conflict_serial, new_data)
local_var.txn)
append(oid) append(oid)
resolved = True resolved = True
else: else:
...@@ -939,7 +942,7 @@ class Application(object): ...@@ -939,7 +942,7 @@ class Application(object):
raise UndoError('Some data were modified by a later ' \ raise UndoError('Some data were modified by a later ' \
'transaction', oid) 'transaction', oid)
else: else:
self.store(oid, data_tid, new_data, '', self.local_var.txn) self._store(oid, data_tid, new_data)
oid_list = self.local_var.txn_info['oids'] oid_list = self.local_var.txn_info['oids']
# Consistency checking: all oids of the transaction must have been # Consistency checking: all oids of the transaction must have been
......
...@@ -227,7 +227,7 @@ class EventHandler(object): ...@@ -227,7 +227,7 @@ class EventHandler(object):
raise UnexpectedPacketError raise UnexpectedPacketError
def askStoreObject(self, conn, oid, serial, def askStoreObject(self, conn, oid, serial,
compression, checksum, data, tid): compression, checksum, data, data_serial, tid):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerStoreObject(self, conn, conflicting, oid, serial): def answerStoreObject(self, conn, conflicting, oid, serial):
......
...@@ -883,22 +883,26 @@ class AskStoreObject(Packet): ...@@ -883,22 +883,26 @@ class AskStoreObject(Packet):
Ask to store an object. Send an OID, an original serial, a current Ask to store an object. Send an OID, an original serial, a current
transaction ID, and data. C -> S. transaction ID, and data. C -> S.
""" """
_header_format = '!8s8s8sBL' _header_format = '!8s8s8sBL8s'
@profiler_decorator @profiler_decorator
def _encode(self, oid, serial, compression, checksum, data, tid): def _encode(self, oid, serial, compression, checksum, data, data_serial,
tid):
if serial is None: if serial is None:
serial = INVALID_TID serial = INVALID_TID
if data_serial is None:
data_serial = INVALID_TID
return pack(self._header_format, oid, serial, tid, compression, return pack(self._header_format, oid, serial, tid, compression,
checksum) + _encodeString(data) checksum, data_serial) + _encodeString(data)
def _decode(self, body): def _decode(self, body):
header_len = self._header_len header_len = self._header_len
r = unpack(self._header_format, body[:header_len]) r = unpack(self._header_format, body[:header_len])
oid, serial, tid, compression, checksum = r oid, serial, tid, compression, checksum, data_serial = r
serial = _decodeTID(serial) serial = _decodeTID(serial)
data_serial = _decodeTID(data_serial)
(data, _) = _decodeString(body, 'data', offset=header_len) (data, _) = _decodeString(body, 'data', offset=header_len)
return (oid, serial, compression, checksum, data, tid) return (oid, serial, compression, checksum, data, data_serial, tid)
class AnswerStoreObject(Packet): class AnswerStoreObject(Packet):
""" """
......
...@@ -47,7 +47,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -47,7 +47,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
conn.answer(Packets.AnswerStoreTransaction(tid)) conn.answer(Packets.AnswerStoreTransaction(tid))
def _askStoreObject(self, conn, oid, serial, compression, checksum, data, def _askStoreObject(self, conn, oid, serial, compression, checksum, data,
tid, request_time): data_serial, tid, request_time):
if tid not in self.app.tm: if tid not in self.app.tm:
# transaction was aborted, cancel this event # transaction was aborted, cancel this event
logging.info('Forget store of %s:%s by %s delayed by %s', logging.info('Forget store of %s:%s by %s delayed by %s',
...@@ -58,7 +58,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -58,7 +58,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
return return
try: try:
self.app.tm.storeObject(tid, serial, oid, compression, self.app.tm.storeObject(tid, serial, oid, compression,
checksum, data, None) checksum, data, data_serial)
except ConflictError, err: except ConflictError, err:
# resolvable or not # resolvable or not
tid_or_serial = err.getTID() tid_or_serial = err.getTID()
...@@ -75,11 +75,11 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -75,11 +75,11 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
conn.answer(Packets.AnswerStoreObject(0, oid, serial)) conn.answer(Packets.AnswerStoreObject(0, oid, serial))
def askStoreObject(self, conn, oid, serial, def askStoreObject(self, conn, oid, serial,
compression, checksum, data, tid): compression, checksum, data, data_serial, tid):
# register the transaction # register the transaction
self.app.tm.register(conn.getUUID(), tid) self.app.tm.register(conn.getUUID(), tid)
self._askStoreObject(conn, oid, serial, compression, checksum, data, self._askStoreObject(conn, oid, serial, compression, checksum, data,
tid, time.time()) data_serial, tid, time.time())
def askTIDs(self, conn, first, last, partition): def askTIDs(self, conn, first, last, partition):
# This method is complicated, because I must return TIDs only # This method is complicated, because I must return TIDs only
......
...@@ -215,10 +215,11 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -215,10 +215,11 @@ class StorageClientHandlerTests(NeoTestBase):
conn = self._getConnection(uuid=uuid) conn = self._getConnection(uuid=uuid)
tid = self.getNextTID() tid = self.getNextTID()
oid, serial, comp, checksum, data = self._getObject() oid, serial, comp, checksum, data = self._getObject()
data_tid = self.getNextTID()
self.operation.askStoreObject(conn, oid, serial, comp, checksum, self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, tid) data, data_tid, tid)
self._checkStoreObjectCalled(tid, serial, oid, comp, self._checkStoreObjectCalled(tid, serial, oid, comp,
checksum, data, None) checksum, data, data_tid)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn, pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True) decode=True)
self.assertEqual(pconflicting, 0) self.assertEqual(pconflicting, 0)
...@@ -235,8 +236,9 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -235,8 +236,9 @@ class StorageClientHandlerTests(NeoTestBase):
raise ConflictError(locking_tid) raise ConflictError(locking_tid)
self.app.tm.storeObject = fakeStoreObject self.app.tm.storeObject = fakeStoreObject
oid, serial, comp, checksum, data = self._getObject() oid, serial, comp, checksum, data = self._getObject()
data_tid = self.getNextTID()
self.operation.askStoreObject(conn, oid, serial, comp, checksum, self.operation.askStoreObject(conn, oid, serial, comp, checksum,
data, tid) data, data_tid, tid)
pconflicting, poid, pserial = self.checkAnswerStoreObject(conn, pconflicting, poid, pserial = self.checkAnswerStoreObject(conn,
decode=True) decode=True)
self.assertEqual(pconflicting, 1) self.assertEqual(pconflicting, 1)
......
...@@ -348,11 +348,13 @@ class ProtocolTests(NeoTestBase): ...@@ -348,11 +348,13 @@ class ProtocolTests(NeoTestBase):
oid = self.getNextTID() oid = self.getNextTID()
serial = self.getNextTID() serial = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
p = Packets.AskStoreObject(oid, serial, 1, 55, "to", tid) tid2 = self.getNextTID()
poid, pserial, compression, checksum, data, ptid = p.decode() p = Packets.AskStoreObject(oid, serial, 1, 55, "to", tid2, tid)
poid, pserial, compression, checksum, data, ptid2, ptid = p.decode()
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
self.assertEqual(serial, pserial) self.assertEqual(serial, pserial)
self.assertEqual(tid, ptid) self.assertEqual(tid, ptid)
self.assertEqual(tid2, ptid2)
self.assertEqual(compression, 1) self.assertEqual(compression, 1)
self.assertEqual(checksum, 55) self.assertEqual(checksum, 55)
self.assertEqual(data, "to") self.assertEqual(data, "to")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment