Commit b0753366 authored by Vincent Pelletier's avatar Vincent Pelletier

Implement rsync-ish replication.

For further description, see storage/handlers/replication.py .

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2295 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 01735352
......@@ -275,16 +275,10 @@ class EventHandler(object):
def answerObjectHistory(self, conn, oid, history_list):
raise UnexpectedPacketError
def askObjectHistoryFrom(self, conn, oid, min_serial, length):
def askObjectHistoryFrom(self, conn, oid, min_serial, length, partition):
raise UnexpectedPacketError
def answerObjectHistoryFrom(self, conn, oid, history_list):
raise UnexpectedPacketError
def askOIDs(self, conn, min_oid, length, partition):
raise UnexpectedPacketError
def answerOIDs(self, conn, oid_list):
def answerObjectHistoryFrom(self, conn, object_dict):
raise UnexpectedPacketError
def askPartitionList(self, conn, min_offset, max_offset, uuid):
......@@ -358,6 +352,21 @@ class EventHandler(object):
def answerPack(self, conn, status):
raise UnexpectedPacketError
def askCheckTIDRange(self, conn, min_tid, length, partition):
raise UnexpectedPacketError
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid):
raise UnexpectedPacketError
def askCheckSerialRange(self, conn, min_oid, min_serial, length,
partition):
raise UnexpectedPacketError
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial):
raise UnexpectedPacketError
# Error packet handlers.
......@@ -450,8 +459,6 @@ class EventHandler(object):
d[Packets.AnswerObjectHistory] = self.answerObjectHistory
d[Packets.AskObjectHistoryFrom] = self.askObjectHistoryFrom
d[Packets.AnswerObjectHistoryFrom] = self.answerObjectHistoryFrom
d[Packets.AskOIDs] = self.askOIDs
d[Packets.AnswerOIDs] = self.answerOIDs
d[Packets.AskPartitionList] = self.askPartitionList
d[Packets.AnswerPartitionList] = self.answerPartitionList
d[Packets.AskNodeList] = self.askNodeList
......@@ -476,6 +483,10 @@ class EventHandler(object):
d[Packets.AnswerBarrier] = self.answerBarrier
d[Packets.AskPack] = self.askPack
d[Packets.AnswerPack] = self.answerPack
d[Packets.AskCheckTIDRange] = self.askCheckTIDRange
d[Packets.AnswerCheckTIDRange] = self.answerCheckTIDRange
d[Packets.AskCheckSerialRange] = self.askCheckSerialRange
d[Packets.AnswerCheckSerialRange] = self.answerCheckSerialRange
return d
......
......@@ -113,6 +113,7 @@ INVALID_PARTITION = 0xffffffff
ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID)
TID_LEN = len(INVALID_TID)
UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S',
......@@ -1167,63 +1168,47 @@ class AnswerObjectHistory(Packet):
class AskObjectHistoryFrom(Packet):
"""
Ask history information for a given object. The order of serials is
ascending, and starts at (or above) min_serial. S -> S.
ascending, and starts at (or above) min_serial for min_oid. S -> S.
"""
_header_format = '!8s8sL'
_header_format = '!8s8sLL'
def _encode(self, oid, min_serial, length):
return pack(self._header_format, oid, min_serial, length)
def _encode(self, min_oid, min_serial, length, partition):
return pack(self._header_format, min_oid, min_serial, length,
partition)
def _decode(self, body):
return unpack(self._header_format, body) # oid, min_serial, length
# min_oid, min_serial, length, partition
return unpack(self._header_format, body)
class AnswerObjectHistoryFrom(AskFinishTransaction):
class AnswerObjectHistoryFrom(Packet):
"""
Answer the requested serials. S -> S.
"""
# This is similar to AskFinishTransaction as TID size is identical to OID
# size:
# - we have a single OID (TID in AskFinishTransaction)
# - we have a list of TIDs (OIDs in AskFinishTransaction)
pass
class AskOIDs(Packet):
"""
Ask for length OIDs starting at min_oid. S -> S.
"""
_header_format = '!8sLL'
def _encode(self, min_oid, length, partition):
return pack(self._header_format, min_oid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_oid, length, partition
class AnswerOIDs(Packet):
"""
Answer the requested OIDs. S -> S.
"""
_header_format = '!L'
_list_entry_format = '8s'
_list_entry_format = '!8sL'
_list_entry_len = calcsize(_list_entry_format)
def _encode(self, oid_list):
body = [pack(self._header_format, len(oid_list))]
body.extend(oid_list)
def _encode(self, object_dict):
body = [pack(self._header_format, len(object_dict))]
append = body.append
extend = body.extend
list_entry_format = self._list_entry_format
for oid, serial_list in object_dict.iteritems():
append(pack(list_entry_format, oid, len(serial_list)))
extend(serial_list)
return ''.join(body)
def _decode(self, body):
offset = self._header_len
(n,) = unpack(self._header_format, body[:offset])
oid_list = []
body = StringIO(body)
read = body.read
list_entry_format = self._list_entry_format
list_entry_len = self._list_entry_len
for _ in xrange(n):
next_offset = offset + list_entry_len
oid = unpack(list_entry_format, body[offset:next_offset])[0]
offset = next_offset
oid_list.append(oid)
return (oid_list,)
object_dict = {}
dict_len = unpack(self._header_format, read(self._header_len))[0]
for _ in xrange(dict_len):
oid, serial_len = unpack(list_entry_format, read(list_entry_len))
object_dict[oid] = [read(TID_LEN) for _ in xrange(serial_len)]
return (object_dict, )
class AskPartitionList(Packet):
"""
......@@ -1660,6 +1645,73 @@ class AnswerPack(Packet):
def _decode(self, body):
return (bool(unpack(self._header_format, body)[0]), )
class AskCheckTIDRange(Packet):
"""
Ask some stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8sLL'
def _encode(self, min_tid, length, partition):
return pack(self._header_format, min_tid, length, partition)
def _decode(self, body):
return unpack(self._header_format, body) # min_tid, length, partition
class AnswerCheckTIDRange(Packet):
"""
Stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8sLLQ8s'
def _encode(self, min_tid, length, count, tid_checksum, max_tid):
return pack(self._header_format, min_tid, length, count, tid_checksum,
max_tid)
def _decode(self, body):
# min_tid, length, partition, count, tid_checksum, max_tid
return unpack(self._header_format, body)
class AskCheckSerialRange(Packet):
"""
Ask some stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8s8sLL'
def _encode(self, min_oid, min_serial, length, partition):
return pack(self._header_format, min_oid, min_serial, length,
partition)
def _decode(self, body):
# min_oid, min_serial, length, partition
return unpack(self._header_format, body)
class AnswerCheckSerialRange(Packet):
"""
Stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
S -> S
"""
_header_format = '!8s8sLLQ8sQ8s'
def _encode(self, min_oid, min_serial, length, count, oid_checksum,
max_oid, serial_checksum, max_serial):
return pack(self._header_format, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial)
def _decode(self, body):
# min_oid, min_serial, length, count, oid_checksum, max_oid,
# serial_checksum, max_serial
return unpack(self._header_format, body)
class Error(Packet):
"""
Error is a special type of message, because this can be sent against
......@@ -1844,10 +1896,6 @@ class PacketRegistry(dict):
0x001F,
AskObjectHistory,
AnswerObjectHistory)
AskOIDs, AnswerOIDs = register(
0x0020,
AskOIDs,
AnswerOIDs)
AskPartitionList, AnswerPartitionList = register(
0x0021,
AskPartitionList,
......@@ -1903,6 +1951,16 @@ class PacketRegistry(dict):
0x0038,
AskPack,
AnswerPack)
AskCheckTIDRange, AnswerCheckTIDRange = register(
0x0039,
AskCheckTIDRange,
AnswerCheckTIDRange,
)
AskCheckSerialRange, AnswerCheckSerialRange = register(
0x003A,
AskCheckSerialRange,
AnswerCheckSerialRange,
)
# build a "singleton"
Packets = PacketRegistry()
......
......@@ -288,6 +288,12 @@ class Application(object):
while True:
em.poll(1)
if self.replicator.pending():
# Call processDelayedTasks before act, so tasks added in the
# act call are executed after one poll call, so that sent
# packets are already on the network and delayed task
# processing happens in parallel with the same task on the
# other storage node.
self.replicator.processDelayedTasks()
self.replicator.act()
def wait(self):
......
......@@ -274,6 +274,11 @@ class DatabaseManager(object):
area."""
raise NotImplementedError
def deleteObject(self, oid, serial=None):
"""Delete given object. If serial is given, only delete that serial for
given oid."""
raise NotImplementedError
def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information,
a description, and extension information, for a given transaction
......@@ -282,12 +287,6 @@ class DatabaseManager(object):
area as well."""
raise NotImplementedError
def getOIDList(self, min_oid, length, num_partitions, partition_list):
"""Return a list of OIDs in ascending order from a minimal oid,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
def getObjectHistory(self, oid, offset = 0, length = 1):
"""Return a list of serials and sizes for a given object ID.
The length specifies the maximum size of such a list. Result starts
......@@ -295,9 +294,11 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None."""
raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, length):
"""Return a list of length serials for a given object ID at (or above)
min_serial, sorted in ascending order."""
def getObjectHistoryFrom(self, oid, min_serial, length, num_partitions,
partition):
"""Return a dict of length serials grouped by oid at (or above)
min_oid and min_serial, for given partition, sorted in ascending
order."""
raise NotImplementedError
def getTIDList(self, offset, length, num_partitions, partition_list):
......@@ -307,20 +308,10 @@ class DatabaseManager(object):
raise NotImplementedError
def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list):
partition):
"""Return a list of TIDs in ascending order from an initial tid value,
at most the specified length. The list of partitions are passed
to filter out non-applicable TIDs."""
raise NotImplementedError
def getTIDListPresent(self, tid_list):
"""Return a list of TIDs which are present in a database among
the given list."""
raise NotImplementedError
def getSerialListPresent(self, oid, serial_list):
"""Return a list of serials which are present in a database among
the given list."""
at most the specified length. The partition number is passed to filter
out non-applicable TIDs."""
raise NotImplementedError
def pack(self, tid, updateObjectDataForPack):
......
......@@ -24,7 +24,7 @@ import string
from neo.storage.database import DatabaseManager
from neo.exception import DatabaseFailure
from neo.protocol import CellStates
from neo.protocol import CellStates, ZERO_OID, ZERO_TID
from neo import util
LOG_QUERIES = False
......@@ -576,6 +576,23 @@ class MySQLDatabaseManager(DatabaseManager):
raise
self.commit()
def deleteObject(self, oid, serial=None):
u64 = util.u64
query_param_dict = {
'oid': u64(oid),
}
query_fmt = 'DELETE FROM obj WHERE oid = %(oid)d'
if serial is not None:
query_param_dict['serial'] = u64(serial)
query_fmt = query_fmt + ' AND serial = %(serial)d'
self.begin()
try:
self.query(query_fmt % query_param_dict)
except:
self.rollback()
raise
self.commit()
def getTransaction(self, tid, all = False):
q = self.query
tid = util.u64(tid)
......@@ -594,20 +611,6 @@ class MySQLDatabaseManager(DatabaseManager):
return oid_list, user, desc, ext, bool(packed)
return None
def getOIDList(self, min_oid, length, num_partitions,
partition_list):
q = self.query
r = q("""SELECT DISTINCT oid FROM obj WHERE
MOD(oid, %(num_partitions)d) in (%(partitions)s)
AND oid >= %(min_oid)d
ORDER BY oid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'min_oid': util.u64(min_oid),
'length': length,
})
return [util.p64(t[0]) for t in r]
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
......@@ -646,18 +649,32 @@ class MySQLDatabaseManager(DatabaseManager):
return result
return None
def getObjectHistoryFrom(self, oid, min_serial, length):
def getObjectHistoryFrom(self, min_oid, min_serial, length, num_partitions,
partition):
q = self.query
oid = util.u64(oid)
u64 = util.u64
p64 = util.p64
r = q("""SELECT serial FROM obj
WHERE oid = %(oid)d AND serial >= %(min_serial)d
ORDER BY serial ASC LIMIT %(length)d""" % {
'oid': oid,
'min_serial': util.u64(min_serial),
min_oid = u64(min_oid)
min_serial = u64(min_serial)
r = q('SELECT oid, serial FROM obj '
'WHERE ((oid = %(min_oid)d AND serial >= %(min_serial)d) OR '
'oid > %(min_oid)d) AND '
'MOD(oid, %(num_partitions)d) = %(partition)s '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': min_oid,
'min_serial': min_serial,
'length': length,
'num_partitions': num_partitions,
'partition': partition,
})
return [p64(t[0]) for t in r]
result = {}
for oid, serial in r:
try:
serial_list = result[oid]
except KeyError:
serial_list = result[oid] = []
serial_list.append(p64(serial))
return dict((p64(x), y) for x, y in result.iteritems())
def getTIDList(self, offset, length, num_partitions, partition_list):
q = self.query
......@@ -669,32 +686,19 @@ class MySQLDatabaseManager(DatabaseManager):
return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, length, num_partitions,
partition_list):
partition):
q = self.query
r = q("""SELECT tid FROM trans WHERE
MOD(tid, %(num_partitions)d) in (%(partitions)s)
MOD(tid, %(num_partitions)d) = %(partition)d
AND tid >= %(min_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % {
'num_partitions': num_partitions,
'partitions': ','.join([str(p) for p in partition_list]),
'partition': partition,
'min_tid': util.u64(min_tid),
'length': length,
})
return [util.p64(t[0]) for t in r]
def getTIDListPresent(self, tid_list):
q = self.query
r = q("""SELECT tid FROM trans WHERE tid in (%s)""" \
% ','.join([str(util.u64(tid)) for tid in tid_list]))
return [util.p64(t[0]) for t in r]
def getSerialListPresent(self, oid, serial_list):
q = self.query
oid = util.u64(oid)
r = q("""SELECT serial FROM obj WHERE oid = %d AND serial in (%s)""" \
% (oid, ','.join([str(util.u64(serial)) for serial in serial_list])))
return [util.p64(t[0]) for t in r]
def _updatePackFuture(self, oid, orig_serial, max_serial,
updateObjectDataForPack):
q = self.query
......@@ -783,4 +787,54 @@ class MySQLDatabaseManager(DatabaseManager):
self.rollback()
raise
self.commit()
def checkTIDRange(self, min_tid, length, num_partitions, partition):
# XXX: XOR is a lame checksum
count, tid_checksum, max_tid = self.query('SELECT COUNT(*), '
'BIT_XOR(tid), MAX(tid) FROM ('
'SELECT tid FROM trans '
'WHERE MOD(tid, %(num_partitions)d) = %(partition)s '
'AND tid >= %(min_tid)d '
'ORDER BY tid ASC LIMIT %(length)d'
') AS foo' % {
'num_partitions': num_partitions,
'partition': partition,
'min_tid': util.u64(min_tid),
'length': length,
})[0]
if count == 0:
tid_checksum = 0
max_tid = ZERO_TID
else:
max_tid = util.p64(max_tid)
return count, tid_checksum, max_tid
def checkSerialRange(self, min_oid, min_serial, length, num_partitions,
partition):
# XXX: XOR is a lame checksum
u64 = util.u64
p64 = util.p64
r = self.query('SELECT oid, serial FROM obj WHERE '
'(oid > %(min_oid)d OR '
'(oid = %(min_oid)d AND serial >= %(min_serial)d)) '
'AND MOD(oid, %(num_partitions)d) = %(partition)s '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': u64(min_oid),
'min_serial': u64(min_serial),
'length': length,
'num_partitions': num_partitions,
'partition': partition,
})
count = len(r)
oid_checksum = serial_checksum = 0
if count == 0:
max_oid = ZERO_OID
max_serial = ZERO_TID
else:
for max_oid, max_serial in r:
oid_checksum ^= max_oid
serial_checksum ^= max_serial
max_oid = p64(max_oid)
max_serial = p64(max_serial)
return count, oid_checksum, max_oid, serial_checksum, max_serial
......@@ -22,6 +22,48 @@ from neo.handler import EventHandler
from neo.protocol import Packets, ZERO_TID, ZERO_OID
from neo import util
# TODO: benchmark how different values behave
RANGE_LENGTH = 4000
MIN_RANGE_LENGTH = 1000
"""
Replication algorythm
Purpose: replicate the content of a reference node into a replicating node,
bringing it up-to-date.
This happens both when a new storage is added to en existing cluster, as well
as when a nde was separated from cluster and rejoins it.
Replication happens per partition. Reference node can change between
partitions.
2 parts, done sequentially:
- Transaction (metadata) replication
- Object (data) replication
Both part follow the same mechanism:
- On both sides (replicating and reference), compute a checksum of a chunk
(RANGE_LENGTH number of entries). If there is a mismatch, chunk size is
reduced, and scan restarts from same row, until it reaches a minimal length
(MIN_RANGE_LENGTH). Then, it replicates all rows in that chunk. If the
content of chunks match, it moves on to the next chunk.
- Replicating a chunk starts with asking for a list of all entries (only their
identifier) and skipping those both side have, deleting those which reference
has and replicating doesn't, and asking individually all entries missing in
replicating.
"""
# TODO: Make object replication get ordered by serial first and oid second, so
# changes are in a big segment at the end, rather than in many segments (one
# per object).
# TODO: To improve performance when a pack happened, the following algorithm
# should be used:
# - If reference node packed, find non-existant oids in reference node (their
# creation was undone, and pack pruned them), and delete them.
# - Run current algorithm, starting at our last pack TID.
# - Pack partition at reference's TID.
def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw):
if self.app.replicator.current_connection is conn:
......@@ -51,28 +93,26 @@ class ReplicationHandler(EventHandler):
uuid, num_partitions, num_replicas, your_uuid):
# set the UUID on the connection
conn.setUUID(uuid)
self.startReplication(conn)
def startReplication(self, conn):
conn.ask(self._doAskCheckTIDRange(ZERO_TID), timeout=300)
@checkConnectionIsReplicatorConnection
def answerTIDsFrom(self, conn, tid_list):
app = self.app
if tid_list:
# If I have pending TIDs, check which TIDs I don't have, and
# request the data.
present_tid_list = app.dm.getTIDListPresent(tid_list)
tid_set = set(tid_list) - set(present_tid_list)
for tid in tid_set:
conn.ask(Packets.AskTransactionInformation(tid), timeout=300)
# And, ask more TIDs.
p = Packets.AskTIDsFrom(add64(tid_list[-1], 1), 1000,
app.replicator.current_partition.getRID())
conn.ask(p, timeout=300)
else:
# If no more TID, a replication of transactions is finished.
# So start to replicate objects now.
p = Packets.AskOIDs(ZERO_OID, 1000,
app.replicator.current_partition.getRID())
conn.ask(p, timeout=300)
# If I have pending TIDs, check which TIDs I don't have, and
# request the data.
tid_set = frozenset(tid_list)
my_tid_set = frozenset(app.replicator.getTIDsFromResult())
extra_tid_set = my_tid_set - tid_set
if extra_tid_set:
deleteTransaction = app.dm.deleteTransaction
for tid in extra_tid_set:
deleteTransaction(tid)
missing_tid_set = tid_set - my_tid_set
for tid in missing_tid_set:
conn.ask(Packets.AskTransactionInformation(tid), timeout=300)
@checkConnectionIsReplicatorConnection
def answerTransactionInformation(self, conn, tid,
......@@ -83,46 +123,23 @@ class ReplicationHandler(EventHandler):
False)
@checkConnectionIsReplicatorConnection
def answerOIDs(self, conn, oid_list):
def answerObjectHistoryFrom(self, conn, object_dict):
app = self.app
if oid_list:
app.replicator.next_oid = add64(oid_list[-1], 1)
# Pick one up, and ask the history.
oid = oid_list.pop()
conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
timeout=300)
app.replicator.oid_list = oid_list
else:
# Nothing remains, so the replication for this partition is
# finished.
app.replicator.replication_done = True
@checkConnectionIsReplicatorConnection
def answerObjectHistoryFrom(self, conn, oid, serial_list):
app = self.app
if serial_list:
my_object_dict = app.replicator.getObjectHistoryFromResult()
deleteObject = app.dm.deleteObject
for oid, serial_list in object_dict.iteritems():
# Check if I have objects, request those which I don't have.
present_serial_list = app.dm.getSerialListPresent(oid, serial_list)
serial_set = set(serial_list) - set(present_serial_list)
for serial in serial_set:
conn.ask(Packets.AskObject(oid, serial, None), timeout=300)
# And, ask more serials.
conn.ask(Packets.AskObjectHistoryFrom(oid,
add64(serial_list[-1], 1), 1000), timeout=300)
else:
# This OID is finished. So advance to next.
oid_list = app.replicator.oid_list
if oid_list:
# If I have more pending OIDs, pick one up.
oid = oid_list.pop()
conn.ask(Packets.AskObjectHistoryFrom(oid, ZERO_TID, 1000),
timeout=300)
if oid in my_object_dict:
my_serial_set = frozenset(my_object_dict[oid])
serial_set = frozenset(serial_list)
extra_serial_set = my_serial_set - serial_set
for serial in extra_serial_set:
deleteObject(oid, serial)
missing_serial_set = serial_set - my_serial_set
else:
# Otherwise, acquire more OIDs.
p = Packets.AskOIDs(app.replicator.next_oid, 1000,
app.replicator.current_partition.getRID())
conn.ask(p, timeout=300)
missing_serial_set = serial_list
for serial in missing_serial_set:
conn.ask(Packets.AskObject(oid, serial, None), timeout=300)
@checkConnectionIsReplicatorConnection
def answerObject(self, conn, oid, serial_start,
......@@ -134,3 +151,97 @@ class ReplicationHandler(EventHandler):
del obj
del data
def _doAskCheckSerialRange(self, min_oid, min_tid, length=RANGE_LENGTH):
replicator = self.app.replicator
partition = replicator.current_partition.getRID()
replicator.checkSerialRange(min_oid, min_tid, length, partition)
return Packets.AskCheckSerialRange(min_oid, min_tid, length, partition)
def _doAskCheckTIDRange(self, min_tid, length=RANGE_LENGTH):
replicator = self.app.replicator
partition = replicator.current_partition.getRID()
replicator.checkTIDRange(min_tid, length, partition)
return Packets.AskCheckTIDRange(min_tid, length, partition)
def _doAskTIDsFrom(self, min_tid, length):
replicator = self.app.replicator
partition = replicator.current_partition.getRID()
replicator.getTIDsFrom(min_tid, length, partition)
return Packets.AskTIDsFrom(min_tid, length, partition)
def _doAskObjectHistoryFrom(self, min_oid, min_serial, length):
replicator = self.app.replicator
partition = replicator.current_partition.getRID()
replicator.getObjectHistoryFrom(min_oid, min_serial, length, partition)
return Packets.AskObjectHistoryFrom(min_oid, min_serial, length,
partition)
@checkConnectionIsReplicatorConnection
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid):
app = self.app
replicator = app.replicator
our = replicator.getTIDCheckResult(min_tid, length)
his = (count, tid_checksum, max_tid)
our_count = our[0]
our_max_tid = our[2]
p = None
if our != his:
# Something is different...
if length <= MIN_RANGE_LENGTH:
# We are already at minimum chunk length, replicate.
conn.ask(self._doAskTIDsFrom(min_tid, count))
else:
# Check a smaller chunk.
# Note: this could be made into a real binary search, but is
# it really worth the work ?
# Note: +1, so we can detect we reached the end when answer
# comes back.
p = self._doAskCheckTIDRange(min_tid, min(length / 2,
count + 1))
if p is None:
if count == length:
# Go on with next chunk
p = self._doAskCheckTIDRange(add64(max_tid, 1))
else:
# If no more TID, a replication of transactions is finished.
# So start to replicate objects now.
p = self._doAskCheckSerialRange(ZERO_OID, ZERO_TID)
conn.ask(p)
@checkConnectionIsReplicatorConnection
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial):
app = self.app
replicator = app.replicator
our = replicator.getSerialCheckResult(min_oid, min_serial, length)
his = (count, oid_checksum, max_oid, serial_checksum, max_serial)
our_count = our[0]
our_max_oid = our[2]
our_max_serial = our[4]
p = None
if our != his:
# Something is different...
if length <= MIN_RANGE_LENGTH:
# We are already at minimum chunk length, replicate.
conn.ask(self._doAskObjectHistoryFrom(min_oid, min_serial,
count))
else:
# Check a smaller chunk.
# Note: this could be made into a real binary search, but is
# it really worth the work ?
# Note: +1, so we can detect we reached the end when answer
# comes back.
p = self._doAskCheckSerialRange(min_oid, min_serial,
min(length / 2, count + 1))
if p is None:
if count == length:
# Go on with next chunk
p = self._doAskCheckSerialRange(max_oid, add64(max_serial, 1))
else:
# Nothing remains, so the replication for this partition is
# finished.
replicator.replication_done = True
if p is not None:
conn.ask(p)
......@@ -30,34 +30,32 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
tid = app.dm.getLastTID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askOIDs(self, conn, min_oid, length, partition):
# This method is complicated, because I must return OIDs only
# about usable partitions assigned to me.
app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
oid_list = app.dm.getOIDList(min_oid, length,
app.pt.getPartitions(), partition_list)
conn.answer(Packets.AnswerOIDs(oid_list))
def askTIDsFrom(self, conn, min_tid, length, partition):
# This method is complicated, because I must return TIDs only
# about usable partitions assigned to me.
app = self.app
if partition == protocol.INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid)
else:
partition_list = [partition]
tid_list = app.dm.getReplicationTIDList(min_tid, length,
app.pt.getPartitions(), partition_list)
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askObjectHistoryFrom(self, conn, oid, min_serial, length):
def askObjectHistoryFrom(self, conn, min_oid, min_serial, length,
partition):
app = self.app
object_dict = app.dm.getObjectHistoryFrom(min_oid, min_serial, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict))
def askCheckTIDRange(self, conn, min_tid, length, partition):
app = self.app
count, tid_checksum, max_tid = app.dm.checkTIDRange(min_tid, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length, count,
tid_checksum, max_tid))
def askCheckSerialRange(self, conn, min_oid, min_serial, length,
partition):
app = self.app
history_list = app.dm.getObjectHistoryFrom(oid, min_serial, length)
conn.answer(Packets.AnswerObjectHistoryFrom(oid, history_list))
count, oid_checksum, max_oid, serial_checksum, max_serial = \
app.dm.checkSerialRange(min_oid, min_serial, length,
app.pt.getPartitions(), partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial))
......@@ -46,6 +46,46 @@ class Partition(object):
return tid is not None and (
min_pending_tid is None or tid < min_pending_tid)
class Task(object):
"""
A Task is a callable to execute at another time, with given parameters.
Execution result is kept and can be retrieved later.
"""
_func = None
_args = None
_kw = None
_result = None
_processed = False
def __init__(self, func, args=(), kw=None):
self._func = func
self._args = args
if kw is None:
kw = {}
self._kw = kw
def process(self):
if self._processed:
raise ValueError, 'You cannot process a single Task twice'
self._processed = True
self._result = self._func(*self._args, **self._kw)
def getResult(self):
# Should we instead execute immediately rather than raising ?
if not self._processed:
raise ValueError, 'You cannot get a result until task is executed'
return self._result
def __repr__(self):
fmt = '<%s at %x %r(*%r, **%r)%%s>' % (self.__class__.__name__,
id(self), self._func, self._args, self._kw)
if self._processed:
extra = ' => %r' % (self._result, )
else:
extra = ''
return fmt % (extra, )
class Replicator(object):
"""This class handles replications of objects and transactions.
......@@ -98,21 +138,23 @@ class Replicator(object):
# didn't answer yet.
# unfinished_tid_list
# The list of unfinished TIDs known by master node.
# oid_list
# List of OIDs to replicate. Doesn't contains currently-replicated
# object.
# XXX: not defined here
# XXX: accessed (r/w) directly by ReplicationHandler
# next_oid
# Next OID to ask when oid_list is empty.
# XXX: not defined here
# XXX: accessed (r/w) directly by ReplicationHandler
# replication_done
# False if we know there is something to replicate.
# True when current_partition is replicated, or we don't know yet if
# there is something to replicate
# XXX: accessed (w) directly by ReplicationHandler
new_partition_dict = None
critical_tid_dict = None
partition_dict = None
task_list = None
task_dict = None
current_partition = None
current_connection = None
waiting_for_unfinished_tids = None
unfinished_tid_list = None
replication_done = None
def __init__(self, app):
self.app = app
......@@ -129,6 +171,8 @@ class Replicator(object):
def reset(self):
"""Reset attributes to restart replicating."""
self.task_list = []
self.task_dict = {}
self.current_partition = None
self.current_connection = None
self.waiting_for_unfinished_tids = False
......@@ -213,15 +257,12 @@ class Replicator(object):
p = Packets.RequestIdentification(NodeTypes.STORAGE,
app.uuid, app.server, app.name)
self.current_connection.ask(p)
p = Packets.AskTIDsFrom(ZERO_TID, 1000,
self.current_partition.getRID())
self.current_connection.ask(p, timeout=300)
else:
self.current_connection.getHandler().startReplication(
self.current_connection)
self.replication_done = False
def _finishReplication(self):
app = self.app
# TODO: remove try..except: pass
try:
self.partition_dict.pop(self.current_partition.getRID())
......@@ -243,7 +284,11 @@ class Replicator(object):
self._askCriticalTID()
if self.current_partition is not None:
if self.replication_done:
# Don't end replication until we have received all expected
# answers, as we might have asked object data just before the last
# AnswerCheckSerialRange.
if self.replication_done and \
not self.current_connection.isPending():
# finish a replication
logging.info('replication is done for %s' %
(self.current_partition.getRID(), ))
......@@ -289,3 +334,57 @@ class Replicator(object):
and not self.new_partition_dict.has_key(rid):
self.new_partition_dict[rid] = Partition(rid)
def _addTask(self, key, func, args=(), kw=None):
task = Task(func, args, kw)
task_dict = self.task_dict
if key in task_dict:
raise ValueError, 'Task with key %r already exists (%r), cannot ' \
'add %r' % (key, task_dict[key], task)
task_dict[key] = task
self.task_list.append(task)
def processDelayedTasks(self):
task_list = self.task_list
if task_list:
for task in task_list:
task.process()
self.task_list = []
def checkTIDRange(self, min_tid, length, partition):
app = self.app
self._addTask(('TID', min_tid, length), app.dm.checkTIDRange,
(min_tid, length, app.pt.getPartitions(), partition))
def checkSerialRange(self, min_oid, min_serial, length, partition):
app = self.app
self._addTask(('Serial', min_oid, min_serial, length),
app.dm.checkSerialRange, (min_oid, min_serial, length,
app.pt.getPartitions(), partition))
def getTIDsFrom(self, min_tid, length, partition):
app = self.app
self._addTask('TIDsFrom',
app.dm.getReplicationTIDList, (min_tid, length,
app.pt.getPartitions(), partition))
def getObjectHistoryFrom(self, min_oid, min_serial, length, partition):
app = self.app
self._addTask('ObjectHistoryFrom',
app.dm.getObjectHistoryFrom, (min_oid, min_serial, length,
app.pt.getPartitions(), partition))
def _getCheckResult(self, key):
return self.task_dict.pop(key).getResult()
def getTIDCheckResult(self, min_tid, length):
return self._getCheckResult(('TID', min_tid, length))
def getSerialCheckResult(self, min_oid, min_serial, length):
return self._getCheckResult(('Serial', min_oid, min_serial, length))
def getTIDsFromResult(self):
return self._getCheckResult('TIDsFrom')
def getObjectHistoryFromResult(self):
return self._getCheckResult('ObjectHistoryFrom')
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock
from neo.tests import NeoTestBase
from neo.protocol import Packets, ZERO_OID, ZERO_TID
from neo.storage.handlers.replication import ReplicationHandler, add64
from neo.storage.handlers.replication import RANGE_LENGTH, MIN_RANGE_LENGTH
class FakeDict(object):
def __init__(self, items):
self._items = items
self._dict = dict(items)
assert len(self._dict) == len(items), self._dict
def iteritems(self):
for item in self._items:
yield item
def iterkeys(self):
for key, value in self.iteritems():
yield key
def itervalues(self):
for key, value in self.iteritems():
yield value
def items(self):
return self._items[:]
def keys(self):
return [x for x, y in self._items]
def values(self):
return [y for x, y in self._items]
def __getitem__(self, key):
return self._dict[key]
def __getattr__(self, key):
return getattr(self._dict, key)
def __len__(self):
return len(self._dict)
class StorageReplicationHandlerTests(NeoTestBase):
def setup(self):
pass
def teardown(self):
pass
def getApp(self, conn=None, tid_check_result=(0, 0, ZERO_TID),
serial_check_result=(0, 0, ZERO_OID, 0, ZERO_TID),
tid_result=(),
history_result=None,
rid=0, critical_tid=ZERO_TID):
if history_result is None:
history_result = {}
replicator = Mock({
'__repr__': 'Fake replicator',
'reset': None,
'checkSerialRange': None,
'checkTIDRange': None,
'getTIDCheckResult': tid_check_result,
'getSerialCheckResult': serial_check_result,
'getTIDsFromResult': tid_result,
'getObjectHistoryFromResult': history_result,
'checkSerialRange': None,
'checkTIDRange': None,
'getTIDsFrom': None,
'getObjectHistoryFrom': None,
})
replicator.current_partition = Mock({
'getRID': rid,
'getCriticalTID': critical_tid,
})
replicator.current_connection = conn
real_replicator = replicator
class FakeApp(object):
replicator = real_replicator
dm = Mock({
'storeTransaction': None,
})
return FakeApp
def _checkReplicationStarted(self, conn, rid, replicator):
min_tid, length, partition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(min_tid, ZERO_TID)
self.assertEqual(length, RANGE_LENGTH)
self.assertEqual(partition, rid)
calls = replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, length, partition)
def _checkPacketTIDList(self, conn, tid_list):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
self.assertEqual(len(packet_list), len(tid_list))
for packet in packet_list:
self.assertEqual(packet.getType(),
Packets.AskTransactionInformation)
ptid = packet.decode()[0]
for tid in tid_list:
if ptid == tid:
tid_list.remove(tid)
break
else:
raise AssertionFailed, '%s not found in %r' % (dump(ptid),
[dump(x) for x in tid_list])
def _checkPacketSerialList(self, conn, object_list):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
self.assertEqual(len(packet_list), len(object_list))
for packet, (oid, serial) in zip(packet_list, object_list):
self.assertEqual(packet.getType(),
Packets.AskObject)
self.assertEqual(packet.decode(), (oid, serial, None))
def test_connectionLost(self):
app = self.getApp()
ReplicationHandler(app).connectionLost(None, None)
self.assertEqual(len(app.replicator.mockGetNamedCalls('reset')), 1)
def test_connectionFailed(self):
app = self.getApp()
ReplicationHandler(app).connectionFailed(None)
self.assertEqual(len(app.replicator.mockGetNamedCalls('reset')), 1)
def test_acceptIdentification(self):
rid = 24
app = self.getApp(rid=rid)
conn = self.getFakeConnection()
replication = ReplicationHandler(app)
replication.acceptIdentification(conn, None, None, None,
None, None)
self._checkReplicationStarted(conn, rid, app.replicator)
def test_startReplication(self):
rid = 24
app = self.getApp(rid=rid)
conn = self.getFakeConnection()
ReplicationHandler(app).startReplication(conn)
self._checkReplicationStarted(conn, rid, app.replicator)
def test_answerTIDsFrom(self):
conn = self.getFakeConnection()
tid_list = [self.getNextTID(), self.getNextTID()]
app = self.getApp(conn=conn, tid_result=[])
# With no known TID
ReplicationHandler(app).answerTIDsFrom(conn, tid_list)
self._checkPacketTIDList(conn, tid_list[:])
# With first TID known
conn = self.getFakeConnection()
known_tid_list = [tid_list[0], ]
unknown_tid_list = [tid_list[1], ]
app = self.getApp(conn=conn, tid_result=known_tid_list)
ReplicationHandler(app).answerTIDsFrom(conn, tid_list)
self._checkPacketTIDList(conn, unknown_tid_list)
def test_answerTransactionInformation(self):
conn = self.getFakeConnection()
app = self.getApp(conn=conn)
tid = self.getNextTID()
user = 'foo'
desc = 'bar'
ext = 'baz'
packed = True
oid_list = [self.getOID(1), self.getOID(2)]
ReplicationHandler(app).answerTransactionInformation(conn, tid, user,
desc, ext, packed, oid_list)
calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, (), (oid_list, user, desc, ext, packed), False)
def test_answerObjectHistoryFrom(self):
conn = self.getFakeConnection()
oid_1 = self.getOID(1)
oid_2 = self.getOID(2)
oid_3 = self.getOID(3)
oid_dict = FakeDict((
(oid_1, [self.getNextTID(), self.getNextTID()]),
(oid_2, [self.getNextTID()]),
(oid_3, [self.getNextTID()]),
))
flat_oid_list = []
for oid, serial_list in oid_dict.iteritems():
for serial in serial_list:
flat_oid_list.append((oid, serial))
app = self.getApp(conn=conn, history_result={})
# With no known OID/Serial
ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
self._checkPacketSerialList(conn, flat_oid_list)
# With some known OID/Serials
conn = self.getFakeConnection()
app = self.getApp(conn=conn, history_result={
oid_1: [oid_dict[oid_1][0], ],
oid_3: [oid_dict[oid_3][0], ],
})
ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
self._checkPacketSerialList(conn, (
(oid_1, oid_dict[oid_1][1]),
(oid_2, oid_dict[oid_2][0]),
))
def test_answerObject(self):
conn = self.getFakeConnection()
app = self.getApp(conn=conn)
oid = self.getOID(1)
serial_start = self.getNextTID()
serial_end = self.getNextTID()
compression = 1
checksum = 2
data = 'foo'
data_serial = None
ReplicationHandler(app).answerObject(conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial)
calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(serial_start, [(oid, compression, checksum, data,
data_serial)], None, False)
# CheckTIDRange
def test_answerCheckTIDRangeIdenticalChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with next chunk
pmin_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkWithoutNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 1, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
max_tid)
# Result: go on with object range checks
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
def test_answerCheckTIDRangeDifferentBigChunk(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: ask again, length halved
pmin_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
def test_answerCheckTIDRangeDifferentSmallChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: ask tid list, and ask next chunk
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2)
tid_call, next_call = calls
tid_packet = tid_call.getParam(0)
next_packet = next_call.getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
pmin_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(plength, length)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
self.assertEqual(next_packet.getType(), Packets.AskCheckTIDRange)
pmin_tid, plength, ppartition = next_packet.decode()
self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn)
handler = ReplicationHandler(app)
# Peer has different data, and less than length
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
max_tid)
# Result: ask tid list, and start replicating object range
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2)
tid_call, next_call = calls
tid_packet = tid_call.getParam(0)
next_packet = next_call.getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom)
pmin_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_packet.decode()
self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
# CheckSerialRange
def test_answerCheckSerialRangeIdenticalChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: go on with next chunk
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
def test_answerCheckSerialRangeIdenticalChunkWithoutNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length - 1, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length - 1, 0, max_oid, 1, max_serial)
# Result: mark replication as done
self.checkNoPacketSent(conn)
self.assertTrue(app.replicator.replication_done)
def test_answerCheckSerialRangeDifferentBigChunk(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: ask again, length halved
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
def test_answerCheckSerialRangeDifferentSmallChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: ask serial list, and ask next chunk
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2)
serial_call, next_call = calls
serial_packet = serial_call.getParam(0)
next_packet = next_call.getParam(0)
self.assertEqual(serial_packet.getType(), Packets.AskObjectHistoryFrom)
pmin_oid, pmin_serial, plength, ppartition = serial_packet.decode()
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(plength, length)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
self.assertEqual(next_packet.getType(), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_packet.decode()
self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
def test_answerCheckSerialRangeDifferentSmallChunkWithoutNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid,
1, max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has different data, and less than length
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length - 1, 0, max_oid, 1, max_serial)
# Result: ask tid list, and mark replication as done
pmin_oid, pmin_serial, plength, ppartition = self.checkAskPacket(conn,
Packets.AskObjectHistoryFrom, decode=True)
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
self.assertTrue(app.replicator.replication_done)
if __name__ == "__main__":
unittest.main()
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock, ReturnValues
from neo.tests import NeoTestBase
from neo.storage.replicator import Replicator, Partition, Task
from neo.protocol import CellStates, NodeStates, Packets
class StorageReplicatorTests(NeoTestBase):
def setup(self):
pass
def teardown(self):
pass
def test_populate(self):
my_uuid = self.getNewUUID()
other_uuid = self.getNewUUID()
app = Mock()
app.uuid = my_uuid
app.pt = Mock({
'getPartitions': 2,
'getRow': ReturnValues(
((my_uuid, CellStates.OUT_OF_DATE),
(other_uuid, CellStates.UP_TO_DATE), ),
((my_uuid, CellStates.UP_TO_DATE),
(other_uuid, CellStates.OUT_OF_DATE), ),
),
})
replicator = Replicator(app)
assert replicator.new_partition_dict is None, \
replicator.new_partition_dict
assert replicator.critical_tid_dict is None, \
replicator.critical_tid_dict
assert replicator.partition_dict is None, replicator.partition_dict
replicator.populate()
self.assertEqual(len(replicator.new_partition_dict), 1)
partition = replicator.new_partition_dict[0]
self.assertEqual(partition.getRID(), 0)
self.assertEqual(partition.getCriticalTID(), None)
self.assertEqual(replicator.critical_tid_dict, {})
self.assertEqual(replicator.partition_dict, {})
def test_reset(self):
replicator = Replicator(None)
assert replicator.task_list is None, replicator.task_list
assert replicator.task_dict is None, replicator.task_dict
assert replicator.current_partition is None, \
replicator.current_partition
assert replicator.current_connection is None, \
replicator.current_connection
assert replicator.waiting_for_unfinished_tids is None, \
replicator.waiting_for_unfinished_tids
assert replicator.unfinished_tid_list is None, \
replicator.unfinished_tid_list
assert replicator.replication_done is None, replicator.replication_done
replicator.reset()
self.assertEqual(replicator.task_list, [])
self.assertEqual(replicator.task_dict, {})
self.assertEqual(replicator.current_partition, None)
self.assertEqual(replicator.current_connection, None)
self.assertEqual(replicator.waiting_for_unfinished_tids, False)
self.assertEqual(replicator.unfinished_tid_list, None)
self.assertEqual(replicator.replication_done, True)
def test_setCriticalTID(self):
replicator = Replicator(None)
master_uuid = self.getNewUUID()
partition_list = [Partition(0), Partition(5)]
replicator.critical_tid_dict = {master_uuid: partition_list}
critical_tid = self.getNextTID()
for partition in partition_list:
self.assertEqual(partition.getCriticalTID(), None)
replicator.setCriticalTID(master_uuid, critical_tid)
self.assertEqual(replicator.critical_tid_dict, {})
for partition in partition_list:
self.assertEqual(partition.getCriticalTID(), critical_tid)
def test_setUnfinishedTIDList(self):
replicator = Replicator(None)
replicator.waiting_for_unfinished_tids = True
assert replicator.unfinished_tid_list is None, \
replicator.unfinished_tid_list
tid_list = [self.getNextTID(), ]
replicator.setUnfinishedTIDList(tid_list)
self.assertEqual(replicator.unfinished_tid_list, tid_list)
self.assertFalse(replicator.waiting_for_unfinished_tids)
def test_act(self):
# Also tests "pending"
uuid = self.getNewUUID()
master_uuid = self.getNewUUID()
bad_unfinished_tid = self.getNextTID()
critical_tid = self.getNextTID()
unfinished_tid = self.getNextTID()
app = Mock()
app.em = Mock({
'register': None,
})
def connectorGenerator():
return Mock()
app.connector_handler = connectorGenerator
app.uuid = uuid
node_addr = ('127.0.0.1', 1234)
node = Mock({
'getAddress': node_addr,
})
running_cell = Mock({
'getNodeState': NodeStates.RUNNING,
'getNode': node,
})
unknown_cell = Mock({
'getNodeState': NodeStates.UNKNOWN,
})
app.pt = Mock({
'getPartitions': 1,
'getRow': ReturnValues(
((uuid, CellStates.OUT_OF_DATE), ),
),
'getCellList': [running_cell, unknown_cell],
})
node_conn_handler = Mock({
'startReplication': None,
})
node_conn = Mock({
'getAddress': node_addr,
'getHandler': node_conn_handler,
})
replicator = Replicator(app)
replicator.populate()
def act():
app.master_conn = self.getFakeConnection(uuid=master_uuid)
self.assertTrue(replicator.pending())
replicator.act()
# ask last IDs to infer critical_tid and unfinished tids
act()
last_ids, unfinished_tids = [x.getParam(0) for x in \
app.master_conn.mockGetNamedCalls('ask')]
self.assertEqual(last_ids.getType(), Packets.AskLastIDs)
self.assertFalse(replicator.new_partition_dict)
self.assertEqual(unfinished_tids.getType(),
Packets.AskUnfinishedTransactions)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# nothing happens until waiting_for_unfinished_tids becomes False
act()
self.checkNoPacketSent(app.master_conn)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# Send answers (garanteed to happen in this order)
replicator.setCriticalTID(master_uuid, critical_tid)
act()
self.checkNoPacketSent(app.master_conn)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# first time, there is an unfinished tid before critical tid,
# replication cannot start, and unfinished TIDs are asked again
replicator.setUnfinishedTIDList([unfinished_tid, bad_unfinished_tid])
self.assertFalse(replicator.waiting_for_unfinished_tids)
# Note: detection that nothing can be replicated happens on first call
# and unfinished tids are asked again on second call. This is ok, but
# might change, so just call twice.
act()
act()
self.checkAskPacket(app.master_conn, Packets.AskUnfinishedTransactions)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# this time, critical tid check should be satisfied
replicator.setUnfinishedTIDList([unfinished_tid, ])
replicator.current_connection = node_conn
act()
self.assertEqual(replicator.current_partition,
replicator.partition_dict[0])
self.assertEqual(len(node_conn_handler.mockGetNamedCalls(
'startReplication')), 1)
self.assertFalse(replicator.replication_done)
# Other calls should do nothing
replicator.current_connection = Mock()
act()
self.checkNoPacketSent(app.master_conn)
self.checkNoPacketSent(replicator.current_connection)
# Mark replication over for this partition
replicator.replication_done = True
# Don't finish while there are pending answers
replicator.current_connection = Mock({
'isPending': True,
})
act()
self.assertTrue(replicator.pending())
replicator.current_connection = Mock({
'isPending': False,
})
act()
# unfinished tid list will not be asked again
self.assertTrue(replicator.unfinished_tid_list)
# also, replication is over
self.assertFalse(replicator.pending())
def test_removePartition(self):
replicator = Replicator(None)
replicator.partition_dict = {0: None, 2: None}
replicator.new_partition_dict = {1: None}
replicator.removePartition(0)
self.assertEqual(replicator.partition_dict, {2: None})
self.assertEqual(replicator.new_partition_dict, {1: None})
replicator.removePartition(1)
replicator.removePartition(2)
self.assertEqual(replicator.partition_dict, {})
self.assertEqual(replicator.new_partition_dict, {})
# Must not raise
replicator.removePartition(3)
def test_addPartition(self):
replicator = Replicator(None)
replicator.partition_dict = {0: None}
replicator.new_partition_dict = {1: None}
replicator.addPartition(0)
replicator.addPartition(1)
self.assertEqual(replicator.partition_dict, {0: None})
self.assertEqual(replicator.new_partition_dict, {1: None})
replicator.addPartition(2)
self.assertEqual(replicator.partition_dict, {0: None})
self.assertEqual(len(replicator.new_partition_dict), 2)
self.assertEqual(replicator.new_partition_dict[1], None)
partition = replicator.new_partition_dict[2]
self.assertEqual(partition.getRID(), 2)
self.assertEqual(partition.getCriticalTID(), None)
def test_processDelayedTasks(self):
replicator = Replicator(None)
replicator.reset()
marker = []
def someCallable(foo, bar=None):
return (foo, bar)
replicator._addTask(1, someCallable, args=('foo', ))
self.assertRaises(ValueError, replicator._addTask, 1, None)
replicator._addTask(2, someCallable, args=('foo', ), kw={'bar': 'bar'})
replicator.processDelayedTasks()
self.assertEqual(replicator._getCheckResult(1), ('foo', None))
self.assertEqual(replicator._getCheckResult(2), ('foo', 'bar'))
# Also test Task
task = Task(someCallable, args=('foo', ))
self.assertRaises(ValueError, task.getResult)
task.process()
self.assertRaises(ValueError, task.process)
self.assertEqual(task.getResult(), ('foo', None))
if __name__ == "__main__":
unittest.main()
......@@ -21,7 +21,7 @@ from collections import deque
from neo.tests import NeoTestBase
from neo.storage.app import Application
from neo.storage.handlers.storage import StorageOperationHandler
from neo.protocol import INVALID_PARTITION
from neo.protocol import INVALID_PARTITION, Packets
from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL
class StorageStorageHandlerTests(NeoTestBase):
......@@ -113,7 +113,7 @@ class StorageStorageHandlerTests(NeoTestBase):
self.assertEquals(len(self.app.event_queue), 0)
self.checkAnswerObject(conn)
def test_25_askTIDsFrom1(self):
def test_25_askTIDsFrom(self):
# well case => answer
conn = self.getFakeConnection()
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
......@@ -122,69 +122,85 @@ class StorageStorageHandlerTests(NeoTestBase):
self.operation.askTIDsFrom(conn, tid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, 2, 1, [1, ])
self.checkAnswerTidsFrom(conn)
def test_25_askTIDsFrom2(self):
# invalid partition => answer usable partitions
conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
self.app.pt = Mock({
'getCellList': (cell, ),
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
tid = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(tid, 2, 1, [0, ])
calls[0].checkArgs(tid, 2, 1, 1)
self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self):
oid = self.getOID(2)
min_tid = self.getNextTID()
min_oid = self.getOID(2)
min_serial = self.getNextTID()
length = 4
partition = 8
num_partitions = 16
tid = self.getNextTID()
conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistoryFrom': [tid]})
self.operation.askObjectHistoryFrom(conn, oid, min_tid, 2)
self.app.dm = Mock({'getObjectHistoryFrom': {min_oid: [tid]},})
self.app.pt = Mock({
'getPartitions': num_partitions,
})
self.operation.askObjectHistoryFrom(conn, min_oid, min_serial, length,
partition)
self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, min_tid, 2)
calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
partition)
def test_25_askOIDs1(self):
# well case > answer OIDs
def test_askCheckTIDRange(self):
count = 1
tid_checksum = 2
min_tid = self.getNextTID()
num_partitions = 4
length = 5
partition = 6
max_tid = self.getNextTID()
self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection()
self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
oid = self.getOID(1)
self.operation.askOIDs(conn, oid, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, 2, 1, [1, ])
self.checkAnswerOids(conn)
def test_25_askOIDs2(self):
# invalid partition => answer usable partitions
self.operation.askCheckTIDRange(conn, min_tid, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, length, num_partitions, partition)
pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
decode=True)
self.assertEqual(min_tid, pmin_tid)
self.assertEqual(length, plength)
self.assertEqual(count, pcount)
self.assertEqual(tid_checksum, ptid_checksum)
self.assertEqual(max_tid, pmax_tid)
def test_askCheckSerialRange(self):
count = 1
oid_checksum = 2
min_oid = self.getOID(1)
num_partitions = 4
length = 5
partition = 6
serial_checksum = 7
min_serial = self.getNextTID()
max_serial = self.getNextTID()
max_oid = self.getOID(2)
self.app.dm = Mock({'checkSerialRange': (count, oid_checksum, max_oid,
serial_checksum, max_serial)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection()
cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
self.app.pt = Mock({
'getCellList': (cell, ),
'getPartitions': 1,
'getAssignedPartitionList': [0],
})
oid = self.getOID(1)
self.operation.askOIDs(conn, oid, 2, INVALID_PARTITION)
self.assertEquals(len(self.app.pt.mockGetNamedCalls('getAssignedPartitionList')), 1)
calls = self.app.dm.mockGetNamedCalls('getOIDList')
self.assertEquals(len(calls), 1)
calls[0].checkArgs(oid, 2, 1, [0])
self.checkAnswerOids(conn)
self.operation.askCheckSerialRange(conn, min_oid, min_serial, length,
partition)
calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, length, num_partitions,
partition)
pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
Packets.AnswerCheckSerialRange, decode=True)
self.assertEqual(min_oid, pmin_oid)
self.assertEqual(min_serial, pmin_serial)
self.assertEqual(length, plength)
self.assertEqual(count, pcount)
self.assertEqual(oid_checksum, poid_checksum)
self.assertEqual(max_oid, pmax_oid)
self.assertEqual(serial_checksum, pserial_checksum)
self.assertEqual(max_serial, pmax_serial)
if __name__ == "__main__":
unittest.main()
......@@ -19,7 +19,7 @@ import unittest
import MySQLdb
from mock import Mock
from neo.util import dump, p64, u64
from neo.protocol import CellStates, INVALID_PTID
from neo.protocol import CellStates, INVALID_PTID, ZERO_OID, ZERO_TID
from neo.tests import NeoTestBase
from neo.exception import DatabaseFailure
from neo.storage.database.mysqldb import MySQLDatabaseManager
......@@ -441,6 +441,23 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None)
def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
txn1, objs1 = self.getTransaction([oid1, oid2])
txn2, objs2 = self.getTransaction([oid1, oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.deleteObject(oid1)
self.assertEqual(self.db.getObject(oid1, tid=tid1), None)
self.assertEqual(self.db.getObject(oid1, tid=tid2), None)
self.db.deleteObject(oid2, serial=tid1)
self.assertEqual(self.db.getObject(oid2, tid=tid1), False)
self.assertEqual(self.db.getObject(oid2, tid=tid2), (tid2, None) + \
objs2[1][1:])
def test_getTransaction(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2)
......@@ -459,30 +476,6 @@ class StorageMySQSLdbTests(NeoTestBase):
self.assertEqual(result, ([oid1], 'user', 'desc', 'ext', False))
self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getOIDList(self):
# store four objects
oid1, oid2, oid3, oid4 = self.getOIDs(4)
tid = self.getNextTID()
txn, objs = self.getTransaction([oid1, oid2, oid3, oid4])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
# get oids
result = self.db.getOIDList(oid1, 4, 1, [0])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(oid1, 4, 2, [0])
self.checkSet(result, [oid1, oid3])
result = self.db.getOIDList(oid1, 4, 2, [0, 1])
self.checkSet(result, [oid1, oid2, oid3, oid4])
result = self.db.getOIDList(oid1, 4, 3, [0])
self.checkSet(result, [oid1, oid4])
# get a subset of oids
result = self.db.getOIDList(oid1, 2, 1, [0])
self.checkSet(result, [oid1, oid2])
result = self.db.getOIDList(oid3, 2, 1, [0])
self.checkSet(result, [oid3, oid4])
result = self.db.getOIDList(oid2, 1, 3, [0])
self.checkSet(result, [oid4])
def test_getObjectHistory(self):
oid = self.getOID(1)
tid1, tid2, tid3 = self.getTIDs(3)
......@@ -506,6 +499,50 @@ class StorageMySQSLdbTests(NeoTestBase):
result = self.db.getObjectHistory(oid, 2, 3)
self.assertEqual(result, None)
def test_getObjectHistoryFrom(self):
oid1 = self.getOID(0)
oid2 = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
txn3, objs3 = self.getTransaction([oid1])
txn4, objs4 = self.getTransaction([oid2])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.storeTransaction(tid3, objs3, txn3)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.finishTransaction(tid3)
self.db.finishTransaction(tid4)
# Check full result
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Lower bound is inclusive
result = self.db.getObjectHistoryFrom(oid1, tid1, 10, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Length is total number of serials
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 3, 1, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Partition constraints are honored
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
})
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, 10, 2, 1)
self.assertEqual(result, {
oid2: [tid2, tid4],
})
def _storeTransactions(self, count):
# use OID generator to know result of tid % N
tid_list = self.getOIDs(count)
......@@ -538,59 +575,20 @@ class StorageMySQSLdbTests(NeoTestBase):
def test_getReplicationTIDList(self):
tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids
result = self.db.getReplicationTIDList(tid1, 4, 1, [0])
result = self.db.getReplicationTIDList(tid1, 4, 1, 0)
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0])
result = self.db.getReplicationTIDList(tid1, 4, 2, 0)
self.checkSet(result, [tid1, tid3])
result = self.db.getReplicationTIDList(tid1, 4, 2, [0, 1])
self.checkSet(result, [tid1, tid2, tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 4, 3, [0])
result = self.db.getReplicationTIDList(tid1, 4, 3, 0)
self.checkSet(result, [tid1, tid4])
# get a subset of tids
result = self.db.getReplicationTIDList(tid3, 4, 1, [0])
result = self.db.getReplicationTIDList(tid3, 4, 1, 0)
self.checkSet(result, [tid3, tid4])
result = self.db.getReplicationTIDList(tid1, 2, 1, [0])
result = self.db.getReplicationTIDList(tid1, 2, 1, 0)
self.checkSet(result, [tid1, tid2])
result = self.db.getReplicationTIDList(tid1, 1, 3, [1])
result = self.db.getReplicationTIDList(tid1, 1, 3, 1)
self.checkSet(result, [tid2])
def test_getTIDListPresent(self):
oid = self.getOID(1)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid])
txn4, objs4 = self.getTransaction([oid])
# four tids, two missing
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid4)
result = self.db.getTIDListPresent([tid1, tid2, tid3, tid4])
self.checkSet(result, [tid1, tid4])
result = self.db.getTIDListPresent([tid1, tid2])
self.checkSet(result, [tid1])
self.assertEqual(self.db.getTIDListPresent([tid2, tid3]), [])
def test_getSerialListPresent(self):
oid1, oid2 = self.getOIDs(2)
tid1, tid2, tid3, tid4 = self.getTIDs(4)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid1])
txn3, objs3 = self.getTransaction([oid2])
txn4, objs4 = self.getTransaction([oid2])
# four object, one revision each
self.db.storeTransaction(tid1, objs1, txn1)
self.db.finishTransaction(tid1)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.finishTransaction(tid4)
result = self.db.getSerialListPresent(oid1, [tid1, tid2])
self.checkSet(result, [tid1])
result = self.db.getSerialListPresent(oid2, [tid3, tid4])
self.checkSet(result, [tid4])
result = self.db.getSerialListPresent(oid1, [tid2])
self.assertEqual(result, [])
result = self.db.getSerialListPresent(oid2, [tid3])
self.assertEqual(result, [])
def test__getObjectData(self):
db = self.db
db.setup(reset=True)
......
......@@ -458,24 +458,6 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(p_hist_list, hist_list)
self.assertEqual(oid, poid)
def test_55_askOIDs(self):
oid = self.getOID(1)
p = Packets.AskOIDs(oid, 1000, 5)
min_oid, length, partition = p.decode()
self.assertEqual(min_oid, oid)
self.assertEqual(length, 1000)
self.assertEqual(partition, 5)
def test_56_answerOIDs(self):
oid1 = self.getNextTID()
oid2 = self.getNextTID()
oid3 = self.getNextTID()
oid4 = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4]
p = Packets.AnswerOIDs(oid_list)
p_oid_list = p.decode()[0]
self.assertEqual(p_oid_list, oid_list)
def test_57_notifyReplicationDone(self):
offset = 10
p = Packets.NotifyReplicationDone(offset)
......@@ -626,14 +608,82 @@ class ProtocolTests(NeoTestBase):
oid = self.getOID(1)
min_serial = self.getNextTID()
length = 5
p = Packets.AskObjectHistoryFrom(oid, min_serial, length)
p_oid, p_min_serial, p_length = p.decode()
partition = 4
p = Packets.AskObjectHistoryFrom(oid, min_serial, length, partition)
p_oid, p_min_serial, p_length, p_partition = p.decode()
self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerObjectHistoryFrom(self):
self._testXIDAndYIDList(Packets.AnswerObjectHistoryFrom)
object_dict = {}
for int_oid in xrange(4):
object_dict[self.getOID(int_oid)] = [self.getNextTID() \
for _ in xrange(5)]
p = Packets.AnswerObjectHistoryFrom(object_dict)
p_object_dict = p.decode()[0]
self.assertEqual(object_dict, p_object_dict)
def test_AskCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckTIDRange(min_tid, length, partition)
p_min_tid, p_length, p_partition = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckTIDRange(self):
min_tid = self.getNextTID()
length = 2
count = 1
tid_checksum = 42
max_tid = self.getNextTID()
p = Packets.AnswerCheckTIDRange(min_tid, length, count, tid_checksum,
max_tid)
p_min_tid, p_length, p_count, p_tid_checksum, p_max_tid = p.decode()
self.assertEqual(p_min_tid, min_tid)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_tid, max_tid)
def test_AskCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
partition = 4
p = Packets.AskCheckSerialRange(min_oid, min_serial, length, partition)
p_min_oid, p_min_serial, p_length, p_partition = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerCheckSerialRange(self):
min_oid = self.getOID(1)
min_serial = self.getNextTID()
length = 2
count = 1
oid_checksum = 24
max_oid = self.getOID(5)
tid_checksum = 42
max_serial = self.getNextTID()
p = Packets.AnswerCheckSerialRange(min_oid, min_serial, length, count,
oid_checksum, max_oid, tid_checksum, max_serial)
p_min_oid, p_min_serial, p_length, p_count, p_oid_checksum, \
p_max_oid, p_tid_checksum, p_max_serial = p.decode()
self.assertEqual(p_min_oid, min_oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_count, count)
self.assertEqual(p_oid_checksum, oid_checksum)
self.assertEqual(p_max_oid, max_oid)
self.assertEqual(p_tid_checksum, tid_checksum)
self.assertEqual(p_max_serial, max_serial)
def test_AskPack(self):
tid = self.getNextTID()
......
......@@ -56,6 +56,8 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions',
'neo.tests.storage.testReplicationHandler',
'neo.tests.storage.testReplicator',
# client application
'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment