Commit b0753366 authored by Vincent Pelletier's avatar Vincent Pelletier

Implement rsync-ish replication.

For further description, see storage/handlers/replication.py .

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2295 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 01735352
...@@ -275,16 +275,10 @@ class EventHandler(object): ...@@ -275,16 +275,10 @@ class EventHandler(object):
def answerObjectHistory(self, conn, oid, history_list): def answerObjectHistory(self, conn, oid, history_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def askObjectHistoryFrom(self, conn, oid, min_serial, length): def askObjectHistoryFrom(self, conn, oid, min_serial, length, partition):
raise UnexpectedPacketError raise UnexpectedPacketError
def answerObjectHistoryFrom(self, conn, oid, history_list): def answerObjectHistoryFrom(self, conn, object_dict):
raise UnexpectedPacketError
def askOIDs(self, conn, min_oid, length, partition):
raise UnexpectedPacketError
def answerOIDs(self, conn, oid_list):
raise UnexpectedPacketError raise UnexpectedPacketError
def askPartitionList(self, conn, min_offset, max_offset, uuid): def askPartitionList(self, conn, min_offset, max_offset, uuid):
...@@ -359,6 +353,21 @@ class EventHandler(object): ...@@ -359,6 +353,21 @@ class EventHandler(object):
def answerPack(self, conn, status): def answerPack(self, conn, status):
raise UnexpectedPacketError raise UnexpectedPacketError
def askCheckTIDRange(self, conn, min_tid, length, partition):
raise UnexpectedPacketError
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid):
raise UnexpectedPacketError
def askCheckSerialRange(self, conn, min_oid, min_serial, length,
partition):
raise UnexpectedPacketError
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial):
raise UnexpectedPacketError
# Error packet handlers. # Error packet handlers.
def error(self, conn, code, message): def error(self, conn, code, message):
...@@ -450,8 +459,6 @@ class EventHandler(object): ...@@ -450,8 +459,6 @@ class EventHandler(object):
d[Packets.AnswerObjectHistory] = self.answerObjectHistory d[Packets.AnswerObjectHistory] = self.answerObjectHistory
d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom
d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom
d[Packets.AskOIDs] = self.askOIDs
d[Packets.AnswerOIDs] = self.answerOIDs
d[Packets.AskPartitionList] = self.askPartitionList d[Packets.AskPartitionList] = self.askPartitionList
d[Packets.AnswerPartitionList] = self.answerPartitionList d[Packets.AnswerPartitionList] = self.answerPartitionList
d[Packets.AskNodeList] = self.askNodeList d[Packets.AskNodeList] = self.askNodeList
...@@ -476,6 +483,10 @@ class EventHandler(object): ...@@ -476,6 +483,10 @@ class EventHandler(object):
d[Packets.AnswerBarrier] = self.answerBarrier d[Packets.AnswerBarrier] = self.answerBarrier
d[Packets.AskPack] = self.askPack d[Packets.AskPack] = self.askPack
d[Packets.AnswerPack] = self.answerPack d[Packets.AnswerPack] = self.answerPack
d[Packets.AskCheckTIDRange] = self.askCheckTIDRange
d[Packets.AnswerCheckTIDRange] = self.answerCheckTIDRange
d[Packets.AskCheckSerialRange] = self.askCheckSerialRange
d[Packets.AnswerCheckSerialRange] = self.answerCheckSerialRange
return d return d
......
...@@ -113,6 +113,7 @@ INVALID_PARTITION = 0xffffffff ...@@ -113,6 +113,7 @@ INVALID_PARTITION = 0xffffffff
ZERO_TID = '\0' * 8 ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8 ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID) OID_LEN = len(INVALID_OID)
TID_LEN = len(INVALID_TID)
UUID_NAMESPACES = { UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S', NodeTypes.STORAGE: 'S',
...@@ -1167,63 +1168,47 @@ class AnswerObjectHistory(Packet): ...@@ -1167,63 +1168,47 @@ class AnswerObjectHistory(Packet):
class AskObjectHistoryFrom(Packet): class AskObjectHistoryFrom(Packet):
""" """
Ask history information for a given object. The order of serials is Ask history information for a given object. The order of serials is
ascending, and starts at (or above) min_serial. S -> S. ascending, and starts at (or above) min_serial for min_oid. S -> S.
""" """
_header_format = '!8s8sL' _header_format = '!8s8sLL'
def _encode(self, oid, min_serial, length): def _encode(self, min_oid, min_serial, length, partition):
return pack(self._header_format, oid, min_serial, length) return pack(self._header_format, min_oid, min_serial, length,
partition)
def _decode(self, body): def _decode(self, body):
return unpack(self._header_format, body) # oid, min_serial, length # min_oid, min_serial, length, partition
return unpack(self._header_format, body)
class AnswerObjectHistoryFrom(AskFinishTransaction): class AnswerObjectHistoryFrom(Packet):
""" """
Answer the requested serials. S -> S. Answer the requested serials. S -> S.
""" """
# This is similar to AskFinishTransaction as TID size is identical to OID
# size:
# - we have a single OID (TID in AskFinishTransaction)
# - we have a list of TIDs (OIDs in AskFinishTransaction)
pass
class AskOIDs(Packet):
"""
Ask for length OIDs starting at min_oid. S -> S.
"""
_header_format = '!8sLL'
def _encode(self, min_oid, length, partition):
return pack(self._header_format, min_oid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_oid, length, partition
class AnswerOIDs(Packet):
"""
Answer the requested OIDs. S -> S.
"""
_header_format = '!L' _header_format = '!L'
_list_entry_format = '8s' _list_entry_format = '!8sL'
_list_entry_len = calcsize(_list_entry_format) _list_entry_len = calcsize(_list_entry_format)
def _encode(self, oid_list): def _encode(self, object_dict):
body = [pack(self._header_format, len(oid_list))] body = [pack(self._header_format, len(object_dict))]
body.extend(oid_list) append = body.append
extend = body.extend
list_entry_format = self._list_entry_format
for oid, serial_list in object_dict.iteritems():
append(pack(list_entry_format, oid, len(serial_list)))
extend(serial_list)
return ''.join(body) return ''.join(body)
def _decode(self, body): def _decode(self, body):
offset = self._header_len body = StringIO(body)
(n,) = unpack(self._header_format, body[:offset]) read = body.read
oid_list = []
list_entry_format = self._list_entry_format list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len list_entry_len = self._list_entry_len
for _ in xrange(n): object_dict = {}
next_offset = offset + list_entry_len dict_len = unpack(self._header_format, read(self._header_len))[0]
oid = unpack(list_entry_format, body[offset:next_offset])[0] for _ in xrange(dict_len):
offset = next_offset oid, serial_len = unpack(list_entry_format, read(list_entry_len))
oid_list.append(oid) object_dict[oid] = [read(TID_LEN) for _ in xrange(serial_len)]
return (oid_list,) return (object_dict, )
class AskPartitionList(Packet): class AskPartitionList(Packet):
""" """
...@@ -1660,6 +1645,73 @@ class AnswerPack(Packet): ...@@ -1660,6 +1645,73 @@ class AnswerPack(Packet):
def _decode(self, body): def _decode(self, body):
return (bool(unpack(self._header_format, body)[0]), ) return (bool(unpack(self._header_format, body)[0]), )
class AskCheckTIDRange(Packet):
"""
Ask some stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8sLL'
def _encode(self, min_tid, length, partition):
return pack(self._header_format, min_tid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition
class AnswerCheckTIDRange(Packet):
"""
Stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8sLLQ8s'
def _encode(self, min_tid, length, count, tid_checksum, max_tid):
return pack(self._header_format, min_tid, length, count, tid_checksum,
max_tid)
def _decode(self, body):
# min_tid, length, partition, count, tid_checksum, max_tid
return unpack(self._header_format, body)
class AskCheckSerialRange(Packet):
"""
Ask some stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8s8sLL'
def _encode(self, min_oid, min_serial, length, partition):
return pack(self._header_format, min_oid, min_serial, length,
partition)
def _decode(self, body):
# min_oid, min_serial, length, partition
return unpack(self._header_format, body)
class AnswerCheckSerialRange(Packet):
"""
Stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8s8sLLQ8sQ8s'
def _encode(self, min_oid, min_serial, length, count, oid_checksum,
max_oid, serial_checksum, max_serial):
return pack(self._header_format, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial)
def _decode(self, body):
# min_oid, min_serial, length, count, oid_checksum, max_oid,
# serial_checksum, max_serial
return unpack(self._header_format, body)
class Error(Packet): class Error(Packet):
""" """
Error is a special type of message, because this can be sent against Error is a special type of message, because this can be sent against
...@@ -1844,10 +1896,6 @@ class PacketRegistry(dict): ...@@ -1844,10 +1896,6 @@ class PacketRegistry(dict):
0x001F, 0x001F,
AskObjectHistory, AskObjectHistory,
AnswerObjectHistory) AnswerObjectHistory)
AskOIDs, AnswerOIDs = register(
0x0020,
AskOIDs,
AnswerOIDs)
AskPartitionList, AnswerPartitionList = register( AskPartitionList, AnswerPartitionList = register(
0x0021, 0x0021,
AskPartitionList, AskPartitionList,
...@@ -1903,6 +1951,16 @@ class PacketRegistry(dict): ...@@ -1903,6 +1951,16 @@ class PacketRegistry(dict):
0x0038, 0x0038,
AskPack, AskPack,
AnswerPack) AnswerPack)
AskCheckTIDRange, AnswerCheckTIDRange = register(
0x0039,
AskCheckTIDRange,
AnswerCheckTIDRange,
)
AskCheckSerialRange, AnswerCheckSerialRange = register(
0x003A,
AskCheckSerialRange,
AnswerCheckSerialRange,
)
# build a "singleton" # build a "singleton"
Packets = PacketRegistry() Packets = PacketRegistry()
......
...@@ -288,6 +288,12 @@ class Application(object): ...@@ -288,6 +288,12 @@ class Application(object):
while True: while True:
em.poll(1) em.poll(1)
if self.replicator.pending(): if self.replicator.pending():
# Call processDelayedTasks before act, so tasks added in the
# act call are executed after one poll call, so that sent
# packets are already on the network and delayed task
# processing happens in parallel with the same task on the
# other storage node.
self.replicator.processDelayedTasks()
self.replicator.act() self.replicator.act()
def wait(self): def wait(self):
......
...@@ -274,6 +274,11 @@ class DatabaseManager(object): ...@@ -274,6 +274,11 @@ class DatabaseManager(object):
area.""" area."""
raise NotImplementedError raise NotImplementedError
def deleteObject(self, oid, serial=None):
"""Delete given object. If serial is given, only delete that serial for
given oid."""
raise NotImplementedError
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information, """Return a tuple of the list of OIDs, user information,
a description, and extension information, for a given transaction a description, and extension information, for a given transaction
...@@ -282,12 +287,6 @@ class DatabaseManager(object): ...@@ -282,12 +287,6 @@ class DatabaseManager(object):
area as well.""" area as well."""
raise NotImplementedError raise NotImplementedError
def getOIDList(self, min_oid, length, num_partitions, partition_list):
"""Return a list of OIDs in ascending order from a minimal oid,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
def getObjectHistory(self, oid, offset = 0, length = 1): def getObjectHistory(self, oid, offset = 0, length = 1):
"""Return a list of serials and sizes for a given object ID. """Return a list of serials and sizes for a given object ID.
The length specifies the maximum size of such a list. Result starts The length specifies the maximum size of such a list. Result starts
...@@ -295,9 +294,11 @@ class DatabaseManager(object): ...@@ -295,9 +294,11 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None.""" If there is no such object ID in a database, return None."""
raise NotImplementedError raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, length): def getObjectHistoryFrom(self, oid, min_serial, length, num_partitions,
"""Return a list of length serials for a given object ID at (or above) partition):
min_serial, sorted in ascending order.""" """Return a dict of length serials grouped by oid at (or above)
min_oid and min_serial, for given partition, sorted in ascending
order."""
raise NotImplementedError raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
...@@ -307,20 +308,10 @@ class DatabaseManager(object): ...@@ -307,20 +308,10 @@ class DatabaseManager(object):
raise NotImplementedError raise NotImplementedError
def getReplicationTIDList(self, min_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list): partition):
"""Return a list of TIDs in ascending order from an initial tid value, """Return a list of TIDs in ascending order from an initial tid value,
at most the specified length. The list of partitions are passed at most the specified length. The partition number is passed to filter
to filter out non-applicable TIDs.""" out non-applicable TIDs."""
raise NotImplementedError
def getTIDListPresent(self, tid_list):
"""Return a list of TIDs which are present in a database among
the given list."""
raise NotImplementedError
def getSerialListPresent(self, oid, serial_list):
"""Return a list of serials which are present in a database among
the given list."""
raise NotImplementedError raise NotImplementedError
def pack(self, tid, updateObjectDataForPack): def pack(self, tid, updateObjectDataForPack):
......
...@@ -24,7 +24,7 @@ import string ...@@ -24,7 +24,7 @@ import string
from neo.storage.database import DatabaseManager from neo.storage.database import DatabaseManager
from neo.exception import DatabaseFailure from neo.exception import DatabaseFailure
from neo.protocol import CellStates from neo.protocol import CellStates, ZERO_OID, ZERO_TID
from neo import util from neo import util
LOG_QUERIES = False LOG_QUERIES = False
...@@ -576,6 +576,23 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -576,6 +576,23 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def deleteObject(self, oid, serial=None):
u64 = util.u64
query_param_dict = {
'oid': u64(oid),
}
query_fmt = 'DELETE FROM obj WHERE oid = %(oid)d'
if serial is not None:
query_param_dict['serial'] = u64(serial)
query_fmt = query_fmt + ' AND serial = %(serial)d'
self.begin()
try:
self.query(query_fmt % query_param_dict)
except:
self.rollback()
raise
self.commit()
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
q = self.query q = self.query
tid = util.u64(tid) tid = util.u64(tid)
...@@ -594,20 +611,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -594,20 +611,6 @@ class MySQLDatabaseManager(DatabaseManager):
return oid_list, user, desc, ext, bool(packed) return oid_list, user, desc, ext, bool(packed)
return None return None
def getOIDList(self, min_oid, length, num_partitions,
partition_list):
q = self.query
r = q("""SELECT DISTINCT oid FROM obj WHERE
MOD(oid, %(num_partitions)d) in (%(partitions)s)
AND oid >= %(min_oid)d
ORDER BY oid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'min_oid': util.u64(min_oid),
'length': length,
})
return [util.p64(t[0]) for t in r]
def _getObjectLength(self, oid, value_serial): def _getObjectLength(self, oid, value_serial):
if value_serial is None: if value_serial is None:
raise CreationUndone raise CreationUndone
...@@ -646,18 +649,32 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -646,18 +649,32 @@ class MySQLDatabaseManager(DatabaseManager):
return result return result
return None return None
def getObjectHistoryFrom(self, oid, min_serial, length): def getObjectHistoryFrom(self, min_oid, min_serial, length, num_partitions,
partition):
q = self.query q = self.query
oid = util.u64(oid) u64 = util.u64
p64 = util.p64 p64 = util.p64
r = q("""SELECT serial FROM obj min_oid = u64(min_oid)
WHERE oid = %(oid)d AND serial >= %(min_serial)d min_serial = u64(min_serial)
ORDER BY serial ASC LIMIT %(length)d""" % { r = q('SELECT oid, serial FROM obj '
'oid': oid, 'WHERE ((oid = %(min_oid)d AND serial >= %(min_serial)d) OR '
'min_serial': util.u64(min_serial), 'oid > %(min_oid)d) AND '
'MOD(oid, %(num_partitions)d) = %(partition)s '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': min_oid,
'min_serial': min_serial,
'length': length, 'length': length,
'num_partitions': num_partitions,
'partition': partition,
}) })
return [p64(t[0]) for t in r] result = {}
for oid, serial in r:
try:
serial_list = result[oid]
except KeyError:
serial_list = result[oid] = []
serial_list.append(p64(serial))
return dict((p64(x), y) for x, y in result.iteritems())
def getTIDList(self, offset, length, num_partitions, partition_list): def getTIDList(self, offset, length, num_partitions, partition_list):
q = self.query q = self.query
...@@ -669,32 +686,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -669,32 +686,19 @@ class MySQLDatabaseManager(DatabaseManager):
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, length, num_partitions, def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list): partition):
q = self.query q = self.query
r = q("""SELECT tid FROM trans WHERE r = q("""SELECT tid FROM trans WHERE
MOD(tid, %(num_partitions)d) in (%(partitions)s) MOD(tid, %(num_partitions)d) = %(partition)d
AND tid >= %(min_tid)d AND tid >= %(min_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % { ORDER BY tid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions, 'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]), 'partition': partition,
'min_tid': util.u64(min_tid), 'min_tid': util.u64(min_tid),
'length': length, 'length': length,
}) })
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getTIDListPresent(self, tid_list):
q = self.query
r = q("""SELECT tid FROM trans WHERE tid in (%s)""" \
% ','.join([str(util.u64(tid)) for tid in tid_list]))
return [util.p64(t[0]) for t in r]
def getSerialListPresent(self, oid, serial_list):
q = self.query
oid = util.u64(oid)
r = q("""SELECT serial FROM obj WHERE oid = %d AND serial in (%s)""" \
% (oid, ','.join([str(util.u64(serial)) for serial in serial_list])))
return [util.p64(t[0]) for t in r]
def _updatePackFuture(self, oid, orig_serial, max_serial, def _updatePackFuture(self, oid, orig_serial, max_serial,
updateObjectDataForPack): updateObjectDataForPack):
q = self.query q = self.query
...@@ -784,3 +788,53 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -784,3 +788,53 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
self.commit() self.commit()
def checkTIDRange(self, min_tid, length, num_partitions, partition):
# XXX: XOR is a lame checksum
count, tid_checksum, max_tid = self.query('SELECT COUNT(*), '
'BIT_XOR(tid), MAX(tid) FROM ('
'SELECT tid FROM trans '
'WHERE MOD(tid, %(num_partitions)d) = %(partition)s '
'AND tid >= %(min_tid)d '
'ORDER BY tid ASC LIMIT %(length)d'
') AS foo' % {
'num_partitions': num_partitions,
'partition': partition,
'min_tid': util.u64(min_tid),
'length': length,
})[0]
if count == 0:
tid_checksum = 0
max_tid = ZERO_TID
else:
max_tid = util.p64(max_tid)
return count, tid_checksum, max_tid
def checkSerialRange(self, min_oid, min_serial, length, num_partitions,
partition):
# XXX: XOR is a lame checksum
u64 = util.u64
p64 = util.p64
r = self.query('SELECT oid, serial FROM obj WHERE '
'(oid > %(min_oid)d OR '
'(oid = %(min_oid)d AND serial >= %(min_serial)d)) '
'AND MOD(oid, %(num_partitions)d) = %(partition)s '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': u64(min_oid),
'min_serial': u64(min_serial),
'length': length,
'num_partitions': num_partitions,
'partition': partition,
})
count = len(r)
oid_checksum = serial_checksum = 0
if count == 0:
max_oid = ZERO_OID
max_serial = ZERO_TID
else:
for max_oid, max_serial in r:
oid_checksum ^= max_oid
serial_checksum ^= max_serial
max_oid = p64(max_oid)
max_serial = p64(max_serial)
return count, oid_checksum, max_oid, serial_checksum, max_serial
...@@ -22,6 +22,48 @@ from neo.handler import EventHandler ...@@ -22,6 +22,48 @@ from neo.handler import EventHandler
from neo.protocol import Packets, ZERO_TID, ZERO_OID from neo.protocol import Packets, ZERO_TID, ZERO_OID
from neo import util from neo import util
# TODO: benchmark how different values behave
RANGE_LENGTH = 4000
MIN_RANGE_LENGTH = 1000
"""
Replication algorythm
Purpose: replicate the content of a reference node into a replicating node,
bringing it up-to-date.
This happens both when a new storage is added to en existing cluster, as well
as when a nde was separated from cluster and rejoins it.
Replication happens per partition. Reference node can change between
partitions.
2 parts, done sequentially:
- Transaction (metadata) replication
- Object (data) replication
Both part follow the same mechanism:
- On both sides (replicating and reference), compute a checksum of a chunk
(RANGE_LENGTH number of entries). If there is a mismatch, chunk size is
reduced, and scan restarts from same row, until it reaches a minimal length
(MIN_RANGE_LENGTH). Then, it replicates all rows in that chunk. If the
content of chunks match, it moves on to the next chunk.
- Replicating a chunk starts with asking for a list of all entries (only their
identifier) and skipping those both side have, deleting those which reference
has and replicating doesn't, and asking individually all entries missing in
replicating.
"""
# TODO: Make object replication get ordered by serial first and oid second, so
# changes are in a big segment at the end, rather than in many segments (one
# per object).
# TODO: To improve performance when a pack happened, the following algorithm
# should be used:
# - If reference node packed, find non-existant oids in reference node (their
# creation was undone, and pack pruned them), and delete them.
# - Run current algorithm, starting at our last pack TID.
# - Pack partition at reference's TID.
def checkConnectionIsReplicatorConnection(func): def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw): def decorator(self, conn, *args, **kw):
if self.app.replicator.current_connection is conn: if self.app.replicator.current_connection is conn:
...@@ -51,29 +93,27 @@ class ReplicationHandler(EventHandler): ...@@ -51,29 +93,27 @@ class ReplicationHandler(EventHandler):
uuid, num_partitions, num_replicas, your_uuid): uuid, num_partitions, num_replicas, your_uuid):
# set the UUID on the connection # set the UUID on the connection
conn.setUUID(uuid) conn.setUUID(uuid)
self.startReplication(conn)
def startReplication(self, conn):
conn.ask(self._doAskCheckTIDRange(ZERO_TID), timeout=300)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerTIDsFrom(self, conn, tid_list): def answerTIDsFrom(self, conn, tid_list):
app = self.app app = self.app
if tid_list:
# If I have pending TIDs, check which TIDs I don't have, and # If I have pending TIDs, check which TIDs I don't have, and
# request the data. # request the data.
present_tid_list = app.dm.getTIDListPresent(tid_list) tid_set = frozenset(tid_list)
tid_set = set(tid_list) - set(present_tid_list) my_tid_set = frozenset(app.replicator.getTIDsFromResult())
for tid in tid_set: extra_tid_set = my_tid_set - tid_set
if extra_tid_set:
deleteTransaction = app.dm.deleteTransaction
for tid in extra_tid_set:
deleteTransaction(tid)
missing_tid_set = tid_set - my_tid_set
for tid in missing_tid_set:
conn.ask(Packets.AskTransactionInformation(tid), timeout=300) conn.ask(Packets.AskTransactionInformation(tid), timeout=300)
# And, ask more TIDs.
p = Packets.AskTIDsFrom(add64(tid_list[-1], 1), 1000,
app.replicator.current_partition.getRID())
conn.ask(p, timeout=300)
else:
# If no more TID, a replication of transactions is finished.
# So start to replicate objects now.
p = Packets.AskOIDs(ZERO_OID, 1000,
app.replicator.current_partition.getRID())
conn.ask(p, timeout=300)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerTransactionInformation(self, conn, tid, def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list): user, desc, ext, packed, oid_list):
...@@ -83,46 +123,23 @@ class ReplicationHandler(EventHandler): ...@@ -83,46 +123,23 @@ class ReplicationHandler(EventHandler):
False) False)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerOIDs(self, conn, oid_list): def answerObjectHistoryFrom(self, conn, object_dict):
app = self.app
if oid_list:
app.replicator.next_oid = add64(oid_list[-1], 1)
# Pick one up, and ask the history.
oid = oid_list.pop()
conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
timeout=300)
app.replicator.oid_list = oid_list
else:
# Nothing remains, so the replication for this partition is
# finished.
app.replicator.replication_done = True
@checkConnectionIsReplicatorConnection
def answerObjectHistoryFrom(self, conn, oid, serial_list):
app = self.app app = self.app
if serial_list: my_object_dict = app.replicator.getObjectHistoryFromResult()
deleteObject = app.dm.deleteObject
for oid, serial_list in object_dict.iteritems():
# Check if I have objects, request those which I don't have. # Check if I have objects, request those which I don't have.
present_serial_list = app.dm.getSerialListPresent(oid, serial_list) if oid in my_object_dict:
serial_set = set(serial_list) - set(present_serial_list) my_serial_set = frozenset(my_object_dict[oid])
for serial in serial_set: serial_set = frozenset(serial_list)
conn.ask(Packets.AskObject(oid, serial, None), timeout=300) extra_serial_set = my_serial_set - serial_set
for serial in extra_serial_set:
# And, ask more serials. deleteObject(oid, serial)
conn.ask(Packets.AskObjectHistoryFrom(oid, missing_serial_set = serial_set - my_serial_set
add64(serial_list[-1], 1), 1000), timeout=300)
else: else:
# This OID is finished. So advance to next. missing_serial_set = serial_list
oid_list = app.replicator.oid_list for serial in missing_serial_set:
if oid_list: conn.ask(Packets.AskObject(oid, serial, None), timeout=300)
# If I have more pending OIDs, pick one up.
oid = oid_list.pop()
conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
timeout=300)
else:
# Otherwise, acquire more OIDs.
p = Packets.AskOIDs(app.replicator.next_oid, 1000,
app.replicator.current_partition.getRID())
conn.ask(p, timeout=300)
@checkConnectionIsReplicatorConnection @checkConnectionIsReplicatorConnection
def answerObject(self, conn, oid, serial_start, def answerObject(self, conn, oid, serial_start,
...@@ -134,3 +151,97 @@ class ReplicationHandler(EventHandler): ...@@ -134,3 +151,97 @@ class ReplicationHandler(EventHandler):
del obj del obj
del data del data
def _doAskCheckSerialRange(self, min_oid, min_tid, length=RANGE_LENGTH):
replicator = self.app.replicator
partition = replicator.current_partition.getRID()
replicator.checkSerialRange(min_oid, min_tid, length, partition)
return Packets.AskCheckSerialRange(min_oid, min_tid, length, partition)
def _doAskCheckTIDRange(self, min_tid, length=RANGE_LENGTH):
replicator = self.app.replicator
partition = replicator.current_partition.getRID()
replicator.checkTIDRange(min_tid, length, partition)
return Packets.AskCheckTIDRange(min_tid, length, partition)
def _doAskTIDsFrom(self, min_tid, length):
replicator = self.app.replicator
partition = replicator.current_partition.getRID()
replicator.getTIDsFrom(min_tid, length, partition)
return Packets.AskTIDsFrom(min_tid, length, partition)
def _doAskObjectHistoryFrom(self, min_oid, min_serial, length):
replicator = self.app.replicator
partition = replicator.current_partition.getRID()
replicator.getObjectHistoryFrom(min_oid, min_serial, length, partition)
return Packets.AskObjectHistoryFrom(min_oid, min_serial, length,
partition)
@checkConnectionIsReplicatorConnection
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid):
app = self.app
replicator = app.replicator
our = replicator.getTIDCheckResult(min_tid, length)
his = (count, tid_checksum, max_tid)
our_count = our[0]
our_max_tid = our[2]
p = None
if our != his:
# Something is different...
if length <= MIN_RANGE_LENGTH:
# We are already at minimum chunk length, replicate.
conn.ask(self._doAskTIDsFrom(min_tid, count))
else:
# Check a smaller chunk.
# Note: this could be made into a real binary search, but is
# it really worth the work ?
# Note: +1, so we can detect we reached the end when answer
# comes back.
p = self._doAskCheckTIDRange(min_tid, min(length / 2,
count + 1))
if p is None:
if count == length:
# Go on with next chunk
p = self._doAskCheckTIDRange(add64(max_tid, 1))
else:
# If no more TID, a replication of transactions is finished.
# So start to replicate objects now.
p = self._doAskCheckSerialRange(ZERO_OID, ZERO_TID)
conn.ask(p)
@checkConnectionIsReplicatorConnection
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial):
app = self.app
replicator = app.replicator
our = replicator.getSerialCheckResult(min_oid, min_serial, length)
his = (count, oid_checksum, max_oid, serial_checksum, max_serial)
our_count = our[0]
our_max_oid = our[2]
our_max_serial = our[4]
p = None
if our != his:
# Something is different...
if length <= MIN_RANGE_LENGTH:
# We are already at minimum chunk length, replicate.
conn.ask(self._doAskObjectHistoryFrom(min_oid, min_serial,
count))
else:
# Check a smaller chunk.
# Note: this could be made into a real binary search, but is
# it really worth the work ?
# Note: +1, so we can detect we reached the end when answer
# comes back.
p = self._doAskCheckSerialRange(min_oid, min_serial,
min(length / 2, count + 1))
if p is None:
if count == length:
# Go on with next chunk
p = self._doAskCheckSerialRange(max_oid, add64(max_serial, 1))
else:
# Nothing remains, so the replication for this partition is
# finished.
replicator.replication_done = True
if p is not None:
conn.ask(p)
...@@ -30,34 +30,32 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -30,34 +30,32 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
tid = app.dm.getLastTID() tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID())) conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askOIDs(self, conn, min_oid, length, partition):
# This method is complicated, because I must return OIDs only
# about usable partitions assigned to me.
app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
oid_list = app.dm.getOIDList(min_oid, length,
app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerOIDs(oid_list))
def askTIDsFrom(self, conn, min_tid, length, partition): def askTIDsFrom(self, conn, min_tid, length, partition):
# This method is complicated, because I must return TIDs only
# about usable partitions assigned to me.
app = self.app app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
tid_list = app.dm.getReplicationTIDList(min_tid, length, tid_list = app.dm.getReplicationTIDList(min_tid, length,
app.pt.getPartitions(), partition_list) app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerTIDsFrom(tid_list)) conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, oid, min_serial, length): def askObjectHistoryFrom(self, conn, min_oid, min_serial, length,
partition):
app = self.app
object_dict = app.dm.getObjectHistoryFrom(min_oid, min_serial, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
def askCheckTIDRange(self, conn, min_tid, length, partition):
app = self.app
count, tid_checksum, max_tid = app.dm.checkTIDRange(min_tid, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length, count,
tid_checksum, max_tid))
def askCheckSerialRange(self, conn, min_oid, min_serial, length,
partition):
app = self.app app = self.app
history_list = app.dm.getObjectHistoryFrom(oid, min_serial, length) count, oid_checksum, max_oid, serial_checksum, max_serial = \
conn.answer(Packets.AnswerObjectHistoryFrom(oid, history_list)) app.dm.checkSerialRange(min_oid, min_serial, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial))
...@@ -46,6 +46,46 @@ class Partition(object): ...@@ -46,6 +46,46 @@ class Partition(object):
return tid is not None and ( return tid is not None and (
min_pending_tid is None or tid < min_pending_tid) min_pending_tid is None or tid < min_pending_tid)
class Task(object):
"""
A Task is a callable to execute at another time, with given parameters.
Execution result is kept and can be retrieved later.
"""
_func = None
_args = None
_kw = None
_result = None
_processed = False
def __init__(self, func, args=(), kw=None):
self._func = func
self._args = args
if kw is None:
kw = {}
self._kw = kw
def process(self):
if self._processed:
raise ValueError, 'You cannot process a single Task twice'
self._processed = True
self._result = self._func(*self._args, **self._kw)
def getResult(self):
# Should we instead execute immediately rather than raising ?
if not self._processed:
raise ValueError, 'You cannot get a result until task is executed'
return self._result
def __repr__(self):
fmt = '<%s at %x %r(*%r, **%r)%%s>' % (self.__class__.__name__,
id(self), self._func, self._args, self._kw)
if self._processed:
extra = ' => %r' % (self._result, )
else:
extra = ''
return fmt % (extra, )
class Replicator(object): class Replicator(object):
"""This class handles replications of objects and transactions. """This class handles replications of objects and transactions.
...@@ -98,21 +138,23 @@ class Replicator(object): ...@@ -98,21 +138,23 @@ class Replicator(object):
# didn't answer yet. # didn't answer yet.
# unfinished_tid_list # unfinished_tid_list
# The list of unfinished TIDs known by master node. # The list of unfinished TIDs known by master node.
# oid_list
# List of OIDs to replicate. Doesn't contains currently-replicated
# object.
# XXX: not defined here
# XXX: accessed (r/w) directly by ReplicationHandler
# next_oid
# Next OID to ask when oid_list is empty.
# XXX: not defined here
# XXX: accessed (r/w) directly by ReplicationHandler
# replication_done # replication_done
# False if we know there is something to replicate. # False if we know there is something to replicate.
# True when current_partition is replicated, or we don't know yet if # True when current_partition is replicated, or we don't know yet if
# there is something to replicate # there is something to replicate
# XXX: accessed (w) directly by ReplicationHandler # XXX: accessed (w) directly by ReplicationHandler
new_partition_dict = None
critical_tid_dict = None
partition_dict = None
task_list = None
task_dict = None
current_partition = None
current_connection = None
waiting_for_unfinished_tids = None
unfinished_tid_list = None
replication_done = None
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
...@@ -129,6 +171,8 @@ class Replicator(object): ...@@ -129,6 +171,8 @@ class Replicator(object):
def reset(self): def reset(self):
"""Reset attributes to restart replicating.""" """Reset attributes to restart replicating."""
self.task_list = []
self.task_dict = {}
self.current_partition = None self.current_partition = None
self.current_connection = None self.current_connection = None
self.waiting_for_unfinished_tids = False self.waiting_for_unfinished_tids = False
...@@ -213,15 +257,12 @@ class Replicator(object): ...@@ -213,15 +257,12 @@ class Replicator(object):
p = Packets.RequestIdentification(NodeTypes.STORAGE, p = Packets.RequestIdentification(NodeTypes.STORAGE,
app.uuid, app.server, app.name) app.uuid, app.server, app.name)
self.current_connection.ask(p) self.current_connection.ask(p)
else:
p = Packets.AskTIDsFrom(ZERO_TID, 1000, self.current_connection.getHandler().startReplication(
self.current_partition.getRID()) self.current_connection)
self.current_connection.ask(p, timeout=300)
self.replication_done = False self.replication_done = False
def _finishReplication(self): def _finishReplication(self):
app = self.app
# TODO: remove try..except: pass # TODO: remove try..except: pass
try: try:
self.partition_dict.pop(self.current_partition.getRID()) self.partition_dict.pop(self.current_partition.getRID())
...@@ -243,7 +284,11 @@ class Replicator(object): ...@@ -243,7 +284,11 @@ class Replicator(object):
self._askCriticalTID() self._askCriticalTID()
if self.current_partition is not None: if self.current_partition is not None:
if self.replication_done: # Don't end replication until we have received all expected
# answers, as we might have asked object data just before the last
# AnswerCheckSerialRange.
if self.replication_done and \
not self.current_connection.isPending():
# finish a replication # finish a replication
logging.info('replication is done for %s' % logging.info('replication is done for %s' %
(self.current_partition.getRID(), )) (self.current_partition.getRID(), ))
...@@ -289,3 +334,57 @@ class Replicator(object): ...@@ -289,3 +334,57 @@ class Replicator(object):
and not self.new_partition_dict.has_key(rid): and not self.new_partition_dict.has_key(rid):
self.new_partition_dict[rid] = Partition(rid) self.new_partition_dict[rid] = Partition(rid)
def _addTask(self, key, func, args=(), kw=None):
task = Task(func, args, kw)
task_dict = self.task_dict
if key in task_dict:
raise ValueError, 'Task with key %r already exists (%r), cannot ' \
'add %r' % (key, task_dict[key], task)
task_dict[key] = task
self.task_list.append(task)
def processDelayedTasks(self):
task_list = self.task_list
if task_list:
for task in task_list:
task.process()
self.task_list = []
def checkTIDRange(self, min_tid, length, partition):
app = self.app
self._addTask(('TID', min_tid, length), app.dm.checkTIDRange,
(min_tid, length, app.pt.getPartitions(), partition))
def checkSerialRange(self, min_oid, min_serial, length, partition):
app = self.app
self._addTask(('Serial', min_oid, min_serial, length),
app.dm.checkSerialRange, (min_oid, min_serial, length,
app.pt.getPartitions(), partition))
def getTIDsFrom(self, min_tid, length, partition):
app = self.app
self._addTask('TIDsFrom',
app.dm.getReplicationTIDList, (min_tid, length,
app.pt.getPartitions(), partition))
def getObjectHistoryFrom(self, min_oid, min_serial, length, partition):
app = self.app
self._addTask('ObjectHistoryFrom',
app.dm.getObjectHistoryFrom, (min_oid, min_serial, length,
app.pt.getPartitions(), partition))
def _getCheckResult(self, key):
return self.task_dict.pop(key).getResult()
def getTIDCheckResult(self, min_tid, length):
return self._getCheckResult(('TID', min_tid, length))
def getSerialCheckResult(self, min_oid, min_serial, length):
return self._getCheckResult(('Serial', min_oid, min_serial, length))
def getTIDsFromResult(self):
return self._getCheckResult('TIDsFrom')
def getObjectHistoryFromResult(self):
return self._getCheckResult('ObjectHistoryFrom')
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock
from neo.tests import NeoTestBase
from neo.protocol import Packets, ZERO_OID, ZERO_TID
from neo.storage.handlers.replication import ReplicationHandler, add64
from neo.storage.handlers.replication import RANGE_LENGTH, MIN_RANGE_LENGTH
class FakeDict(object):
def __init__(self, items):
self._items = items
self._dict = dict(items)
assert len(self._dict) == len(items), self._dict
def iteritems(self):
for item in self._items:
yield item
def iterkeys(self):
for key, value in self.iteritems():
yield key
def itervalues(self):
for key, value in self.iteritems():
yield value
def items(self):
return self._items[:]
def keys(self):
return [x for x, y in self._items]
def values(self):
return [y for x, y in self._items]
def __getitem__(self, key):
return self._dict[key]
def __getattr__(self, key):
return getattr(self._dict, key)
def __len__(self):
return len(self._dict)
class StorageReplicationHandlerTests(NeoTestBase):
def setup(self):
pass
def teardown(self):
pass
def getApp(self, conn=None, tid_check_result=(0, 0, ZERO_TID),
serial_check_result=(0, 0, ZERO_OID, 0, ZERO_TID),
tid_result=(),
history_result=None,
rid=0, critical_tid=ZERO_TID):
if history_result is None:
history_result = {}
replicator = Mock({
'__repr__': 'Fake replicator',
'reset': None,
'checkSerialRange': None,
'checkTIDRange': None,
'getTIDCheckResult': tid_check_result,
'getSerialCheckResult': serial_check_result,
'getTIDsFromResult': tid_result,
'getObjectHistoryFromResult': history_result,
'checkSerialRange': None,
'checkTIDRange': None,
'getTIDsFrom': None,
'getObjectHistoryFrom': None,
})
replicator.current_partition = Mock({
'getRID': rid,
'getCriticalTID': critical_tid,
})
replicator.current_connection = conn
real_replicator = replicator
class FakeApp(object):
replicator = real_replicator
dm = Mock({
'storeTransaction': None,
})
return FakeApp
def _checkReplicationStarted(self, conn, rid, replicator):
min_tid, length, partition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(min_tid, ZERO_TID)
self.assertEqual(length, RANGE_LENGTH)
self.assertEqual(partition, rid)
calls = replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, length, partition)
def _checkPacketTIDList(self, conn, tid_list):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
self.assertEqual(len(packet_list), len(tid_list))
for packet in packet_list:
self.assertEqual(packet.getType(),
Packets.AskTransactionInformation)
ptid = packet.decode()[0]
for tid in tid_list:
if ptid == tid:
tid_list.remove(tid)
break
else:
raise AssertionFailed, '%s not found in %r' % (dump(ptid),
[dump(x) for x in tid_list])
def _checkPacketSerialList(self, conn, object_list):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
self.assertEqual(len(packet_list), len(object_list))
for packet, (oid, serial) in zip(packet_list, object_list):
self.assertEqual(packet.getType(),
Packets.AskObject)
self.assertEqual(packet.decode(), (oid, serial, None))
def test_connectionLost(self):
app = self.getApp()
ReplicationHandler(app).connectionLost(None, None)
self.assertEqual(len(app.replicator.mockGetNamedCalls('reset')), 1)
def test_connectionFailed(self):
app = self.getApp()
ReplicationHandler(app).connectionFailed(None)
self.assertEqual(len(app.replicator.mockGetNamedCalls('reset')), 1)
def test_acceptIdentification(self):
rid = 24
app = self.getApp(rid=rid)
conn = self.getFakeConnection()
replication = ReplicationHandler(app)
replication.acceptIdentification(conn, None, None, None,
None, None)
self._checkReplicationStarted(conn, rid, app.replicator)
def test_startReplication(self):
rid = 24
app = self.getApp(rid=rid)
conn = self.getFakeConnection()
ReplicationHandler(app).startReplication(conn)
self._checkReplicationStarted(conn, rid, app.replicator)
def test_answerTIDsFrom(self):
conn = self.getFakeConnection()
tid_list = [self.getNextTID(), self.getNextTID()]
app = self.getApp(conn=conn, tid_result=[])
# With no known TID
ReplicationHandler(app).answerTIDsFrom(conn, tid_list)
self._checkPacketTIDList(conn, tid_list[:])
# With first TID known
conn = self.getFakeConnection()
known_tid_list = [tid_list[0], ]
unknown_tid_list = [tid_list[1], ]
app = self.getApp(conn=conn, tid_result=known_tid_list)
ReplicationHandler(app).answerTIDsFrom(conn, tid_list)
self._checkPacketTIDList(conn, unknown_tid_list)
def test_answerTransactionInformation(self):
conn = self.getFakeConnection()
app = self.getApp(conn=conn)
tid = self.getNextTID()
user = 'foo'
desc = 'bar'
ext = 'baz'
packed = True
oid_list = [self.getOID(1), self.getOID(2)]
ReplicationHandler(app).answerTransactionInformation(conn, tid, user,
desc, ext, packed, oid_list)
calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, (), (oid_list, user, desc, ext, packed), False)
def test_answerObjectHistoryFrom(self):
conn = self.getFakeConnection()
oid_1 = self.getOID(1)
oid_2 = self.getOID(2)
oid_3 = self.getOID(3)
oid_dict = FakeDict((
(oid_1, [self.getNextTID(), self.getNextTID()]),
(oid_2, [self.getNextTID()]),
(oid_3, [self.getNextTID()]),
))
flat_oid_list = []
for oid, serial_list in oid_dict.iteritems():
for serial in serial_list:
flat_oid_list.append((oid, serial))
app = self.getApp(conn=conn, history_result={})
# With no known OID/Serial
ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
self._checkPacketSerialList(conn, flat_oid_list)
# With some known OID/Serials
conn = self.getFakeConnection()
app = self.getApp(conn=conn, history_result={
oid_1: [oid_dict[oid_1][0], ],
oid_3: [oid_dict[oid_3][0], ],
})
ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
self._checkPacketSerialList(conn, (
(oid_1, oid_dict[oid_1][1]),
(oid_2, oid_dict[oid_2][0]),
))
def test_answerObject(self):
conn = self.getFakeConnection()
app = self.getApp(conn=conn)
oid = self.getOID(1)
serial_start = self.getNextTID()
serial_end = self.getNextTID()
compression = 1
checksum = 2
data = 'foo'
data_serial = None
ReplicationHandler(app).answerObject(conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial)
calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(serial_start, [(oid, compression, checksum, data,
data_serial)], None, False)
# CheckTIDRange
def test_answerCheckTIDRangeIdenticalChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with next chunk
pmin_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkWithoutNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 1, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
max_tid)
# Result: go on with object range checks
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
def test_answerCheckTIDRangeDifferentBigChunk(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: ask again, length halved
pmin_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
def test_answerCheckTIDRangeDifferentSmallChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: ask tid list, and ask next chunk
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2)
tid_call, next_call = calls
tid_packet = tid_call.getParam(0)
next_packet = next_call.getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
pmin_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(plength, length)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
self.assertEqual(next_packet.getType(), Packets.AskCheckTIDRange)
pmin_tid, plength, ppartition = next_packet.decode()
self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has different data, and less than length
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
max_tid)
# Result: ask tid list, and start replicating object range
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2)
tid_call, next_call = calls
tid_packet = tid_call.getParam(0)
next_packet = next_call.getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
pmin_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_packet.decode()
self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
# CheckSerialRange
def test_answerCheckSerialRangeIdenticalChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: go on with next chunk
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
def test_answerCheckSerialRangeIdenticalChunkWithoutNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length - 1, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length - 1, 0, max_oid, 1, max_serial)
# Result: mark replication as done
self.checkNoPacketSent(conn)
self.assertTrue(app.replicator.replication_done)
def test_answerCheckSerialRangeDifferentBigChunk(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: ask again, length halved
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
def test_answerCheckSerialRangeDifferentSmallChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: ask serial list, and ask next chunk
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2)
serial_call, next_call = calls
serial_packet = serial_call.getParam(0)
next_packet = next_call.getParam(0)
self.assertEqual(serial_packet.getType(), Packets.AskObjectHistoryFrom)
pmin_oid, pmin_serial, plength, ppartition = serial_packet.decode()
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(plength, length)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_packet.decode()
self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
def test_answerCheckSerialRangeDifferentSmallChunkWithoutNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid,
1, max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has different data, and less than length
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length - 1, 0, max_oid, 1, max_serial)
# Result: ask tid list, and mark replication as done
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
Packets.AskObjectHistoryFrom, decode=True)
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
self.assertTrue(app.replicator.replication_done)
if __name__ == "__main__":
unittest.main()
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock, ReturnValues
from neo.tests import NeoTestBase
from neo.storage.replicator import Replicator, Partition, Task
from neo.protocol import CellStates, NodeStates, Packets
class StorageReplicatorTests(NeoTestBase):
def setup(self):
pass
def teardown(self):
pass
def test_populate(self):
my_uuid = self.getNewUUID()
other_uuid = self.getNewUUID()
app = Mock()
app.uuid = my_uuid
app.pt = Mock({
'getPartitions': 2,
'getRow': ReturnValues(
((my_uuid, CellStates.OUT_OF_DATE),
(other_uuid, CellStates.UP_TO_DATE), ),
((my_uuid, CellStates.UP_TO_DATE),
(other_uuid, CellStates.OUT_OF_DATE), ),
),
})
replicator = Replicator(app)
assert replicator.new_partition_dict is None, \
replicator.new_partition_dict
assert replicator.critical_tid_dict is None, \
replicator.critical_tid_dict
assert replicator.partition_dict is None, replicator.partition_dict
replicator.populate()
self.assertEqual(len(replicator.new_partition_dict), 1)
partition = replicator.new_partition_dict[0]
self.assertEqual(partition.getRID(), 0)
self.assertEqual(partition.getCriticalTID(), None)
self.assertEqual(replicator.critical_tid_dict, {})
self.assertEqual(replicator.partition_dict, {})
def test_reset(self):
replicator = Replicator(None)
assert replicator.task_list is None, replicator.task_list
assert replicator.task_dict is None, replicator.task_dict
assert replicator.current_partition is None, \
replicator.current_partition
assert replicator.current_connection is None, \
replicator.current_connection
assert replicator.waiting_for_unfinished_tids is None, \
replicator.waiting_for_unfinished_tids
assert replicator.unfinished_tid_list is None, \
replicator.unfinished_tid_list
assert replicator.replication_done is None, replicator.replication_done
replicator.reset()
self.assertEqual(replicator.task_list, [])
self.assertEqual(replicator.task_dict, {})
self.assertEqual(replicator.current_partition, None)
self.assertEqual(replicator.current_connection, None)
self.assertEqual(replicator.waiting_for_unfinished_tids, False)
self.assertEqual(replicator.unfinished_tid_list, None)
self.assertEqual(replicator.replication_done, True)
def test_setCriticalTID(self):
replicator = Replicator(None)
master_uuid = self.getNewUUID()
partition_list = [Partition(0), Partition(5)]
replicator.critical_tid_dict = {master_uuid: partition_list}
critical_tid = self.getNextTID()
for partition in partition_list:
self.assertEqual(partition.getCriticalTID(), None)
replicator.setCriticalTID(master_uuid, critical_tid)
self.assertEqual(replicator.critical_tid_dict, {})
for partition in partition_list:
self.assertEqual(partition.getCriticalTID(), critical_tid)
def test_setUnfinishedTIDList(self):
replicator = Replicator(None)
replicator.waiting_for_unfinished_tids = True
assert replicator.unfinished_tid_list is None, \
replicator.unfinished_tid_list
tid_list = [self.getNextTID(), ]
replicator.setUnfinishedTIDList(tid_list)
self.assertEqual(replicator.unfinished_tid_list, tid_list)
self.assertFalse(replicator.waiting_for_unfinished_tids)
def test_act(self):
# Also tests "pending"
uuid = self.getNewUUID()
master_uuid = self.getNewUUID()
bad_unfinished_tid = self.getNextTID()
critical_tid = self.getNextTID()
unfinished_tid = self.getNextTID()
app = Mock()
app.em = Mock({
'register': None,
})
def connectorGenerator():
return Mock()
app.connector_handler = connectorGenerator
app.uuid = uuid
node_addr = ('127.0.0.1', 1234)
node = Mock({
'getAddress': node_addr,
})
running_cell = Mock({
'getNodeState': NodeStates.RUNNING,
'getNode': node,
})
unknown_cell = Mock({
'getNodeState': NodeStates.UNKNOWN,
})
app.pt = Mock({
'getPartitions': 1,
'getRow': ReturnValues(
((uuid, CellStates.OUT_OF_DATE), ),
),
'getCellList': [running_cell, unknown_cell],
})
node_conn_handler = Mock({
'startReplication': None,
})
node_conn = Mock({
'getAddress': node_addr,
'getHandler': node_conn_handler,
})
replicator = Replicator(app)
replicator.populate()
def act():
app.master_conn = self.getFakeConnection(uuid=master_uuid)
self.assertTrue(replicator.pending())
replicator.act()
# ask last IDs to infer critical_tid and unfinished tids
act()
last_ids, unfinished_tids = [x.getParam(0) for x in \
app.master_conn.mockGetNamedCalls('ask')]
self.assertEqual(last_ids.getType(), Packets.AskLastIDs)
self.assertFalse(replicator.new_partition_dict)
self.assertEqual(unfinished_tids.getType(),
Packets.AskUnfinishedTransactions)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# nothing happens until waiting_for_unfinished_tids becomes False
act()
self.checkNoPacketSent(app.master_conn)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# Send answers (garanteed to happen in this order)
replicator.setCriticalTID(master_uuid, critical_tid)
act()
self.checkNoPacketSent(app.master_conn)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# first time, there is an unfinished tid before critical tid,
# replication cannot start, and unfinished TIDs are asked again
replicator.setUnfinishedTIDList([unfinished_tid, bad_unfinished_tid])
self.assertFalse(replicator.waiting_for_unfinished_tids)
# Note: detection that nothing can be replicated happens on first call
# and unfinished tids are asked again on second call. This is ok, but
# might change, so just call twice.
act()
act()
self.checkAskPacket(app.master_conn, Packets.AskUnfinishedTransactions)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# this time, critical tid check should be satisfied
replicator.setUnfinishedTIDList([unfinished_tid, ])
replicator.current_connection = node_conn
act()
self.assertEqual(replicator.current_partition,
replicator.partition_dict[0])
self.assertEqual(len(node_conn_handler.mockGetNamedCalls(
'startReplication')), 1)
self.assertFalse(replicator.replication_done)
# Other calls should do nothing
replicator.current_connection = Mock()
act()
self.checkNoPacketSent(app.master_conn)
self.checkNoPacketSent(replicator.current_connection)
# Mark replication over for this partition
replicator.replication_done = True
# Don't finish while there are pending answers
replicator.current_connection = Mock({
'isPending': True,
})
act()
self.assertTrue(replicator.pending())
replicator.current_connection = Mock({
'isPending': False,
})
act()
# unfinished tid list will not be asked again
self.assertTrue(replicator.unfinished_tid_list)
# also, replication is over
self.assertFalse(replicator.pending())
def test_removePartition(self):
replicator = Replicator(None)
replicator.partition_dict = {0: None, 2: None}
replicator.new_partition_dict = {1: None}
replicator.removePartition(0)
self.assertEqual(replicator.partition_dict, {2: None})
self.assertEqual(replicator.new_partition_dict, {1: None})
replicator.removePartition(1)
replicator.removePartition(2)
self.assertEqual(replicator.partition_dict, {})
self.assertEqual(replicator.new_partition_dict, {})
# Must not raise
replicator.removePartition(3)
def test_addPartition(self):
replicator = Replicator(None)
replicator.partition_dict = {0: None}
replicator.new_partition_dict = {1: None}
replicator.addPartition(0)
replicator.addPartition(1)
self.assertEqual(replicator.partition_dict, {0: None})
self.assertEqual(replicator.new_partition_dict, {1: None})
replicator.addPartition(2)
self.assertEqual(replicator.partition_dict, {0: None})
self.assertEqual(len(replicator.new_partition_dict), 2)
self.assertEqual(replicator.new_partition_dict[1], None)
partition = replicator.new_partition_dict[2]
self.assertEqual(partition.getRID(), 2)
self.assertEqual(partition.getCriticalTID(), None)
def test_processDelayedTasks(self):
replicator = Replicator(None)
replicator.reset()
marker = []
def someCallable(foo, bar=None):
return (foo, bar)
replicator._addTask(1, someCallable, args=('foo', ))
self.assertRaises(ValueError, replicator._addTask, 1, None)
replicator._addTask(2, someCallable, args=('foo', ), kw={'bar': 'bar'})
replicator.processDelayedTasks()
self.assertEqual(replicator._getCheckResult(1), ('foo', None))
self.assertEqual(replicator._getCheckResult(2), ('foo', 'bar'))
# Also test Task
task = Task(someCallable, args=('foo', ))
self.assertRaises(ValueError, task.getResult)
task.process()
self.assertRaises(ValueError, task.process)
self.assertEqual(task.getResult(), ('foo', None))
if __name__ == "__main__":
unittest.main()
...@@ -21,7 +21,7 @@ from collections import deque ...@@ -21,7 +21,7 @@ from collections import deque
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.storage import StorageOperationHandler from neo.storage.handlers.storage import StorageOperationHandler
from neo.protocol import INVALID_PARTITION from neo.protocol import INVALID_PARTITION, Packets
from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL
class StorageStorageHandlerTests(NeoTestBase): class StorageStorageHandlerTests(NeoTestBase):
...@@ -113,7 +113,7 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -113,7 +113,7 @@ class StorageStorageHandlerTests(NeoTestBase):
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.checkAnswerObject(conn) self.checkAnswerObject(conn)
def test_25_askTIDsFrom1(self): def test_25_askTIDsFrom(self):
# well case => answer # well case => answer
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
...@@ -122,69 +122,85 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -122,69 +122,85 @@ class StorageStorageHandlerTests(NeoTestBase):
self.operation.askTIDsFrom(conn, tid, 2, 1) self.operation.askTIDsFrom(conn, tid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList') calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, 2, 1, [1, ]) calls[0].checkArgs(tid, 2, 1, 1)
self.checkAnswerTidsFrom(conn)
def test_25_askTIDsFrom2(self):
# invalid partition => answer usable partitions
conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
self.app.pt = Mock({
'getCellList': (cell, ),
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
tid = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, 2, 1, [0, ])
self.checkAnswerTidsFrom(conn) self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self): def test_26_askObjectHistoryFrom(self):
oid = self.getOID(2) min_oid = self.getOID(2)
min_tid = self.getNextTID() min_serial = self.getNextTID()
length = 4
partition = 8
num_partitions = 16
tid = self.getNextTID() tid = self.getNextTID()
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistoryFrom': [tid]}) self.app.dm = Mock({'getObjectHistoryFrom': {min_oid: [tid]},})
self.operation.askObjectHistoryFrom(conn, oid, min_tid, 2) self.app.pt = Mock({
'getPartitions': num_partitions,
})
self.operation.askObjectHistoryFrom(conn, min_oid, min_serial, length,
partition)
self.checkAnswerObjectHistoryFrom(conn) self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom') calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, min_tid, 2) calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
partition)
def test_25_askOIDs1(self): def test_askCheckTIDRange(self):
# well case > answer OIDs count = 1
tid_checksum = 2
min_tid = self.getNextTID()
num_partitions = 4
length = 5
partition = 6
max_tid = self.getNextTID()
self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection() conn = self.getFakeConnection()
self.app.pt = Mock({'getPartitions': 1}) self.operation.askCheckTIDRange(conn, min_tid, length, partition)
self.app.dm = Mock({'getOIDList': (INVALID_OID, )}) calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
oid = self.getOID(1) self.assertEqual(len(calls), 1)
self.operation.askOIDs(conn, oid, 2, 1) calls[0].checkArgs(min_tid, length, num_partitions, partition)
calls = self.app.dm.mockGetNamedCalls('getOIDList') pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
self.assertEquals(len(calls), 1) self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
calls[0].checkArgs(oid, 2, 1, [1, ]) decode=True)
self.checkAnswerOids(conn) self.assertEqual(min_tid, pmin_tid)
self.assertEqual(length, plength)
def test_25_askOIDs2(self): self.assertEqual(count, pcount)
# invalid partition => answer usable partitions self.assertEqual(tid_checksum, ptid_checksum)
self.assertEqual(max_tid, pmax_tid)
def test_askCheckSerialRange(self):
count = 1
oid_checksum = 2
min_oid = self.getOID(1)
num_partitions = 4
length = 5
partition = 6
serial_checksum = 7
min_serial = self.getNextTID()
max_serial = self.getNextTID()
max_oid = self.getOID(2)
self.app.dm = Mock({'checkSerialRange': (count, oid_checksum, max_oid,
serial_checksum, max_serial)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection() conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid}) self.operation.askCheckSerialRange(conn, min_oid, min_serial, length,
self.app.dm = Mock({'getOIDList': (INVALID_OID, )}) partition)
self.app.pt = Mock({ calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
'getCellList': (cell, ), self.assertEqual(len(calls), 1)
'getPartitions': 1, calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
'getAssignedPartitionList': [0], partition)
}) pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
oid = self.getOID(1) pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
self.operation.askOIDs(conn, oid, 2, INVALID_PARTITION) Packets.AnswerCheckSerialRange, decode=True)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1) self.assertEqual(min_oid, pmin_oid)
calls = self.app.dm.mockGetNamedCalls('getOIDList') self.assertEqual(min_serial, pmin_serial)
self.assertEquals(len(calls), 1) self.assertEqual(length, plength)
calls[0].checkArgs(oid, 2, 1, [0]) self.assertEqual(count, pcount)
self.checkAnswerOids(conn) self.assertEqual(oid_checksum, poid_checksum)
self.assertEqual(max_oid, pmax_oid)
self.assertEqual(serial_checksum, pserial_checksum)
self.assertEqual(max_serial, pmax_serial)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import MySQLdb import MySQLdb
from mock import Mock from mock import Mock
from neo.util import dump, p64, u64 from neo.util import dump, p64, u64
from neo.protocol import CellStates, INVALID_PTID from neo.protocol import CellStates, INVALID_PTID, ZERO_OID, ZERO_TID
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.exception import DatabaseFailure from neo.exception import DatabaseFailure
from neo.storage.database.mysqldb import MySQLDatabaseManager from neo.storage.database.mysqldb import MySQLDatabaseManager
...@@ -441,6 +441,23 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -441,6 +441,23 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(self.db.getTransaction(tid1, True), None) self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None) self.assertEqual(self.db.getTransaction(tid2, True), None)
def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1, oid2])
txn2, objs2 = self.getTransaction([oid1, oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.deleteObject(oid1)
self.assertEqual(self.db.getObject(oid1, tid=tid1), None)
self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1)
self.assertEqual(self.db.getObject(oid2, tid=tid1), False)
self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \
objs2[1][1:])
def test_getTransaction(self): def test_getTransaction(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
...@@ -459,30 +476,6 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -459,30 +476,6 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False)) self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False))
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getOIDList(self):
# store four objects
oid1, oid2, oid3, oid4 = self.getOIDs(4)
tid = self.getNextTID()
txn, objs = self.getTransaction([oid1, oid2, oid3, oid4])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
# get oids
result = self.db.getOIDList(oid1, 4, 1, [0])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(oid1, 4, 2, [0])
self.checkSet(result, [oid1, oid3])
result = self.db.getOIDList(oid1, 4, 2, [0, 1])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(oid1, 4, 3, [0])
self.checkSet(result, [oid1, oid4])
# get a subset of oids
result = self.db.getOIDList(oid1, 2, 1, [0])
self.checkSet(result, [oid1, oid2])
result = self.db.getOIDList(oid3, 2, 1, [0])
self.checkSet(result, [oid3, oid4])
result = self.db.getOIDList(oid2, 1, 3, [0])
self.checkSet(result, [oid4])
def test_getObjectHistory(self): def test_getObjectHistory(self):
oid = self.getOID(1) oid = self.getOID(1)
tid1, tid2, tid3 = self.getTIDs(3) tid1, tid2, tid3 = self.getTIDs(3)
...@@ -506,6 +499,50 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -506,6 +499,50 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getObjectHistory(oid, 2, 3) result = self.db.getObjectHistory(oid, 2, 3)
self.assertEqual(result, None) self.assertEqual(result, None)
def test_getObjectHistoryFrom(self):
oid1 = self.getOID(0)
oid2 = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
txn3, objs3 = self.getTransaction([oid1])
txn4, objs4 = self.getTransaction([oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.storeTransaction(tid3, objs3, txn3)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.finishTransaction(tid3)
self.db.finishTransaction(tid4)
# Check full result
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Lower bound is inclusive
result = self.db.getObjectHistoryFrom(oid1, tid1, 10, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Length is total number of serials
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 3, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Partition constraints are honored
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
})
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 1)
self.assertEqual(result, {
oid2: [tid2, tid4],
})
def _storeTransactions(self, count): def _storeTransactions(self, count):
# use OID generator to know result of tid % N # use OID generator to know result of tid % N
tid_list = self.getOIDs(count) tid_list = self.getOIDs(count)
...@@ -538,59 +575,20 @@ class StorageMySQSLdbTests(NeoTestBase): ...@@ -538,59 +575,20 @@ class StorageMySQSLdbTests(NeoTestBase):
def test_getReplicationTIDList(self): def test_getReplicationTIDList(self):
tid1, tid2, tid3, tid4 = self._storeTransactions(4) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids # get tids
result = self.db.getReplicationTIDList(tid1, 4, 1, [0]) result = self.db.getReplicationTIDList(tid1, 4, 1, 0)
self.checkSet(result, [tid1, tid2, tid3, tid4]) self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0]) result = self.db.getReplicationTIDList(tid1, 4, 2, 0)
self.checkSet(result, [tid1, tid3]) self.checkSet(result, [tid1, tid3])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0, 1]) result = self.db.getReplicationTIDList(tid1, 4, 3, 0)
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 3, [0])
self.checkSet(result, [tid1, tid4]) self.checkSet(result, [tid1, tid4])
# get a subset of tids # get a subset of tids
result = self.db.getReplicationTIDList(tid3, 4, 1, [0]) result = self.db.getReplicationTIDList(tid3, 4, 1, 0)
self.checkSet(result, [tid3, tid4]) self.checkSet(result, [tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 2, 1, [0]) result = self.db.getReplicationTIDList(tid1, 2, 1, 0)
self.checkSet(result, [tid1, tid2]) self.checkSet(result, [tid1, tid2])
result = self.db.getReplicationTIDList(tid1, 1, 3, [1]) result = self.db.getReplicationTIDList(tid1, 1, 3, 1)
self.checkSet(result, [tid2]) self.checkSet(result, [tid2])
def test_getTIDListPresent(self):
oid = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid])
txn4, objs4 = self.getTransaction([oid])
# four tids, two missing
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid4)
result = self.db.getTIDListPresent([tid1, tid2, tid3, tid4])
self.checkSet(result, [tid1, tid4])
result = self.db.getTIDListPresent([tid1, tid2])
self.checkSet(result, [tid1])
self.assertEqual(self.db.getTIDListPresent([tid2, tid3]), [])
def test_getSerialListPresent(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1])
txn3, objs3 = self.getTransaction([oid2])
txn4, objs4 = self.getTransaction([oid2])
# four object, one revision each
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid4)
result = self.db.getSerialListPresent(oid1, [tid1, tid2])
self.checkSet(result, [tid1])
result = self.db.getSerialListPresent(oid2, [tid3, tid4])
self.checkSet(result, [tid4])
result = self.db.getSerialListPresent(oid1, [tid2])
self.assertEqual(result, [])
result = self.db.getSerialListPresent(oid2, [tid3])
self.assertEqual(result, [])
def test__getObjectData(self): def test__getObjectData(self):
db = self.db db = self.db
db.setup(reset=True) db.setup(reset=True)
......
...@@ -458,24 +458,6 @@ class ProtocolTests(NeoTestBase): ...@@ -458,24 +458,6 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(p_hist_list, hist_list) self.assertEqual(p_hist_list, hist_list)
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
def test_55_askOIDs(self):
oid = self.getOID(1)
p = Packets.AskOIDs(oid, 1000, 5)
min_oid, length, partition = p.decode()
self.assertEqual(min_oid, oid)
self.assertEqual(length, 1000)
self.assertEqual(partition, 5)
def test_56_answerOIDs(self):
oid1 = self.getNextTID()
oid2 = self.getNextTID()
oid3 = self.getNextTID()
oid4 = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4]
p = Packets.AnswerOIDs(oid_list)
p_oid_list = p.decode()[0]
self.assertEqual(p_oid_list, oid_list)
def test_57_notifyReplicationDone(self): def test_57_notifyReplicationDone(self):
offset = 10 offset = 10
p = Packets.NotifyReplicationDone(offset) p = Packets.NotifyReplicationDone(offset)
...@@ -626,14 +608,82 @@ class ProtocolTests(NeoTestBase): ...@@ -626,14 +608,82 @@ class ProtocolTests(NeoTestBase):
oid = self.getOID(1) oid = self.getOID(1)
min_serial = self.getNextTID() min_serial = self.getNextTID()
length = 5 length = 5
p = Packets.AskObjectHistoryFrom(oid, min_serial, length) partition = 4
p_oid, p_min_serial, p_length = p.decode() p = Packets.AskObjectHistoryFrom(oid, min_serial, length, partition)
p_oid, p_min_serial, p_length, p_partition = p.decode()
self.assertEqual(p_oid, oid) self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial) self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length) self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerObjectHistoryFrom(self): def test_AnswerObjectHistoryFrom(self):
self._testXIDAndYIDList(Packets.AnswerObjectHistoryFrom) object_dict = {}
for int_oid in xrange(4):
object_dict[self.getOID(int_oid)] = [self.getNextTID() \
for _ in xrange(5)]
p = Packets.AnswerObjectHistoryFrom(object_dict)
p_object_dict = p.decode()[0]
self.assertEqual(object_dict, p_object_dict)
def test_AskCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckTIDRange(min_tid, length, partition)
p_min_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
count = 1
tid_checksum = 42
max_tid = self.getNextTID()
p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
max_tid)
p_min_tid, p_length, p_count, p_tid_checksum, p_max_tid = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_tid, max_tid)
def test_AskCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckSerialRange(min_oid, min_serial, length, partition)
p_min_oid, p_min_serial, p_length, p_partition = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
count = 1
oid_checksum = 24
max_oid = self.getOID(5)
tid_checksum = 42
max_serial = self.getNextTID()
p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
oid_checksum, max_oid, tid_checksum, max_serial)
p_min_oid, p_min_serial, p_length, p_count, p_oid_checksum, \
p_max_oid, p_tid_checksum, p_max_serial = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_oid_checksum, oid_checksum)
self.assertEqual(p_max_oid, max_oid)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_serial, max_serial)
def test_AskPack(self): def test_AskPack(self):
tid = self.getNextTID() tid = self.getNextTID()
......
...@@ -56,6 +56,8 @@ UNIT_TEST_MODULES = [ ...@@ -56,6 +56,8 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testVerificationHandler', 'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler', 'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions', 'neo.tests.storage.testTransactions',
'neo.tests.storage.testReplicationHandler',
'neo.tests.storage.testReplicator',
# client application # client application
'neo.tests.client.testClientApp', 'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler', 'neo.tests.client.testMasterHandler',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment