Commit d8a7a177 authored by Vincent Pelletier's avatar Vincent Pelletier

Improve replication SQL queries.

It is more efficient to provide a boundary value than a row count range.
This fixes replication on partitions with a large number of objects, revisions
or transactions: query time is now constant where it used to increase, causing
timeout problems when query duration exceeded ping time + ping timeout (11s
currently).

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2221 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 44b434d5
...@@ -256,6 +256,12 @@ class EventHandler(object): ...@@ -256,6 +256,12 @@ class EventHandler(object):
def answerTIDs(self, conn, tid_list): def answerTIDs(self, conn, tid_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def askTIDsFrom(self, conn, min_tid, length, partition):
raise UnexpectedPacketError
def answerTIDsFrom(self, conn, tid_list):
raise UnexpectedPacketError
def askTransactionInformation(self, conn, tid): def askTransactionInformation(self, conn, tid):
raise UnexpectedPacketError raise UnexpectedPacketError
...@@ -269,7 +275,13 @@ class EventHandler(object): ...@@ -269,7 +275,13 @@ class EventHandler(object):
def answerObjectHistory(self, conn, oid, history_list): def answerObjectHistory(self, conn, oid, history_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def askOIDs(self, conn, first, last, partition): def askObjectHistoryFrom(self, conn, oid, min_serial, length):
raise UnexpectedPacketError
def answerObjectHistoryFrom(self, conn, oid, history_list):
raise UnexpectedPacketError
def askOIDs(self, conn, min_oid, length, partition):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerOIDs(self, conn, oid_list): def answerOIDs(self, conn, oid_list):
...@@ -414,11 +426,15 @@ class EventHandler(object): ...@@ -414,11 +426,15 @@ class EventHandler(object):
d[Packets.AnswerObject] = self.answerObject d[Packets.AnswerObject] = self.answerObject
d[Packets.AskTIDs] = self.askTIDs d[Packets.AskTIDs] = self.askTIDs
d[Packets.AnswerTIDs] = self.answerTIDs d[Packets.AnswerTIDs] = self.answerTIDs
d[Packets.AskTIDsFrom] = self.askTIDsFrom
d[Packets.AnswerTIDsFrom] = self.answerTIDsFrom
d[Packets.AskTransactionInformation] = self.askTransactionInformation d[Packets.AskTransactionInformation] = self.askTransactionInformation
d[Packets.AnswerTransactionInformation] = \ d[Packets.AnswerTransactionInformation] = \
self.answerTransactionInformation self.answerTransactionInformation
d[Packets.AskObjectHistory] = self.askObjectHistory d[Packets.AskObjectHistory] = self.askObjectHistory
d[Packets.AnswerObjectHistory] = self.answerObjectHistory d[Packets.AnswerObjectHistory] = self.answerObjectHistory
d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom
d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom
d[Packets.AskOIDs] = self.askOIDs d[Packets.AskOIDs] = self.askOIDs
d[Packets.AnswerOIDs] = self.answerOIDs d[Packets.AnswerOIDs] = self.answerOIDs
d[Packets.AskPartitionList] = self.askPartitionList d[Packets.AskPartitionList] = self.askPartitionList
......
...@@ -109,6 +109,8 @@ INVALID_OID = '\xff' * 8 ...@@ -109,6 +109,8 @@ INVALID_OID = '\xff' * 8
INVALID_PTID = '\0' * 8 INVALID_PTID = '\0' * 8
INVALID_SERIAL = INVALID_TID INVALID_SERIAL = INVALID_TID
INVALID_PARTITION = 0xffffffff INVALID_PARTITION = 0xffffffff
ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID) OID_LEN = len(INVALID_OID)
UUID_NAMESPACES = { UUID_NAMESPACES = {
...@@ -1024,7 +1026,7 @@ class AnswerObject(Packet): ...@@ -1024,7 +1026,7 @@ class AnswerObject(Packet):
class AskTIDs(Packet): class AskTIDs(Packet):
""" """
Ask for TIDs between a range of offsets. The order of TIDs is descending, Ask for TIDs between a range of offsets. The order of TIDs is descending,
and the range is [first, last). C, S -> S. and the range is [first, last). C -> S.
""" """
_header_format = '!QQL' _header_format = '!QQL'
...@@ -1036,7 +1038,7 @@ class AskTIDs(Packet): ...@@ -1036,7 +1038,7 @@ class AskTIDs(Packet):
class AnswerTIDs(Packet): class AnswerTIDs(Packet):
""" """
Answer the requested TIDs. S -> C, S. Answer the requested TIDs. S -> C.
""" """
_header_format = '!L' _header_format = '!L'
_list_entry_format = '8s' _list_entry_format = '8s'
...@@ -1060,6 +1062,25 @@ class AnswerTIDs(Packet): ...@@ -1060,6 +1062,25 @@ class AnswerTIDs(Packet):
tid_list.append(tid) tid_list.append(tid)
return (tid_list,) return (tid_list,)
class AskTIDsFrom(Packet):
"""
Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
S -> S.
"""
_header_format = '!8sLL'
def _encode(self, min_tid, length, partition):
return pack(self._header_format, min_tid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition
class AnswerTIDsFrom(AnswerTIDs):
"""
Answer the requested TIDs. S -> S
"""
pass
class AskTransactionInformation(Packet): class AskTransactionInformation(Packet):
""" """
Ask information about a transaction. Any -> S. Ask information about a transaction. Any -> S.
...@@ -1105,7 +1126,7 @@ class AnswerTransactionInformation(Packet): ...@@ -1105,7 +1126,7 @@ class AnswerTransactionInformation(Packet):
class AskObjectHistory(Packet): class AskObjectHistory(Packet):
""" """
Ask history information for a given object. The order of serials is Ask history information for a given object. The order of serials is
descending, and the range is [first, last]. C, S -> S. descending, and the range is [first, last]. C -> S.
""" """
_header_format = '!8sQQ' _header_format = '!8sQQ'
...@@ -1118,7 +1139,7 @@ class AskObjectHistory(Packet): ...@@ -1118,7 +1139,7 @@ class AskObjectHistory(Packet):
class AnswerObjectHistory(Packet): class AnswerObjectHistory(Packet):
""" """
Answer history information (serial, size) for an object. S -> C, S. Answer history information (serial, size) for an object. S -> C.
""" """
_header_format = '!8sL' _header_format = '!8sL'
_list_entry_format = '!8sL' _list_entry_format = '!8sL'
...@@ -1144,18 +1165,40 @@ class AnswerObjectHistory(Packet): ...@@ -1144,18 +1165,40 @@ class AnswerObjectHistory(Packet):
history_list.append((serial, size)) history_list.append((serial, size))
return (oid, history_list) return (oid, history_list)
class AskObjectHistoryFrom(Packet):
"""
Ask history information for a given object. The order of serials is
ascending, and starts at (or above) min_serial. S -> S.
"""
_header_format = '!8s8sL'
def _encode(self, oid, min_serial, length):
return pack(self._header_format, oid, min_serial, length)
def _decode(self, body):
return unpack(self._header_format, body) # oid, min_serial, length
class AnswerObjectHistoryFrom(AskFinishTransaction):
"""
Answer the requested serials. S -> S.
"""
# This is similar to AskFinishTransaction as TID size is identical to OID
# size:
# - we have a single OID (TID in AskFinishTransaction)
# - we have a list of TIDs (OIDs in AskFinishTransaction)
pass
class AskOIDs(Packet): class AskOIDs(Packet):
""" """
Ask for OIDs between a range of offsets. The order of OIDs is descending, Ask for length OIDs starting at min_oid. S -> S.
and the range is [first, last). S -> S.
""" """
_header_format = '!QQL' _header_format = '!8sLL'
def _encode(self, first, last, partition): def _encode(self, min_oid, length, partition):
return pack(self._header_format, first, last, partition) return pack(self._header_format, min_oid, length, partition)
def _decode(self, body): def _decode(self, body):
return unpack(self._header_format, body) # first, last, partition return unpack(self._header_format, body) # min_oid, length, partition
class AnswerOIDs(Packet): class AnswerOIDs(Packet):
""" """
...@@ -1787,6 +1830,14 @@ class PacketRegistry(dict): ...@@ -1787,6 +1830,14 @@ class PacketRegistry(dict):
0x0034, 0x0034,
AskHasLock, AskHasLock,
AnswerHasLock) AnswerHasLock)
AskTIDsFrom, AnswerTIDsFrom = register(
0x0035,
AskTIDsFrom,
AnswerTIDsFrom)
AskObjectHistoryFrom, AnswerObjectHistoryFrom = register(
0x0036,
AskObjectHistoryFrom,
AnswerObjectHistoryFrom)
# build a "singleton" # build a "singleton"
Packets = PacketRegistry() Packets = PacketRegistry()
......
...@@ -263,8 +263,8 @@ class DatabaseManager(object): ...@@ -263,8 +263,8 @@ class DatabaseManager(object):
area as well.""" area as well."""
raise NotImplementedError raise NotImplementedError
def getOIDList(self, offset, length, num_partitions, partition_list): def getOIDList(self, min_oid, length, num_partitions, partition_list):
"""Return a list of OIDs in descending order from an offset, """Return a list of OIDs in ascending order from a minimal oid,
at most the specified length. The list of partitions are passed at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs.""" to filter out non-applicable TIDs."""
raise NotImplementedError raise NotImplementedError
...@@ -276,15 +276,20 @@ class DatabaseManager(object): ...@@ -276,15 +276,20 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None.""" If there is no such object ID in a database, return None."""
raise NotImplementedError raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, length):
"""Return a list of length serials for a given object ID at (or above)
min_serial, sorted in ascending order."""
raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
"""Return a list of TIDs in ascending order from an offset, """Return a list of TIDs in ascending order from an offset,
at most the specified length. The list of partitions are passed at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs.""" to filter out non-applicable TIDs."""
raise NotImplementedError raise NotImplementedError
def getReplicationTIDList(self, offset, length, num_partitions, def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list): partition_list):
"""Return a list of TIDs in descending order from an offset, """Return a list of TIDs in ascending order from an initial tid value,
at most the specified length. The list of partitions are passed at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs.""" to filter out non-applicable TIDs."""
raise NotImplementedError raise NotImplementedError
......
...@@ -618,12 +618,18 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -618,12 +618,18 @@ class MySQLDatabaseManager(DatabaseManager):
return oid_list, user, desc, ext, bool(packed) return oid_list, user, desc, ext, bool(packed)
return None return None
def getOIDList(self, offset, length, num_partitions, partition_list): def getOIDList(self, min_oid, length, num_partitions,
partition_list):
q = self.query q = self.query
r = q("""SELECT DISTINCT oid FROM obj WHERE MOD(oid, %d) in (%s) r = q("""SELECT DISTINCT oid FROM obj WHERE
ORDER BY oid DESC LIMIT %d,%d""" \ MOD(oid, %(num_partitions)d) in (%(partitions)s)
% (num_partitions, ','.join([str(p) for p in partition_list]), AND oid >= %(min_oid)d
offset, length)) ORDER BY oid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'min_oid': util.u64(min_oid),
'length': length,
})
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def _getObjectLength(self, oid, value_serial): def _getObjectLength(self, oid, value_serial):
...@@ -662,6 +668,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -662,6 +668,19 @@ class MySQLDatabaseManager(DatabaseManager):
return result return result
return None return None
def getObjectHistoryFrom(self, oid, min_serial, length):
q = self.query
oid = util.u64(oid)
p64 = util.p64
r = q("""SELECT serial FROM obj
WHERE oid = %(oid)d AND serial >= %(min_serial)d
ORDER BY serial ASC LIMIT %(length)d""" % {
'oid': oid,
'min_serial': util.u64(min_serial),
'length': length,
})
return [p64(t[0]) for t in r]
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
q = self.query q = self.query
r = q("""SELECT tid FROM trans WHERE MOD(tid, %d) in (%s) r = q("""SELECT tid FROM trans WHERE MOD(tid, %d) in (%s)
...@@ -671,13 +690,18 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -671,13 +690,18 @@ class MySQLDatabaseManager(DatabaseManager):
offset, length)) offset, length))
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, offset, length, num_partitions, partition_list): def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list):
q = self.query q = self.query
r = q("""SELECT tid FROM trans WHERE MOD(tid, %d) in (%s) r = q("""SELECT tid FROM trans WHERE
ORDER BY tid ASC LIMIT %d,%d""" \ MOD(tid, %(num_partitions)d) in (%(partitions)s)
% (num_partitions, AND tid >= %(min_tid)d
','.join([str(p) for p in partition_list]), ORDER BY tid ASC LIMIT %(length)d""" % {
offset, length)) 'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'min_tid': util.u64(min_tid),
'length': length,
})
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getTIDListPresent(self, tid_list): def getTIDListPresent(self, tid_list):
......
...@@ -65,16 +65,6 @@ class BaseMasterHandler(EventHandler): ...@@ -65,16 +65,6 @@ class BaseMasterHandler(EventHandler):
class BaseClientAndStorageOperationHandler(EventHandler): class BaseClientAndStorageOperationHandler(EventHandler):
""" Accept requests common to client and storage nodes """ """ Accept requests common to client and storage nodes """
def askObjectHistory(self, conn, oid, first, last):
if first >= last:
raise protocol.ProtocolError( 'invalid offsets')
app = self.app
history_list = app.dm.getObjectHistory(oid, first, last - first)
if history_list is None:
history_list = []
conn.answer(Packets.AnswerObjectHistory(oid, history_list))
def askTransactionInformation(self, conn, tid): def askTransactionInformation(self, conn, tid):
app = self.app app = self.app
t = app.dm.getTransaction(tid) t = app.dm.getTransaction(tid)
......
...@@ -144,3 +144,13 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -144,3 +144,13 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
state = LockState.GRANTED_TO_OTHER state = LockState.GRANTED_TO_OTHER
conn.answer(Packets.AnswerHasLock(oid, state)) conn.answer(Packets.AnswerHasLock(oid, state))
def askObjectHistory(self, conn, oid, first, last):
if first >= last:
raise protocol.ProtocolError( 'invalid offsets')
app = self.app
history_list = app.dm.getObjectHistory(oid, first, last - first)
if history_list is None:
history_list = []
conn.answer(Packets.AnswerObjectHistory(oid, history_list))
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
from neo import logging from neo import logging
from neo.handler import EventHandler from neo.handler import EventHandler
from neo.protocol import Packets from neo.protocol import Packets, ZERO_TID, ZERO_OID
from neo import util
def checkConnectionIsReplicatorConnection(func): def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw): def decorator(self, conn, *args, **kw):
...@@ -31,6 +32,10 @@ def checkConnectionIsReplicatorConnection(func): ...@@ -31,6 +32,10 @@ def checkConnectionIsReplicatorConnection(func):
return result return result
return decorator return decorator
def add64(packed, offset):
"""Add a python number to a 64-bits packed value"""
return util.p64(util.u64(packed) + offset)
class ReplicationHandler(EventHandler): class ReplicationHandler(EventHandler):
"""This class handles events for replications.""" """This class handles events for replications."""
...@@ -48,7 +53,7 @@ class ReplicationHandler(EventHandler): ...@@ -48,7 +53,7 @@ class ReplicationHandler(EventHandler):
conn.setUUID(uuid) conn.setUUID(uuid)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerTIDs(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list):
app = self.app app = self.app
if tid_list: if tid_list:
# If I have pending TIDs, check which TIDs I don't have, and # If I have pending TIDs, check which TIDs I don't have, and
...@@ -59,18 +64,15 @@ class ReplicationHandler(EventHandler): ...@@ -59,18 +64,15 @@ class ReplicationHandler(EventHandler):
conn.ask(Packets.AskTransactionInformation(tid), timeout=300) conn.ask(Packets.AskTransactionInformation(tid), timeout=300)
# And, ask more TIDs. # And, ask more TIDs.
app.replicator.tid_offset += 1000 p = Packets.AskTIDsFrom(add64(tid_list[-1], 1), 1000,
offset = app.replicator.tid_offset
p = Packets.AskTIDs(offset, offset + 1000,
app.replicator.current_partition.getRID()) app.replicator.current_partition.getRID())
conn.ask(p, timeout=300) conn.ask(p, timeout=300)
else: else:
# If no more TID, a replication of transactions is finished. # If no more TID, a replication of transactions is finished.
# So start to replicate objects now. # So start to replicate objects now.
p = Packets.AskOIDs(0, 1000, p = Packets.AskOIDs(ZERO_OID, 1000,
app.replicator.current_partition.getRID()) app.replicator.current_partition.getRID())
conn.ask(p, timeout=300) conn.ask(p, timeout=300)
app.replicator.oid_offset = 0
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerTransactionInformation(self, conn, tid, def answerTransactionInformation(self, conn, tid,
...@@ -84,10 +86,11 @@ class ReplicationHandler(EventHandler): ...@@ -84,10 +86,11 @@ class ReplicationHandler(EventHandler):
def answerOIDs(self, conn, oid_list): def answerOIDs(self, conn, oid_list):
app = self.app app = self.app
if oid_list: if oid_list:
app.replicator.next_oid = add64(oid_list[-1], 1)
# Pick one up, and ask the history. # Pick one up, and ask the history.
oid = oid_list.pop() oid = oid_list.pop()
conn.ask(Packets.AskObjectHistory(oid, 0, 1000), timeout=300) conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
app.replicator.serial_offset = 0 timeout=300)
app.replicator.oid_list = oid_list app.replicator.oid_list = oid_list
else: else:
# Nothing remains, so the replication for this partition is # Nothing remains, so the replication for this partition is
...@@ -95,34 +98,29 @@ class ReplicationHandler(EventHandler): ...@@ -95,34 +98,29 @@ class ReplicationHandler(EventHandler):
app.replicator.replication_done = True app.replicator.replication_done = True
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerObjectHistory(self, conn, oid, history_list): def answerObjectHistoryFrom(self, conn, oid, serial_list):
app = self.app app = self.app
if history_list: if serial_list:
# Check if I have objects, request those which I don't have. # Check if I have objects, request those which I don't have.
serial_list = [t[0] for t in history_list]
present_serial_list = app.dm.getSerialListPresent(oid, serial_list) present_serial_list = app.dm.getSerialListPresent(oid, serial_list)
serial_set = set(serial_list) - set(present_serial_list) serial_set = set(serial_list) - set(present_serial_list)
for serial in serial_set: for serial in serial_set:
conn.ask(Packets.AskObject(oid, serial, None), timeout=300) conn.ask(Packets.AskObject(oid, serial, None), timeout=300)
# And, ask more serials. # And, ask more serials.
app.replicator.serial_offset += 1000 conn.ask(Packets.AskObjectHistoryFrom(oid,
offset = app.replicator.serial_offset add64(serial_list[-1], 1), 1000), timeout=300)
p = Packets.AskObjectHistory(oid, offset, offset + 1000)
conn.ask(p, timeout=300)
else: else:
# This OID is finished. So advance to next. # This OID is finished. So advance to next.
oid_list = app.replicator.oid_list oid_list = app.replicator.oid_list
if oid_list: if oid_list:
# If I have more pending OIDs, pick one up. # If I have more pending OIDs, pick one up.
oid = oid_list.pop() oid = oid_list.pop()
conn.ask(Packets.AskObjectHistory(oid, 0, 1000), timeout=300) conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
app.replicator.serial_offset = 0 timeout=300)
else: else:
# Otherwise, acquire more OIDs. # Otherwise, acquire more OIDs.
app.replicator.oid_offset += 1000 p = Packets.AskOIDs(app.replicator.next_oid, 1000,
offset = app.replicator.oid_offset
p = Packets.AskOIDs(offset, offset + 1000,
app.replicator.current_partition.getRID()) app.replicator.current_partition.getRID())
conn.ask(p, timeout=300) conn.ask(p, timeout=300)
......
...@@ -30,36 +30,34 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -30,36 +30,34 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
tid = app.dm.getLastTID() tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID())) conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askOIDs(self, conn, first, last, partition): def askOIDs(self, conn, min_oid, length, partition):
# This method is complicated, because I must return OIDs only # This method is complicated, because I must return OIDs only
# about usable partitions assigned to me. # about usable partitions assigned to me.
if first >= last:
raise protocol.ProtocolError('invalid offsets')
app = self.app app = self.app
if partition == protocol.INVALID_PARTITION: if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid) partition_list = app.pt.getAssignedPartitionList(app.uuid)
else: else:
partition_list = [partition] partition_list = [partition]
oid_list = app.dm.getOIDList(first, last - first, oid_list = app.dm.getOIDList(min_oid, length,
app.pt.getPartitions(), partition_list) app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerOIDs(oid_list)) conn.answer(Packets.AnswerOIDs(oid_list))
def askTIDs(self, conn, first, last, partition): def askTIDsFrom(self, conn, min_tid, length, partition):
# This method is complicated, because I must return TIDs only # This method is complicated, because I must return TIDs only
# about usable partitions assigned to me. # about usable partitions assigned to me.
if first >= last:
raise protocol.ProtocolError('invalid offsets')
app = self.app app = self.app
if partition == protocol.INVALID_PARTITION: if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid) partition_list = app.pt.getAssignedPartitionList(app.uuid)
else: else:
partition_list = [partition] partition_list = [partition]
tid_list = app.dm.getReplicationTIDList(first, last - first, tid_list = app.dm.getReplicationTIDList(min_tid, length,
app.pt.getPartitions(), partition_list) app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerTIDs(tid_list)) conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, oid, min_serial, length):
app = self.app
history_list = app.dm.getObjectHistoryFrom(oid, min_serial, length)
conn.answer(Packets.AnswerObjectHistoryFrom(oid, history_list))
...@@ -19,7 +19,7 @@ from neo import logging ...@@ -19,7 +19,7 @@ from neo import logging
from random import choice from random import choice
from neo.storage.handlers import replication from neo.storage.handlers import replication
from neo.protocol import NodeTypes, NodeStates, CellStates, Packets from neo.protocol import NodeTypes, NodeStates, CellStates, Packets, ZERO_TID
from neo.connection import ClientConnection from neo.connection import ClientConnection
from neo.util import dump from neo.util import dump
...@@ -38,7 +38,7 @@ class Partition(object): ...@@ -38,7 +38,7 @@ class Partition(object):
def setCriticalTID(self, tid): def setCriticalTID(self, tid):
if tid is None: if tid is None:
tid = '\x00' * 8 tid = ZERO_TID
self.tid = tid self.tid = tid
def safe(self, min_pending_tid): def safe(self, min_pending_tid):
...@@ -81,7 +81,6 @@ class Replicator(object): ...@@ -81,7 +81,6 @@ class Replicator(object):
self.app = app self.app = app
self.new_partition_dict = self._getOutdatedPartitionList() self.new_partition_dict = self._getOutdatedPartitionList()
self.critical_tid_dict = {} self.critical_tid_dict = {}
self.tid_offset = 0
self.reset() self.reset()
def reset(self): def reset(self):
...@@ -172,8 +171,8 @@ class Replicator(object): ...@@ -172,8 +171,8 @@ class Replicator(object):
app.uuid, app.server, app.name) app.uuid, app.server, app.name)
self.current_connection.ask(p) self.current_connection.ask(p)
self.tid_offset = 0 p = Packets.AskTIDsFrom(ZERO_TID, 1000,
p = Packets.AskTIDs(0, 1000, self.current_partition.getRID()) self.current_partition.getRID())
self.current_connection.ask(p, timeout=300) self.current_connection.ask(p, timeout=300)
self.replication_done = False self.replication_done = False
......
...@@ -364,9 +364,15 @@ class NeoTestBase(unittest.TestCase): ...@@ -364,9 +364,15 @@ class NeoTestBase(unittest.TestCase):
def checkAnswerTids(self, conn, **kw): def checkAnswerTids(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw) return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw)
def checkAnswerTidsFrom(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDsFrom, **kw)
def checkAnswerObjectHistory(self, conn, **kw): def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw) return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerObjectHistoryFrom(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistoryFrom, **kw)
def checkAnswerStoreTransaction(self, conn, **kw): def checkAnswerStoreTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw) return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw)
......
...@@ -113,28 +113,19 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -113,28 +113,19 @@ class StorageStorageHandlerTests(NeoTestBase):
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.checkAnswerObject(conn) self.checkAnswerObject(conn)
def test_25_askTIDs1(self): def test_25_askTIDsFrom1(self):
# invalid offsets => error
app = self.app
app.pt = Mock()
app.dm = Mock()
conn = self.getFakeConnection()
self.checkProtocolErrorRaised(self.operation.askTIDs, conn, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getReplicationTIDList')), 0)
def test_25_askTIDs2(self):
# well case => answer # well case => answer
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
self.operation.askTIDs(conn, 1, 2, 1) tid = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList') calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [1, ]) calls[0].checkArgs(tid, 2, 1, [1, ])
self.checkAnswerTids(conn) self.checkAnswerTidsFrom(conn)
def test_25_askTIDs3(self): def test_25_askTIDsFrom2(self):
# invalid partition => answer usable partitions # invalid partition => answer usable partitions
conn = self.getFakeConnection() conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid}) cell = Mock({'getUUID':self.app.uuid})
...@@ -144,59 +135,39 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -144,59 +135,39 @@ class StorageStorageHandlerTests(NeoTestBase):
'getPartitions': 1, 'getPartitions': 1,
'getAssignedPartitionList': [0], 'getAssignedPartitionList': [0],
}) })
self.operation.askTIDs(conn, 1, 2, INVALID_PARTITION) tid = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1) self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList') calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [0, ]) calls[0].checkArgs(tid, 2, 1, [0, ])
self.checkAnswerTids(conn) self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistory1(self): def test_26_askObjectHistoryFrom(self):
# invalid offsets => error oid = self.getOID(2)
app = self.app min_tid = self.getNextTID()
app.dm = Mock()
conn = self.getFakeConnection()
self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn,
1, 1, None)
self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_26_askObjectHistory2(self):
oid1 = self.getOID(1)
oid2 = self.getOID(2)
tid = self.getNextTID() tid = self.getNextTID()
# first case: empty history
conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistory': None})
self.operation.askObjectHistory(conn, oid1, 1, 2)
self.checkAnswerObjectHistory(conn)
# second case: not empty history
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistory': [(tid, 0, ), ]}) self.app.dm = Mock({'getObjectHistoryFrom': [tid]})
self.operation.askObjectHistory(conn, oid2, 1, 2) self.operation.askObjectHistoryFrom(conn, oid, min_tid, 2)
self.checkAnswerObjectHistory(conn) self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, min_tid, 2)
def test_25_askOIDs1(self): def test_25_askOIDs1(self):
# invalid offsets => error
app = self.app
app.pt = Mock()
app.dm = Mock()
conn = self.getFakeConnection()
self.checkProtocolErrorRaised(self.operation.askOIDs, conn, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getOIDList')), 0)
def test_25_askOIDs2(self):
# well case > answer OIDs # well case > answer OIDs
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )}) self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
self.operation.askOIDs(conn, 1, 2, 1) oid = self.getOID(1)
self.operation.askOIDs(conn, oid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList') calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [1, ]) calls[0].checkArgs(oid, 2, 1, [1, ])
self.checkAnswerOids(conn) self.checkAnswerOids(conn)
def test_25_askOIDs3(self): def test_25_askOIDs2(self):
# invalid partition => answer usable partitions # invalid partition => answer usable partitions
conn = self.getFakeConnection() conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid}) cell = Mock({'getUUID':self.app.uuid})
...@@ -206,11 +177,12 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -206,11 +177,12 @@ class StorageStorageHandlerTests(NeoTestBase):
'getPartitions': 1, 'getPartitions': 1,
'getAssignedPartitionList': [0], 'getAssignedPartitionList': [0],
}) })
self.operation.askOIDs(conn, 1, 2, INVALID_PARTITION) oid = self.getOID(1)
self.operation.askOIDs(conn, oid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1) self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList') calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(1, 1, 1, [0]) calls[0].checkArgs(oid, 2, 1, [0])
self.checkAnswerOids(conn) self.checkAnswerOids(conn)
......
...@@ -457,20 +457,20 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -457,20 +457,20 @@ class StorageMySQSLdbTests(NeoTestBase):
self.db.storeTransaction(tid, objs, txn) self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid) self.db.finishTransaction(tid)
# get oids # get oids
result = self.db.getOIDList(0, 4, 1, [0]) result = self.db.getOIDList(oid1, 4, 1, [0])
self.checkSet(result, [oid1, oid2, oid3, oid4]) self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(0, 4, 2, [0]) result = self.db.getOIDList(oid1, 4, 2, [0])
self.checkSet(result, [oid1, oid3]) self.checkSet(result, [oid1, oid3])
result = self.db.getOIDList(0, 4, 2, [0, 1]) result = self.db.getOIDList(oid1, 4, 2, [0, 1])
self.checkSet(result, [oid1, oid2, oid3, oid4]) self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(0, 4, 3, [0]) result = self.db.getOIDList(oid1, 4, 3, [0])
self.checkSet(result, [oid1, oid4]) self.checkSet(result, [oid1, oid4])
# get a subset of oids # get a subset of oids
result = self.db.getOIDList(2, 4, 1, [0]) result = self.db.getOIDList(oid1, 2, 1, [0])
self.checkSet(result, [oid1, oid2]) self.checkSet(result, [oid1, oid2])
result = self.db.getOIDList(0, 2, 1, [0]) result = self.db.getOIDList(oid3, 2, 1, [0])
self.checkSet(result, [oid3, oid4]) self.checkSet(result, [oid3, oid4])
result = self.db.getOIDList(0, 1, 3, [0]) result = self.db.getOIDList(oid2, 1, 3, [0])
self.checkSet(result, [oid4]) self.checkSet(result, [oid4])
def test_getObjectHistory(self): def test_getObjectHistory(self):
...@@ -496,23 +496,18 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -496,23 +496,18 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getObjectHistory(oid, 2, 3) result = self.db.getObjectHistory(oid, 2, 3)
self.assertEqual(result, None) self.assertEqual(result, None)
def test_getTIDList(self): def _storeTransactions(self, count):
# use OID generator to know result of tid % N # use OID generator to know result of tid % N
tid1, tid2, tid3, tid4 = self.getOIDs(4) tid_list = self.getOIDs(count)
oid = self.getOID(1) oid = self.getOID(1)
txn1, objs1 = self.getTransaction([oid]) for tid in tid_list:
txn2, objs2 = self.getTransaction([oid]) txn, objs = self.getTransaction([oid])
txn3, objs3 = self.getTransaction([oid]) self.db.storeTransaction(tid, objs, txn)
txn4, objs4 = self.getTransaction([oid]) self.db.finishTransaction(tid)
# store four transaction return tid_list
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2) def test_getTIDList(self):
self.db.storeTransaction(tid3, objs3, txn3) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.finishTransaction(tid3)
self.db.finishTransaction(tid4)
# get tids # get tids
result = self.db.getTIDList(0, 4, 1, [0]) result = self.db.getTIDList(0, 4, 1, [0])
self.checkSet(result, [tid1, tid2, tid3, tid4]) self.checkSet(result, [tid1, tid2, tid3, tid4])
...@@ -530,6 +525,25 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -530,6 +525,25 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getTIDList(0, 1, 3, [0]) result = self.db.getTIDList(0, 1, 3, [0])
self.checkSet(result, [tid4]) self.checkSet(result, [tid4])
def test_getReplicationTIDList(self):
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
result = self.db.getReplicationTIDList(tid1, 4, 1, [0])
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0])
self.checkSet(result, [tid1, tid3])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0, 1])
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 3, [0])
self.checkSet(result, [tid1, tid4])
# get a subset of tids
result = self.db.getReplicationTIDList(tid3, 4, 1, [0])
self.checkSet(result, [tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 2, 1, [0])
self.checkSet(result, [tid1, tid2])
result = self.db.getReplicationTIDList(tid1, 1, 3, [1])
self.checkSet(result, [tid2])
def test_getTIDListPresent(self): def test_getTIDListPresent(self):
oid = self.getOID(1) oid = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
......
...@@ -269,13 +269,16 @@ class ProtocolTests(NeoTestBase): ...@@ -269,13 +269,16 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(p_oid_list, oid_list) self.assertEqual(p_oid_list, oid_list)
def test_36_askFinishTransaction(self): def test_36_askFinishTransaction(self):
self._testXIDAndYIDList(Packets.AskFinishTransaction)
def _testXIDAndYIDList(self, packet):
oid1 = self.getNextTID() oid1 = self.getNextTID()
oid2 = self.getNextTID() oid2 = self.getNextTID()
oid3 = self.getNextTID() oid3 = self.getNextTID()
oid4 = self.getNextTID() oid4 = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4] oid_list = [oid1, oid2, oid3, oid4]
p = Packets.AskFinishTransaction(tid, oid_list) p = packet(tid, oid_list)
p_tid, p_oid_list = p.decode() p_tid, p_oid_list = p.decode()
self.assertEqual(p_tid, tid) self.assertEqual(p_tid, tid)
self.assertEqual(p_oid_list, oid_list) self.assertEqual(p_oid_list, oid_list)
...@@ -404,12 +407,15 @@ class ProtocolTests(NeoTestBase): ...@@ -404,12 +407,15 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(partition, 5) self.assertEqual(partition, 5)
def test_50_answerTIDs(self): def test_50_answerTIDs(self):
self._test_AnswerTIDs(Packets.AnswerTIDs)
def _test_AnswerTIDs(self, packet):
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
tid3 = self.getNextTID() tid3 = self.getNextTID()
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid_list = [tid1, tid2, tid3, tid4] tid_list = [tid1, tid2, tid3, tid4]
p = Packets.AnswerTIDs(tid_list) p = packet(tid_list)
p_tid_list = p.decode()[0] p_tid_list = p.decode()[0]
self.assertEqual(p_tid_list, tid_list) self.assertEqual(p_tid_list, tid_list)
...@@ -457,10 +463,11 @@ class ProtocolTests(NeoTestBase): ...@@ -457,10 +463,11 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
def test_55_askOIDs(self): def test_55_askOIDs(self):
p = Packets.AskOIDs(1, 10, 5) oid = self.getOID(1)
first, last, partition = p.decode() p = Packets.AskOIDs(oid, 1000, 5)
self.assertEqual(first, 1) min_oid, length, partition = p.decode()
self.assertEqual(last, 10) self.assertEqual(min_oid, oid)
self.assertEqual(length, 1000)
self.assertEqual(partition, 5) self.assertEqual(partition, 5)
def test_56_answerOIDs(self): def test_56_answerOIDs(self):
...@@ -602,6 +609,30 @@ class ProtocolTests(NeoTestBase): ...@@ -602,6 +609,30 @@ class ProtocolTests(NeoTestBase):
msg = 'test' msg = 'test'
self.assertEqual(Packets.Notify(msg).decode(), (msg, )) self.assertEqual(Packets.Notify(msg).decode(), (msg, ))
def test_AskTIDsFrom(self):
tid = self.getNextTID()
p = Packets.AskTIDsFrom(tid, 1000, 5)
min_tid, length, partition = p.decode()
self.assertEqual(min_tid, tid)
self.assertEqual(length, 1000)
self.assertEqual(partition, 5)
def test_AnswerTIDsFrom(self):
self._test_AnswerTIDs(Packets.AnswerTIDsFrom)
def test_AskObjectHistoryFrom(self):
oid = self.getOID(1)
min_serial = self.getNextTID()
length = 5
p = Packets.AskObjectHistoryFrom(oid, min_serial, length)
p_oid, p_min_serial, p_length = p.decode()
self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
def test_AnswerObjectHistoryFrom(self):
self._testXIDAndYIDList(Packets.AnswerObjectHistoryFrom)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment