Commit 30a33295 authored by Vincent Pelletier's avatar Vincent Pelletier

Use partition's critical TID to avoid unneeded replications.

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2296 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent b0753366
...@@ -256,7 +256,7 @@ class EventHandler(object): ...@@ -256,7 +256,7 @@ class EventHandler(object):
def answerTIDs(self, conn, tid_list): def answerTIDs(self, conn, tid_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def askTIDsFrom(self, conn, min_tid, length, partition): def askTIDsFrom(self, conn, min_tid, max_tid, length, partition):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerTIDsFrom(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list):
...@@ -275,7 +275,8 @@ class EventHandler(object): ...@@ -275,7 +275,8 @@ class EventHandler(object):
def answerObjectHistory(self, conn, oid, history_list): def answerObjectHistory(self, conn, oid, history_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def askObjectHistoryFrom(self, conn, oid, min_serial, length, partition): def askObjectHistoryFrom(self, conn, oid, min_serial, max_serial, length,
partition):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerObjectHistoryFrom(self, conn, object_dict): def answerObjectHistoryFrom(self, conn, object_dict):
......
...@@ -114,6 +114,7 @@ ZERO_TID = '\0' * 8 ...@@ -114,6 +114,7 @@ ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8 ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID) OID_LEN = len(INVALID_OID)
TID_LEN = len(INVALID_TID) TID_LEN = len(INVALID_TID)
MAX_TID = '\xff' * 8
UUID_NAMESPACES = { UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S', NodeTypes.STORAGE: 'S',
...@@ -1067,10 +1068,10 @@ class AskTIDsFrom(Packet): ...@@ -1067,10 +1068,10 @@ class AskTIDsFrom(Packet):
Ask for length TIDs starting at min_tid. The order of TIDs is ascending. Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
S -> S. S -> S.
""" """
_header_format = '!8sLL' _header_format = '!8s8sLL'
def _encode(self, min_tid, length, partition): def _encode(self, min_tid, max_tid, length, partition):
return pack(self._header_format, min_tid, length, partition) return pack(self._header_format, min_tid, max_tid, length, partition)
def _decode(self, body): def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition return unpack(self._header_format, body) # min_tid, length, partition
...@@ -1170,11 +1171,11 @@ class AskObjectHistoryFrom(Packet): ...@@ -1170,11 +1171,11 @@ class AskObjectHistoryFrom(Packet):
Ask history information for a given object. The order of serials is Ask history information for a given object. The order of serials is
ascending, and starts at (or above) min_serial for min_oid. S -> S. ascending, and starts at (or above) min_serial for min_oid. S -> S.
""" """
_header_format = '!8s8sLL' _header_format = '!8s8s8sLL'
def _encode(self, min_oid, min_serial, length, partition): def _encode(self, min_oid, min_serial, max_serial, length, partition):
return pack(self._header_format, min_oid, min_serial, length, return pack(self._header_format, min_oid, min_serial, max_serial,
partition) length, partition)
def _decode(self, body): def _decode(self, body):
# min_oid, min_serial, length, partition # min_oid, min_serial, length, partition
......
...@@ -294,11 +294,11 @@ class DatabaseManager(object): ...@@ -294,11 +294,11 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None.""" If there is no such object ID in a database, return None."""
raise NotImplementedError raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, length, num_partitions, def getObjectHistoryFrom(self, oid, min_serial, max_serial, length,
partition): num_partitions, partition):
"""Return a dict of length serials grouped by oid at (or above) """Return a dict of length serials grouped by oid at (or above)
min_oid and min_serial, for given partition, sorted in ascending min_oid and min_serial and below max_serial, for given partition,
order.""" sorted in ascending order."""
raise NotImplementedError raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
...@@ -307,11 +307,11 @@ class DatabaseManager(object): ...@@ -307,11 +307,11 @@ class DatabaseManager(object):
to filter out non-applicable TIDs.""" to filter out non-applicable TIDs."""
raise NotImplementedError raise NotImplementedError
def getReplicationTIDList(self, min_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions,
partition): partition):
"""Return a list of TIDs in ascending order from an initial tid value, """Return a list of TIDs in ascending order from an initial tid value,
at most the specified length. The partition number is passed to filter at most the specified length up to max_tid. The partition number is
out non-applicable TIDs.""" passed to filter out non-applicable TIDs."""
raise NotImplementedError raise NotImplementedError
def pack(self, tid, updateObjectDataForPack): def pack(self, tid, updateObjectDataForPack):
......
...@@ -649,20 +649,23 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -649,20 +649,23 @@ class MySQLDatabaseManager(DatabaseManager):
return result return result
return None return None
def getObjectHistoryFrom(self, min_oid, min_serial, length, num_partitions, def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
partition): num_partitions, partition):
q = self.query q = self.query
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
min_oid = u64(min_oid) min_oid = u64(min_oid)
min_serial = u64(min_serial) min_serial = u64(min_serial)
max_serial = u64(max_serial)
r = q('SELECT oid, serial FROM obj ' r = q('SELECT oid, serial FROM obj '
'WHERE ((oid = %(min_oid)d AND serial >= %(min_serial)d) OR ' 'WHERE ((oid = %(min_oid)d AND serial >= %(min_serial)d) OR '
'oid > %(min_oid)d) AND ' 'oid > %(min_oid)d) AND '
'MOD(oid, %(num_partitions)d) = %(partition)s ' 'MOD(oid, %(num_partitions)d) = %(partition)s AND '
'serial <= %(max_serial)d '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % { 'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': min_oid, 'min_oid': min_oid,
'min_serial': min_serial, 'min_serial': min_serial,
'max_serial': max_serial,
'length': length, 'length': length,
'num_partitions': num_partitions, 'num_partitions': num_partitions,
'partition': partition, 'partition': partition,
...@@ -685,19 +688,24 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -685,19 +688,24 @@ class MySQLDatabaseManager(DatabaseManager):
offset, length)) offset, length))
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, max_tid, length, num_partitions,
partition): partition):
q = self.query q = self.query
u64 = util.u64
p64 = util.p64
min_tid = u64(min_tid)
max_tid = u64(max_tid)
r = q("""SELECT tid FROM trans WHERE r = q("""SELECT tid FROM trans WHERE
MOD(tid, %(num_partitions)d) = %(partition)d MOD(tid, %(num_partitions)d) = %(partition)d
AND tid >= %(min_tid)d AND tid >= %(min_tid)d AND tid <= %(max_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % { ORDER BY tid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions, 'num_partitions': num_partitions,
'partition': partition, 'partition': partition,
'min_tid': util.u64(min_tid), 'min_tid': min_tid,
'max_tid': max_tid,
'length': length, 'length': length,
}) })
return [util.p64(t[0]) for t in r] return [p64(t[0]) for t in r]
def _updatePackFuture(self, oid, orig_serial, max_serial, def _updatePackFuture(self, oid, orig_serial, max_serial,
updateObjectDataForPack): updateObjectDataForPack):
......
...@@ -165,16 +165,21 @@ class ReplicationHandler(EventHandler): ...@@ -165,16 +165,21 @@ class ReplicationHandler(EventHandler):
def _doAskTIDsFrom(self, min_tid, length): def _doAskTIDsFrom(self, min_tid, length):
replicator = self.app.replicator replicator = self.app.replicator
partition = replicator.current_partition.getRID() partition = replicator.current_partition
replicator.getTIDsFrom(min_tid, length, partition) partition_id = partition.getRID()
return Packets.AskTIDsFrom(min_tid, length, partition) max_tid = partition.getCriticalTID()
replicator.getTIDsFrom(min_tid, max_tid, length, partition_id)
return Packets.AskTIDsFrom(min_tid, max_tid, length, partition_id)
def _doAskObjectHistoryFrom(self, min_oid, min_serial, length): def _doAskObjectHistoryFrom(self, min_oid, min_serial, length):
replicator = self.app.replicator replicator = self.app.replicator
partition = replicator.current_partition.getRID() partition = replicator.current_partition
replicator.getObjectHistoryFrom(min_oid, min_serial, length, partition) partition_id = partition.getRID()
return Packets.AskObjectHistoryFrom(min_oid, min_serial, length, max_serial = partition.getCriticalTID()
partition) replicator.getObjectHistoryFrom(min_oid, min_serial, max_serial,
length, partition_id)
return Packets.AskObjectHistoryFrom(min_oid, min_serial, max_serial,
length, partition_id)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum, def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
...@@ -200,7 +205,8 @@ class ReplicationHandler(EventHandler): ...@@ -200,7 +205,8 @@ class ReplicationHandler(EventHandler):
p = self._doAskCheckTIDRange(min_tid, min(length / 2, p = self._doAskCheckTIDRange(min_tid, min(length / 2,
count + 1)) count + 1))
if p is None: if p is None:
if count == length: if count == length and \
max_tid < replicator.current_partition.getCriticalTID():
# Go on with next chunk # Go on with next chunk
p = self._doAskCheckTIDRange(add64(max_tid, 1)) p = self._doAskCheckTIDRange(add64(max_tid, 1))
else: else:
......
...@@ -30,17 +30,17 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -30,17 +30,17 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
tid = app.dm.getLastTID() tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID())) conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askTIDsFrom(self, conn, min_tid, length, partition): def askTIDsFrom(self, conn, min_tid, max_tid, length, partition):
app = self.app app = self.app
tid_list = app.dm.getReplicationTIDList(min_tid, length, tid_list = app.dm.getReplicationTIDList(min_tid, max_tid, length,
app.pt.getPartitions(), partition) app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerTIDsFrom(tid_list)) conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, min_oid, min_serial, length, def askObjectHistoryFrom(self, conn, min_oid, min_serial, max_serial,
partition): length, partition):
app = self.app app = self.app
object_dict = app.dm.getObjectHistoryFrom(min_oid, min_serial, length, object_dict = app.dm.getObjectHistoryFrom(min_oid, min_serial, max_serial,
app.pt.getPartitions(), partition) length, app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict)) conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
def askCheckTIDRange(self, conn, min_tid, length, partition): def askCheckTIDRange(self, conn, min_tid, length, partition):
......
...@@ -361,17 +361,18 @@ class Replicator(object): ...@@ -361,17 +361,18 @@ class Replicator(object):
app.dm.checkSerialRange, (min_oid, min_serial, length, app.dm.checkSerialRange, (min_oid, min_serial, length,
app.pt.getPartitions(), partition)) app.pt.getPartitions(), partition))
def getTIDsFrom(self, min_tid, length, partition): def getTIDsFrom(self, min_tid, max_tid, length, partition):
app = self.app app = self.app
self._addTask('TIDsFrom', self._addTask('TIDsFrom',
app.dm.getReplicationTIDList, (min_tid, length, app.dm.getReplicationTIDList, (min_tid, max_tid, length,
app.pt.getPartitions(), partition)) app.pt.getPartitions(), partition))
def getObjectHistoryFrom(self, min_oid, min_serial, length, partition): def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
partition):
app = self.app app = self.app
self._addTask('ObjectHistoryFrom', self._addTask('ObjectHistoryFrom',
app.dm.getObjectHistoryFrom, (min_oid, min_serial, length, app.dm.getObjectHistoryFrom, (min_oid, min_serial, max_serial,
app.pt.getPartitions(), partition)) length, app.pt.getPartitions(), partition))
def _getCheckResult(self, key): def _getCheckResult(self, key):
return self.task_dict.pop(key).getResult() return self.task_dict.pop(key).getResult()
......
...@@ -240,11 +240,13 @@ class StorageReplicationHandlerTests(NeoTestBase): ...@@ -240,11 +240,13 @@ class StorageReplicationHandlerTests(NeoTestBase):
def test_answerCheckTIDRangeIdenticalChunkWithNext(self): def test_answerCheckTIDRangeIdenticalChunkWithNext(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
max_tid = self.getNextTID() max_tid = self.getNextTID()
critical_tid = self.getNextTID()
assert max_tid < critical_tid
length = RANGE_LENGTH / 2 length = RANGE_LENGTH / 2
rid = 12 rid = 12
conn = self.getFakeConnection() conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid, app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn) conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app) handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid # Peer has the same data as we have: length, checksum and max_tid
# match. # match.
...@@ -259,6 +261,31 @@ class StorageReplicationHandlerTests(NeoTestBase): ...@@ -259,6 +261,31 @@ class StorageReplicationHandlerTests(NeoTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition) calls[0].checkArgs(pmin_tid, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkAboveCriticalTID(self):
critical_tid = self.getNextTID()
min_tid = self.getNextTID()
max_tid = self.getNextTID()
assert critical_tid < max_tid
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with object range checks
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkWithoutNext(self): def test_answerCheckTIDRangeIdenticalChunkWithoutNext(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
max_tid = self.getNextTID() max_tid = self.getNextTID()
...@@ -307,11 +334,12 @@ class StorageReplicationHandlerTests(NeoTestBase): ...@@ -307,11 +334,12 @@ class StorageReplicationHandlerTests(NeoTestBase):
def test_answerCheckTIDRangeDifferentSmallChunkWithNext(self): def test_answerCheckTIDRangeDifferentSmallChunkWithNext(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
max_tid = self.getNextTID() max_tid = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1 length = MIN_RANGE_LENGTH - 1
rid = 12 rid = 12
conn = self.getFakeConnection() conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid, app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn) conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app) handler = ReplicationHandler(app)
# Peer has different data # Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid) handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
...@@ -322,13 +350,14 @@ class StorageReplicationHandlerTests(NeoTestBase): ...@@ -322,13 +350,14 @@ class StorageReplicationHandlerTests(NeoTestBase):
tid_packet = tid_call.getParam(0) tid_packet = tid_call.getParam(0)
next_packet = next_call.getParam(0) next_packet = next_call.getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom) self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
pmin_tid, plength, ppartition = tid_packet.decode() pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid) self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length) self.assertEqual(plength, length)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getTIDsFrom') calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition) calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
self.assertEqual(next_packet.getType(), Packets.AskCheckTIDRange) self.assertEqual(next_packet.getType(), Packets.AskCheckTIDRange)
pmin_tid, plength, ppartition = next_packet.decode() pmin_tid, plength, ppartition = next_packet.decode()
self.assertEqual(pmin_tid, add64(max_tid, 1)) self.assertEqual(pmin_tid, add64(max_tid, 1))
...@@ -341,11 +370,12 @@ class StorageReplicationHandlerTests(NeoTestBase): ...@@ -341,11 +370,12 @@ class StorageReplicationHandlerTests(NeoTestBase):
def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self): def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
max_tid = self.getNextTID() max_tid = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1 length = MIN_RANGE_LENGTH - 1
rid = 12 rid = 12
conn = self.getFakeConnection() conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid, app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn) conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app) handler = ReplicationHandler(app)
# Peer has different data, and less than length # Peer has different data, and less than length
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0, handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
...@@ -357,13 +387,14 @@ class StorageReplicationHandlerTests(NeoTestBase): ...@@ -357,13 +387,14 @@ class StorageReplicationHandlerTests(NeoTestBase):
tid_packet = tid_call.getParam(0) tid_packet = tid_call.getParam(0)
next_packet = next_call.getParam(0) next_packet = next_call.getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom) self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
pmin_tid, plength, ppartition = tid_packet.decode() pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid) self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length - 1) self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getTIDsFrom') calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition) calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange) self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_packet.decode() pmin_oid, pmin_serial, plength, ppartition = next_packet.decode()
self.assertEqual(pmin_oid, ZERO_OID) self.assertEqual(pmin_oid, ZERO_OID)
...@@ -448,11 +479,12 @@ class StorageReplicationHandlerTests(NeoTestBase): ...@@ -448,11 +479,12 @@ class StorageReplicationHandlerTests(NeoTestBase):
max_oid = self.getOID(10) max_oid = self.getOID(10)
min_serial = self.getNextTID() min_serial = self.getNextTID()
max_serial = self.getNextTID() max_serial = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1 length = MIN_RANGE_LENGTH - 1
rid = 12 rid = 12
conn = self.getFakeConnection() conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1, app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
max_serial), rid=rid, conn=conn) max_serial), rid=rid, conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app) handler = ReplicationHandler(app)
# Peer has different data # Peer has different data
handler.answerCheckSerialRange(conn, min_oid, min_serial, length, handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
...@@ -464,14 +496,17 @@ class StorageReplicationHandlerTests(NeoTestBase): ...@@ -464,14 +496,17 @@ class StorageReplicationHandlerTests(NeoTestBase):
serial_packet = serial_call.getParam(0) serial_packet = serial_call.getParam(0)
next_packet = next_call.getParam(0) next_packet = next_call.getParam(0)
self.assertEqual(serial_packet.getType(), Packets.AskObjectHistoryFrom) self.assertEqual(serial_packet.getType(), Packets.AskObjectHistoryFrom)
pmin_oid, pmin_serial, plength, ppartition = serial_packet.decode() pmin_oid, pmin_serial, pmax_serial, plength, ppartition = \
serial_packet.decode()
self.assertEqual(pmin_oid, min_oid) self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial) self.assertEqual(pmin_serial, min_serial)
self.assertEqual(pmax_serial, critical_tid)
self.assertEqual(plength, length) self.assertEqual(plength, length)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom') calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition) calls[0].checkArgs(pmin_oid, pmin_serial, pmax_serial, plength,
ppartition)
self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange) self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_packet.decode() pmin_oid, pmin_serial, plength, ppartition = next_packet.decode()
self.assertEqual(pmin_oid, max_oid) self.assertEqual(pmin_oid, max_oid)
...@@ -487,25 +522,29 @@ class StorageReplicationHandlerTests(NeoTestBase): ...@@ -487,25 +522,29 @@ class StorageReplicationHandlerTests(NeoTestBase):
max_oid = self.getOID(10) max_oid = self.getOID(10)
min_serial = self.getNextTID() min_serial = self.getNextTID()
max_serial = self.getNextTID() max_serial = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1 length = MIN_RANGE_LENGTH - 1
rid = 12 rid = 12
conn = self.getFakeConnection() conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, app = self.getApp(tid_check_result=(length - 5, 0, max_oid,
1, max_serial), rid=rid, conn=conn) 1, max_serial), rid=rid, conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app) handler = ReplicationHandler(app)
# Peer has different data, and less than length # Peer has different data, and less than length
handler.answerCheckSerialRange(conn, min_oid, min_serial, length, handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length - 1, 0, max_oid, 1, max_serial) length - 1, 0, max_oid, 1, max_serial)
# Result: ask tid list, and mark replication as done # Result: ask tid list, and mark replication as done
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn, pmin_oid, pmin_serial, pmax_serial, plength, ppartition = \
Packets.AskObjectHistoryFrom, decode=True) self.checkAskPacket(conn, Packets.AskObjectHistoryFrom,
decode=True)
self.assertEqual(pmin_oid, min_oid) self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial) self.assertEqual(pmin_serial, min_serial)
self.assertEqual(pmax_serial, critical_tid)
self.assertEqual(plength, length - 1) self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid) self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom') calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition) calls[0].checkArgs(pmin_oid, pmin_serial, pmax_serial, plength,
ppartition)
self.assertTrue(app.replicator.replication_done) self.assertTrue(app.replicator.replication_done)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -119,15 +119,17 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -119,15 +119,17 @@ class StorageStorageHandlerTests(NeoTestBase):
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
tid = self.getNextTID() tid = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, 2, 1) tid2 = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, tid2, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList') calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, 2, 1, 1) calls[0].checkArgs(tid, tid2, 2, 1, 1)
self.checkAnswerTidsFrom(conn) self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self): def test_26_askObjectHistoryFrom(self):
min_oid = self.getOID(2) min_oid = self.getOID(2)
min_serial = self.getNextTID() min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = 4 length = 4
partition = 8 partition = 8
num_partitions = 16 num_partitions = 16
...@@ -137,13 +139,13 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -137,13 +139,13 @@ class StorageStorageHandlerTests(NeoTestBase):
self.app.pt = Mock({ self.app.pt = Mock({
'getPartitions': num_partitions, 'getPartitions': num_partitions,
}) })
self.operation.askObjectHistoryFrom(conn, min_oid, min_serial, length, self.operation.askObjectHistoryFrom(conn, min_oid, min_serial,
partition) max_serial, length, partition)
self.checkAnswerObjectHistoryFrom(conn) self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom') calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, length, num_partitions, calls[0].checkArgs(min_oid, min_serial, max_serial, length,
partition) num_partitions, partition)
def test_askCheckTIDRange(self): def test_askCheckTIDRange(self):
count = 1 count = 1
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import MySQLdb import MySQLdb
from mock import Mock from mock import Mock
from neo.util import dump, p64, u64 from neo.util import dump, p64, u64
from neo.protocol import CellStates, INVALID_PTID, ZERO_OID, ZERO_TID from neo.protocol import CellStates, INVALID_PTID, ZERO_OID, ZERO_TID, MAX_TID
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.exception import DatabaseFailure from neo.exception import DatabaseFailure
from neo.storage.database.mysqldb import MySQLDatabaseManager from neo.storage.database.mysqldb import MySQLDatabaseManager
...@@ -516,29 +516,40 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -516,29 +516,40 @@ class StorageMySQSLdbTests(NeoTestBase):
self.db.finishTransaction(tid3) self.db.finishTransaction(tid3)
self.db.finishTransaction(tid4) self.db.finishTransaction(tid4)
# Check full result # Check full result
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 1, 0) result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
1, 0)
self.assertEqual(result, { self.assertEqual(result, {
oid1: [tid1, tid3], oid1: [tid1, tid3],
oid2: [tid2, tid4], oid2: [tid2, tid4],
}) })
# Lower bound is inclusive # Lower bound is inclusive
result = self.db.getObjectHistoryFrom(oid1, tid1, 10, 1, 0) result = self.db.getObjectHistoryFrom(oid1, tid1, MAX_TID, 10, 1, 0)
self.assertEqual(result, { self.assertEqual(result, {
oid1: [tid1, tid3], oid1: [tid1, tid3],
oid2: [tid2, tid4], oid2: [tid2, tid4],
}) })
# Upper bound is inclusive
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, tid3, 10,
1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Length is total number of serials # Length is total number of serials
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 3, 1, 0) result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 3,
1, 0)
self.assertEqual(result, { self.assertEqual(result, {
oid1: [tid1, tid3], oid1: [tid1, tid3],
oid2: [tid2], oid2: [tid2],
}) })
# Partition constraints are honored # Partition constraints are honored
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 0) result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
2, 0)
self.assertEqual(result, { self.assertEqual(result, {
oid1: [tid1, tid3], oid1: [tid1, tid3],
}) })
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 1) result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
2, 1)
self.assertEqual(result, { self.assertEqual(result, {
oid2: [tid2, tid4], oid2: [tid2, tid4],
}) })
...@@ -575,19 +586,21 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -575,19 +586,21 @@ class StorageMySQSLdbTests(NeoTestBase):
def test_getReplicationTIDList(self): def test_getReplicationTIDList(self):
tid1, tid2, tid3, tid4 = self._storeTransactions(4) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids # get tids
result = self.db.getReplicationTIDList(tid1, 4, 1, 0) # - all
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 1, 0)
self.checkSet(result, [tid1, tid2, tid3, tid4]) self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 2, 0) # - one partition
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 2, 0)
self.checkSet(result, [tid1, tid3]) self.checkSet(result, [tid1, tid3])
result = self.db.getReplicationTIDList(tid1, 4, 3, 0) # - min_tid is inclusive
self.checkSet(result, [tid1, tid4]) result = self.db.getReplicationTIDList(tid3, MAX_TID, 10, 1, 0)
# get a subset of tids
result = self.db.getReplicationTIDList(tid3, 4, 1, 0)
self.checkSet(result, [tid3, tid4]) self.checkSet(result, [tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 2, 1, 0) # - max tid is inclusive
result = self.db.getReplicationTIDList(ZERO_TID, tid2, 10, 1, 0)
self.checkSet(result, [tid1, tid2])
# - limit
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 2, 1, 0)
self.checkSet(result, [tid1, tid2]) self.checkSet(result, [tid1, tid2])
result = self.db.getReplicationTIDList(tid1, 1, 3, 1)
self.checkSet(result, [tid2])
def test__getObjectData(self): def test__getObjectData(self):
db = self.db db = self.db
......
...@@ -595,9 +595,11 @@ class ProtocolTests(NeoTestBase): ...@@ -595,9 +595,11 @@ class ProtocolTests(NeoTestBase):
def test_AskTIDsFrom(self): def test_AskTIDsFrom(self):
tid = self.getNextTID() tid = self.getNextTID()
p = Packets.AskTIDsFrom(tid, 1000, 5) tid2 = self.getNextTID()
min_tid, length, partition = p.decode() p = Packets.AskTIDsFrom(tid, tid2, 1000, 5)
min_tid, max_tid, length, partition = p.decode()
self.assertEqual(min_tid, tid) self.assertEqual(min_tid, tid)
self.assertEqual(max_tid, tid2)
self.assertEqual(length, 1000) self.assertEqual(length, 1000)
self.assertEqual(partition, 5) self.assertEqual(partition, 5)
...@@ -607,12 +609,15 @@ class ProtocolTests(NeoTestBase): ...@@ -607,12 +609,15 @@ class ProtocolTests(NeoTestBase):
def test_AskObjectHistoryFrom(self): def test_AskObjectHistoryFrom(self):
oid = self.getOID(1) oid = self.getOID(1)
min_serial = self.getNextTID() min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = 5 length = 5
partition = 4 partition = 4
p = Packets.AskObjectHistoryFrom(oid, min_serial, length, partition) p = Packets.AskObjectHistoryFrom(oid, min_serial, max_serial, length,
p_oid, p_min_serial, p_length, p_partition = p.decode() partition)
p_oid, p_min_serial, p_max_serial, p_length, p_partition = p.decode()
self.assertEqual(p_oid, oid) self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial) self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_max_serial, max_serial)
self.assertEqual(p_length, length) self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition) self.assertEqual(p_partition, partition)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment