Commit 8e3c7b01 authored by Julien Muchembled's avatar Julien Muchembled

Implements backup using specialised storage nodes and relying on replication

Replication is also fully reimplemented:
- It is not done anymore on whole partitions.
- It runs at lowest priority not to degrades performance for client nodes.

Schema of MySQL table is changed to optimize storage layout: rows are now
grouped by age, for good partial replication performance.
This certainly also speeds up simple loads/stores.
parent 75d83690
...@@ -111,42 +111,17 @@ RC - Review output of pylint (CODE) ...@@ -111,42 +111,17 @@ RC - Review output of pylint (CODE)
consider using query(request, args) instead of query(request % args) consider using query(request, args) instead of query(request % args)
- Make listening address and port optionnal, and if they are not provided - Make listening address and port optionnal, and if they are not provided
listen on all interfaces on any available port. listen on all interfaces on any available port.
- Replication throttling (HIGH AVAILABILITY) - Make replication speed configurable (HIGH AVAILABILITY)
In its current implementation, replication runs at full speed, which In its current implementation, replication runs at lowest priority, not to
degrades performance for client nodes. Replication should allow degrades performance for client nodes. But when there's only 1 storage
throttling, and that throttling should be configurable. left for a partition, it may be wanted to guarantee a minimum speed to
See "Replication pipelining". avoid complete data loss if another failure happens too early.
- Pack segmentation & throttling (HIGH AVAILABILITY) - Pack segmentation & throttling (HIGH AVAILABILITY)
In its current implementation, pack runs in one call on all storage nodes In its current implementation, pack runs in one call on all storage nodes
at the same time, which lcoks down the whole cluster. This task should at the same time, which lcoks down the whole cluster. This task should
be split in chunks and processed in "background" on storage nodes. be split in chunks and processed in "background" on storage nodes.
Packing throttling should probably be at the lowest possible priority Packing throttling should probably be at the lowest possible priority
(below interactive use and below replication). (below interactive use and below replication).
- Replication pipelining (SPEED)
Replication work currently with too many exchanges between replicating
storage, and network latency can become a significant limit.
This should be changed to have just one initial request from
replicating storage, and multiple packets from reference storage with
database range checksums. When receiving these checksums, replicating
storage must compare with what it has, and ask row lists (might not even
be required) and data when there are differences. Quick fetching from
network with asynchronous checking (=queueing) + congestion control
(asking reference storage's to pause its packet flow) will probably be
required.
This should make it easier to throttle replication workload on reference
storage node, as it can decide to postpone replication-related packets on
its own.
- Partial replication (SPEED)
In its current implementation, replication always happens on a whole
partition. In typical use, only a few last transactions will have been
missed, so replicating only past a given TID would be much faster.
To achieve this, storage nodes must store 2 values:
- a pack identifier, which must be different each time a pack occurs
(increasing number sequence, TID-ish, etc) to trigger a
whole-partition replication when a pack happened (this could be
improved too, later)
- the latest (-ish) transaction committed locally, to use as a lower
replication boundary
- tpc_finish failures propagation to master (FUNCTIONALITY) - tpc_finish failures propagation to master (FUNCTIONALITY)
When asked to lock transaction data, if something goes wrong the master When asked to lock transaction data, if something goes wrong the master
node must be informed. node must be informed.
......
...@@ -9,7 +9,7 @@ SQL commands to migrate each storage from NEO 0.10.x:: ...@@ -9,7 +9,7 @@ SQL commands to migrate each storage from NEO 0.10.x::
CREATE TABLE new_data (id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, hash BINARY(20) NOT NULL UNIQUE, compression TINYINT UNSIGNED NULL, value LONGBLOB NULL) ENGINE = InnoDB SELECT DISTINCT obj.hash as hash, compression, value FROM obj, data WHERE obj.hash=data.hash ORDER BY serial; CREATE TABLE new_data (id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, hash BINARY(20) NOT NULL UNIQUE, compression TINYINT UNSIGNED NULL, value LONGBLOB NULL) ENGINE = InnoDB SELECT DISTINCT obj.hash as hash, compression, value FROM obj, data WHERE obj.hash=data.hash ORDER BY serial;
DROP TABLE data; DROP TABLE data;
RENAME TABLE new_data TO data; RENAME TABLE new_data TO data;
CREATE TABLE new_obj (partition SMALLINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL, data_id BIGINT UNSIGNED NULL, value_serial BIGINT UNSIGNED NULL, PRIMARY KEY (partition, oid, serial), KEY (data_id)) ENGINE = InnoDB SELECT partition, oid, serial, data.id as data_id, value_serial FROM obj LEFT JOIN data ON (obj.hash=data.hash); CREATE TABLE new_obj (partition SMALLINT UNSIGNED NOT NULL, oid BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL, data_id BIGINT UNSIGNED NULL, value_serial BIGINT UNSIGNED NULL, PRIMARY KEY (partition, serial, oid), KEY (partition, oid, serial), KEY (data_id)) ENGINE = InnoDB SELECT partition, oid, serial, data.id as data_id, value_serial FROM obj LEFT JOIN data ON (obj.hash=data.hash);
DROP TABLE obj; DROP TABLE obj;
RENAME TABLE new_obj TO obj; RENAME TABLE new_obj TO obj;
ALTER TABLE tobj CHANGE hash data_id BIGINT UNSIGNED NULL; ALTER TABLE tobj CHANGE hash data_id BIGINT UNSIGNED NULL;
......
...@@ -959,7 +959,7 @@ class Application(object): ...@@ -959,7 +959,7 @@ class Application(object):
tid_list = [] tid_list = []
# request a tid list for each partition # request a tid list for each partition
for offset in xrange(self.pt.getPartitions()): for offset in xrange(self.pt.getPartitions()):
p = Packets.AskTIDsFrom(start, stop, limit, [offset]) p = Packets.AskTIDsFrom(start, stop, limit, offset)
for node, conn in self.cp.iterateForObject(offset, readable=True): for node, conn in self.cp.iterateForObject(offset, readable=True):
try: try:
r = self._askStorage(conn, p) r = self._askStorage(conn, p)
......
...@@ -90,3 +90,8 @@ class ConfigurationManager(object): ...@@ -90,3 +90,8 @@ class ConfigurationManager(object):
# only from command line # only from command line
return util.bin(self.argument_list.get('uuid', None)) return util.bin(self.argument_list.get('uuid', None))
def getUpstreamCluster(self):
return self.__get('upstream_cluster', True)
def getUpstreamMasters(self):
return util.parseMasterList(self.__get('upstream_masters'))
...@@ -79,6 +79,9 @@ class EpollEventManager(object): ...@@ -79,6 +79,9 @@ class EpollEventManager(object):
self.epoll.unregister(fd) self.epoll.unregister(fd)
del self.connection_dict[fd] del self.connection_dict[fd]
def isIdle(self):
return not (self._pending_processing or self.writer_set)
def _addPendingConnection(self, conn): def _addPendingConnection(self, conn):
pending_processing = self._pending_processing pending_processing = self._pending_processing
if conn not in pending_processing: if conn not in pending_processing:
......
...@@ -48,6 +48,7 @@ class ErrorCodes(Enum): ...@@ -48,6 +48,7 @@ class ErrorCodes(Enum):
PROTOCOL_ERROR = Enum.Item(4) PROTOCOL_ERROR = Enum.Item(4)
BROKEN_NODE = Enum.Item(5) BROKEN_NODE = Enum.Item(5)
ALREADY_PENDING = Enum.Item(7) ALREADY_PENDING = Enum.Item(7)
REPLICATION_ERROR = Enum.Item(8)
ErrorCodes = ErrorCodes() ErrorCodes = ErrorCodes()
class ClusterStates(Enum): class ClusterStates(Enum):
...@@ -55,6 +56,9 @@ class ClusterStates(Enum): ...@@ -55,6 +56,9 @@ class ClusterStates(Enum):
VERIFYING = Enum.Item(2) VERIFYING = Enum.Item(2)
RUNNING = Enum.Item(3) RUNNING = Enum.Item(3)
STOPPING = Enum.Item(4) STOPPING = Enum.Item(4)
STARTING_BACKUP = Enum.Item(5)
BACKINGUP = Enum.Item(6)
STOPPING_BACKUP = Enum.Item(7)
ClusterStates = ClusterStates() ClusterStates = ClusterStates()
class NodeTypes(Enum): class NodeTypes(Enum):
...@@ -117,6 +121,7 @@ ZERO_TID = '\0' * 8 ...@@ -117,6 +121,7 @@ ZERO_TID = '\0' * 8
ZERO_OID = '\0' * 8 ZERO_OID = '\0' * 8
OID_LEN = len(INVALID_OID) OID_LEN = len(INVALID_OID)
TID_LEN = len(INVALID_TID) TID_LEN = len(INVALID_TID)
MAX_TID = '\x7f' + '\xff' * 7 # SQLite does not accept numbers above 2^63-1
UUID_NAMESPACES = { UUID_NAMESPACES = {
NodeTypes.STORAGE: 'S', NodeTypes.STORAGE: 'S',
...@@ -723,6 +728,7 @@ class LastIDs(Packet): ...@@ -723,6 +728,7 @@ class LastIDs(Packet):
POID('last_oid'), POID('last_oid'),
PTID('last_tid'), PTID('last_tid'),
PPTID('last_ptid'), PPTID('last_ptid'),
PTID('backup_tid'),
) )
class PartitionTable(Packet): class PartitionTable(Packet):
...@@ -760,16 +766,6 @@ class PartitionChanges(Packet): ...@@ -760,16 +766,6 @@ class PartitionChanges(Packet):
), ),
) )
class ReplicationDone(Packet):
"""
Notify the master node that a partition has been successully replicated from
a storage to another.
S -> M
"""
_fmt = PStruct('notify_replication_done',
PNumber('offset'),
)
class StartOperation(Packet): class StartOperation(Packet):
""" """
Tell a storage nodes to start an operation. Until a storage node receives Tell a storage nodes to start an operation. Until a storage node receives
...@@ -965,7 +961,7 @@ class GetObject(Packet): ...@@ -965,7 +961,7 @@ class GetObject(Packet):
""" """
Ask a stored object by its OID and a serial or a TID if given. If a serial Ask a stored object by its OID and a serial or a TID if given. If a serial
is specified, the specified revision of an object will be returned. If is specified, the specified revision of an object will be returned. If
a TID is specified, an object right before the TID will be returned. S,C -> S. a TID is specified, an object right before the TID will be returned. C -> S.
Answer the requested object. S -> C. Answer the requested object. S -> C.
""" """
_fmt = PStruct('ask_object', _fmt = PStruct('ask_object',
...@@ -1003,16 +999,14 @@ class TIDList(Packet): ...@@ -1003,16 +999,14 @@ class TIDList(Packet):
class TIDListFrom(Packet): class TIDListFrom(Packet):
""" """
Ask for length TIDs starting at min_tid. The order of TIDs is ascending. Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
S -> S. C -> S.
Answer the requested TIDs. S -> S Answer the requested TIDs. S -> C
""" """
_fmt = PStruct('tid_list_from', _fmt = PStruct('tid_list_from',
PTID('min_tid'), PTID('min_tid'),
PTID('max_tid'), PTID('max_tid'),
PNumber('length'), PNumber('length'),
PList('partition_list',
PNumber('partition'), PNumber('partition'),
),
) )
_answer = PStruct('answer_tids', _answer = PStruct('answer_tids',
...@@ -1054,27 +1048,6 @@ class ObjectHistory(Packet): ...@@ -1054,27 +1048,6 @@ class ObjectHistory(Packet):
PFHistoryList, PFHistoryList,
) )
class ObjectHistoryFrom(Packet):
"""
Ask history information for a given object. The order of serials is
ascending, and starts at (or above) min_serial for min_oid. S -> S.
Answer the requested serials. S -> S.
"""
_fmt = PStruct('ask_object_history',
POID('min_oid'),
PTID('min_serial'),
PTID('max_serial'),
PNumber('length'),
PNumber('partition'),
)
_answer = PStruct('ask_finish_transaction',
PDict('object_dict',
POID('oid'),
PFTidList,
),
)
class PartitionList(Packet): class PartitionList(Packet):
""" """
All the following messages are for neoctl to admin node All the following messages are for neoctl to admin node
...@@ -1341,6 +1314,110 @@ class NotifyReady(Packet): ...@@ -1341,6 +1314,110 @@ class NotifyReady(Packet):
""" """
pass pass
# replication
class FetchTransactions(Packet):
"""
S -> S
"""
_fmt = PStruct('ask_transaction_list',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
PFTidList, # already known transactions
)
_answer = PStruct('answer_transaction_list',
PTID('pack_tid'),
PTID('next_tid'),
PFTidList, # transactions to delete
)
class AddTransaction(Packet):
"""
S -> S
"""
_fmt = PStruct('add_transaction',
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PBoolean('packed'),
PFOidList,
)
class FetchObjects(Packet):
"""
S -> S
"""
_fmt = PStruct('ask_object_list',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
PDict('object_dict', # already known objects
PTID('serial'),
PFOidList,
),
)
_answer = PStruct('answer_object_list',
PTID('pack_tid'),
PTID('next_tid'),
POID('next_oid'),
PDict('object_dict', # objects to delete
PTID('serial'),
PFOidList,
),
)
class AddObject(Packet):
"""
S -> S
"""
_fmt = PStruct('add_object',
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
)
class Replicate(Packet):
"""
M -> S
"""
_fmt = PStruct('replicate',
PTID('tid'),
PString('upstream_name'),
PDict('source_dict',
PNumber('partition'),
PAddress('address'),
)
)
class ReplicationDone(Packet):
"""
Notify the master node that a partition has been successully replicated from
a storage to another.
S -> M
"""
_fmt = PStruct('notify_replication_done',
PNumber('offset'),
PTID('tid'),
)
class Truncate(Packet):
"""
M -> S
"""
_fmt = PStruct('ask_truncate',
PTID('tid'),
)
_answer = PFEmpty
StaticRegistry = {} StaticRegistry = {}
def register(request, ignore_when_closed=None): def register(request, ignore_when_closed=None):
""" Register a packet in the packet registry """ """ Register a packet in the packet registry """
...@@ -1516,16 +1593,12 @@ class Packets(dict): ...@@ -1516,16 +1593,12 @@ class Packets(dict):
ClusterState) ClusterState)
NotifyLastOID = register( NotifyLastOID = register(
NotifyLastOID) NotifyLastOID)
NotifyReplicationDone = register(
ReplicationDone)
AskObjectUndoSerial, AnswerObjectUndoSerial = register( AskObjectUndoSerial, AnswerObjectUndoSerial = register(
ObjectUndoSerial) ObjectUndoSerial)
AskHasLock, AnswerHasLock = register( AskHasLock, AnswerHasLock = register(
HasLock) HasLock)
AskTIDsFrom, AnswerTIDsFrom = register( AskTIDsFrom, AnswerTIDsFrom = register(
TIDListFrom) TIDListFrom)
AskObjectHistoryFrom, AnswerObjectHistoryFrom = register(
ObjectHistoryFrom)
AskPack, AnswerPack = register( AskPack, AnswerPack = register(
Pack, ignore_when_closed=False) Pack, ignore_when_closed=False)
AskCheckTIDRange, AnswerCheckTIDRange = register( AskCheckTIDRange, AnswerCheckTIDRange = register(
...@@ -1540,6 +1613,20 @@ class Packets(dict): ...@@ -1540,6 +1613,20 @@ class Packets(dict):
CheckCurrentSerial) CheckCurrentSerial)
NotifyTransactionFinished = register( NotifyTransactionFinished = register(
NotifyTransactionFinished) NotifyTransactionFinished)
Replicate = register(
Replicate)
NotifyReplicationDone = register(
ReplicationDone)
AskFetchTransactions, AnswerFetchTransactions = register(
FetchTransactions)
AskFetchObjects, AnswerFetchObjects = register(
FetchObjects)
AddTransaction = register(
AddTransaction)
AddObject = register(
AddObject)
AskTruncate, AnswerTruncate = register(
Truncate)
def Errors(): def Errors():
registry_dict = {} registry_dict = {}
......
...@@ -150,6 +150,11 @@ class PartitionTable(object): ...@@ -150,6 +150,11 @@ class PartitionTable(object):
return True return True
return False return False
def getCell(self, offset, uuid):
for cell in self.partition_list[offset]:
if cell.getUUID() == uuid:
return cell
def setCell(self, offset, node, state): def setCell(self, offset, node, state):
if state == CellStates.DISCARDED: if state == CellStates.DISCARDED:
return self.removeCell(offset, node) return self.removeCell(offset, node)
...@@ -157,28 +162,19 @@ class PartitionTable(object): ...@@ -157,28 +162,19 @@ class PartitionTable(object):
raise PartitionTableException('Invalid node state') raise PartitionTableException('Invalid node state')
self.count_dict.setdefault(node, 0) self.count_dict.setdefault(node, 0)
row = self.partition_list[offset] for cell in self.partition_list[offset]:
if len(row) == 0: if cell.getNode() is node:
# Create a new row.
row = [Cell(node, state), ]
if state != CellStates.FEEDING:
self.count_dict[node] += 1
self.partition_list[offset] = row
self.num_filled_rows += 1
else:
# XXX this can be slow, but it is necessary to remove a duplicate,
# if any.
for cell in row:
if cell.getNode() == node:
row.remove(cell)
if not cell.isFeeding(): if not cell.isFeeding():
self.count_dict[node] -= 1 self.count_dict[node] -= 1
cell.setState(state)
break break
else:
row = self.partition_list[offset]
self.num_filled_rows += not row
row.append(Cell(node, state)) row.append(Cell(node, state))
if state != CellStates.FEEDING: if state != CellStates.FEEDING:
self.count_dict[node] += 1 self.count_dict[node] += 1
return (offset, node.getUUID(), state) return offset, node.getUUID(), state
def removeCell(self, offset, node): def removeCell(self, offset, node):
row = self.partition_list[offset] row = self.partition_list[offset]
......
...@@ -28,6 +28,10 @@ from neo.lib.event import EventManager ...@@ -28,6 +28,10 @@ from neo.lib.event import EventManager
from neo.lib.connection import ListeningConnection, ClientConnection from neo.lib.connection import ListeningConnection, ClientConnection
from neo.lib.exception import ElectionFailure, PrimaryFailure, OperationFailure from neo.lib.exception import ElectionFailure, PrimaryFailure, OperationFailure
from neo.lib.util import dump from neo.lib.util import dump
class StateChangedException(Exception): pass
from .backup_app import BackupApplication
from .handlers import election, identification, secondary from .handlers import election, identification, secondary
from .handlers import administration, client, storage, shutdown from .handlers import administration, client, storage, shutdown
from .pt import PartitionTable from .pt import PartitionTable
...@@ -41,6 +45,8 @@ class Application(object): ...@@ -41,6 +45,8 @@ class Application(object):
packing = None packing = None
# Latest completely commited TID # Latest completely commited TID
last_transaction = ZERO_TID last_transaction = ZERO_TID
backup_tid = None
backup_app = None
def __init__(self, config): def __init__(self, config):
# Internal attributes. # Internal attributes.
...@@ -90,16 +96,29 @@ class Application(object): ...@@ -90,16 +96,29 @@ class Application(object):
self._current_manager = None self._current_manager = None
# backup
upstream_cluster = config.getUpstreamCluster()
if upstream_cluster:
if upstream_cluster == self.name:
raise ValueError("upstream cluster name must be"
" different from cluster name")
self.backup_app = BackupApplication(self, upstream_cluster,
*config.getUpstreamMasters())
registerLiveDebugger(on_log=self.log) registerLiveDebugger(on_log=self.log)
def close(self): def close(self):
self.listening_conn = None self.listening_conn = None
if self.backup_app is not None:
self.backup_app.close()
self.nm.close() self.nm.close()
self.em.close() self.em.close()
del self.__dict__ del self.__dict__
def log(self): def log(self):
self.em.log() self.em.log()
if self.backup_app is not None:
self.backup_app.log()
self.nm.log() self.nm.log()
self.tm.log() self.tm.log()
if self.pt is not None: if self.pt is not None:
...@@ -257,28 +276,30 @@ class Application(object): ...@@ -257,28 +276,30 @@ class Application(object):
a shutdown. a shutdown.
""" """
neo.lib.logging.info('provide service') neo.lib.logging.info('provide service')
em = self.em poll = self.em.poll
self.tm.reset() self.tm.reset()
self.changeClusterState(ClusterStates.RUNNING) self.changeClusterState(ClusterStates.RUNNING)
# Now everything is passive. # Now everything is passive.
while True:
try: try:
em.poll(1) while True:
poll(1)
except OperationFailure: except OperationFailure:
# If not operational, send Stop Operation packets to storage # If not operational, send Stop Operation packets to storage
# nodes and client nodes. Abort connections to client nodes. # nodes and client nodes. Abort connections to client nodes.
neo.lib.logging.critical('No longer operational') neo.lib.logging.critical('No longer operational')
except StateChangedException, e:
assert e.args[0] == ClusterStates.STARTING_BACKUP
self.backup_tid = tid = self.getLastTransaction()
self.pt.setBackupTidDict(dict((node.getUUID(), tid)
for node in self.nm.getStorageList(only_identified=True)))
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
if node.isStorage() or node.isClient(): if node.isStorage() or node.isClient():
node.notify(Packets.StopOperation()) node.notify(Packets.StopOperation())
if node.isClient(): if node.isClient():
node.getConnection().abort() node.getConnection().abort()
# Then, go back, and restart.
return
def playPrimaryRole(self): def playPrimaryRole(self):
neo.lib.logging.info( neo.lib.logging.info(
'play the primary role with %r', self.listening_conn) 'play the primary role with %r', self.listening_conn)
...@@ -314,6 +335,12 @@ class Application(object): ...@@ -314,6 +335,12 @@ class Application(object):
self.runManager(RecoveryManager) self.runManager(RecoveryManager)
while True: while True:
self.runManager(VerificationManager) self.runManager(VerificationManager)
if self.backup_tid:
if self.backup_app is None:
raise RuntimeError("No upstream cluster to backup"
" defined in configuration")
self.backup_app.provideService()
else:
self.provideService() self.provideService()
def playSecondaryRole(self): def playSecondaryRole(self):
...@@ -364,7 +391,8 @@ class Application(object): ...@@ -364,7 +391,8 @@ class Application(object):
# select the storage handler # select the storage handler
client_handler = client.ClientServiceHandler(self) client_handler = client.ClientServiceHandler(self)
if state == ClusterStates.RUNNING: if state in (ClusterStates.RUNNING, ClusterStates.STARTING_BACKUP,
ClusterStates.BACKINGUP, ClusterStates.STOPPING_BACKUP):
storage_handler = storage.StorageServiceHandler(self) storage_handler = storage.StorageServiceHandler(self)
elif self._current_manager is not None: elif self._current_manager is not None:
storage_handler = self._current_manager.getHandler() storage_handler = self._current_manager.getHandler()
...@@ -389,6 +417,7 @@ class Application(object): ...@@ -389,6 +417,7 @@ class Application(object):
handler = storage_handler handler = storage_handler
else: else:
continue # keep handler continue # keep handler
if type(handler) is not type(conn.getLastHandler()):
conn.setHandler(handler) conn.setHandler(handler)
handler.connectionCompleted(conn) handler.connectionCompleted(conn)
self.cluster_state = state self.cluster_state = state
...@@ -437,19 +466,13 @@ class Application(object): ...@@ -437,19 +466,13 @@ class Application(object):
sys.exit() sys.exit()
def identifyStorageNode(self, uuid, node): def identifyStorageNode(self, uuid, node):
if self.cluster_state == ClusterStates.STOPPING:
raise NotReadyError
state = NodeStates.RUNNING state = NodeStates.RUNNING
handler = None
if self.cluster_state == ClusterStates.RUNNING:
if uuid is None or node is None: if uuid is None or node is None:
# same as for verification # same as for verification
state = NodeStates.PENDING state = NodeStates.PENDING
handler = storage.StorageServiceHandler(self) return uuid, state, storage.StorageServiceHandler(self)
elif self.cluster_state == ClusterStates.STOPPING:
raise NotReadyError
else:
raise RuntimeError('unhandled cluster state: %s' %
(self.cluster_state, ))
return (uuid, state, handler)
def identifyNode(self, node_type, uuid, node): def identifyNode(self, node_type, uuid, node):
......
##############################################################################
#
# Copyright (c) 2011 Nexedi SARL and Contributors. All Rights Reserved.
# Julien Muchembled <jm@nexedi.com>
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsibility of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# guarantees and support are strongly advised to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
##############################################################################
import random, weakref
from bisect import bisect
import neo.lib
from neo.lib.bootstrap import BootstrapManager
from neo.lib.connector import getConnectorHandler
from neo.lib.exception import PrimaryFailure
from neo.lib.node import NodeManager
from neo.lib.protocol import CellStates, ClusterStates, NodeTypes, Packets
from neo.lib.protocol import INVALID_TID, ZERO_TID
from neo.lib.util import add64, u64, dump
from .app import StateChangedException
from .pt import PartitionTable
from .handlers.backup import BackupHandler
"""
Backup algorithm
This implementation relies on normal storage replication.
Storage nodes that are specialised for backup are not in the same NEO cluster,
but are managed by another master in a different cluster.
When the cluster is in BACKINGUP state, its master acts like a client to the
master of the main cluster. It gets notified of new data thanks to invalidation,
and notifies in turn its storage nodes what/when to replicate.
Storages stay in UP_TO_DATE state, even if partitions are synchronized up to
different tids. Storage nodes remember they are in such state and when
switching into RUNNING state, the cluster cuts the DB at the last TID for which
we have all data.
Out of backup storage nodes assigned to a partition, one is chosen as primary
for that partition. It means only this node will fetch data from the upstream
cluster, to minimize bandwidth between clusters. Other replicas will
synchronize from the primary node.
There is no UUID conflict between the 2 clusters:
- Storage nodes connect anonymously to upstream.
- Master node receives a new from upstream master and uses it only when
communicating with it.
"""
class BackupApplication(object):
pt = None
def __init__(self, app, name, master_addresses, connector_name):
self.app = weakref.proxy(app)
self.name = name
self.nm = NodeManager()
self.connector_handler = getConnectorHandler(connector_name)
for master_address in master_addresses:
self.nm.createMaster(address=master_address)
em = property(lambda self: self.app.em)
def close(self):
self.nm.close()
del self.__dict__
def log(self):
self.nm.log()
if self.pt is not None:
self.pt.log()
def provideService(self):
neo.lib.logging.info('provide backup')
poll = self.em.poll
app = self.app
pt = app.pt
while True:
app.changeClusterState(ClusterStates.STARTING_BACKUP)
bootstrap = BootstrapManager(self, self.name, NodeTypes.CLIENT)
# {offset -> node}
self.primary_partition_dict = {}
# [[tid]]
self.tid_list = tuple([] for _ in xrange(pt.getPartitions()))
try:
node, conn, uuid, num_partitions, num_replicas = \
bootstrap.getPrimaryConnection(self.connector_handler)
try:
app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node
if num_partitions != pt.getPartitions():
raise RuntimeError("inconsistent number of partitions")
self.pt = PartitionTable(num_partitions, num_replicas)
conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskNodeInformation())
conn.ask(Packets.AskPartitionTable())
conn.ask(Packets.AskLastTransaction())
# debug variable to log how big 'tid_list' can be.
self.debug_tid_count = 0
while True:
poll(1)
except PrimaryFailure, msg:
neo.lib.logging.error('upstream master is down: %s', msg)
finally:
app.backup_tid = pt.getBackupTid()
try:
conn.close()
except PrimaryFailure:
pass
try:
del self.pt
except AttributeError:
pass
except StateChangedException, e:
app.changeClusterState(*e.args)
last_tid = app.getLastTransaction()
if last_tid < app.backup_tid:
neo.lib.logging.warning(
"Truncating at %s (last_tid was %s)",
dump(app.backup_tid), dump(last_tid))
p = Packets.AskTruncate(app.backup_tid)
connection_list = []
for node in app.nm.getStorageList(only_identified=True):
conn = node.getConnection()
conn.ask(p)
connection_list.append(conn)
for conn in connection_list:
while conn.isPending():
poll(1)
app.setLastTransaction(app.backup_tid)
del app.backup_tid
break
finally:
del self.primary_partition_dict, self.tid_list
def nodeLost(self, node):
getCellList = self.app.pt.getCellList
trigger_set = set()
for offset, primary_node in self.primary_partition_dict.items():
if primary_node is not node:
continue
cell_list = getCellList(offset, readable=True)
cell = max(cell_list, key=lambda cell: cell.backup_tid)
tid = cell.backup_tid
self.primary_partition_dict[offset] = primary_node = cell.getNode()
p = Packets.Replicate(tid, '', {offset: primary_node.getAddress()})
for cell in cell_list:
cell.replicating = tid
if cell.backup_tid < tid:
neo.lib.logging.debug(
"ask %s to replicate partition %u up to %u from %r",
dump(cell.getUUID()), offset, u64(tid),
dump(primary_node.getUUID()))
cell.getNode().getConnection().notify(p)
trigger_set.add(primary_node)
for node in trigger_set:
self.triggerBackup(node)
def invalidatePartitions(self, tid, partition_set):
app = self.app
prev_tid = app.getLastTransaction()
app.setLastTransaction(tid)
pt = app.pt
getByUUID = app.nm.getByUUID
trigger_set = set()
for offset in xrange(pt.getPartitions()):
try:
last_max_tid = self.tid_list[offset][-1]
except IndexError:
last_max_tid = INVALID_TID
if offset in partition_set:
self.tid_list[offset].append(tid)
node_list = []
for cell in pt.getCellList(offset, readable=True):
node = cell.getNode()
assert node.isConnected()
node_list.append(node)
if last_max_tid <= cell.backup_tid:
# This is the last time we can increase
# 'backup_tid' without replication.
neo.lib.logging.debug(
"partition %u: updating backup_tid of %r to %u",
offset, cell, u64(prev_tid))
cell.backup_tid = prev_tid
assert node_list
trigger_set.update(node_list)
# Make sure we have a primary storage for this partition.
if offset not in self.primary_partition_dict:
self.primary_partition_dict[offset] = \
random.choice(node_list)
else:
# Partition not touched, so increase 'backup_tid' of all
# "up-to-date" replicas, without having to replicate.
for cell in pt.getCellList(offset, readable=True):
if last_max_tid <= cell.backup_tid:
cell.backup_tid = tid
neo.lib.logging.debug(
"partition %u: updating backup_tid of %r to %u",
offset, cell, u64(tid))
for node in trigger_set:
self.triggerBackup(node)
count = sum(map(len, self.tid_list))
if self.debug_tid_count < count:
neo.lib.logging.debug("Maximum number of tracked tids: %u", count)
self.debug_tid_count = count
def triggerBackup(self, node):
tid_list = self.tid_list
tid = self.app.getLastTransaction()
replicate_list = []
for offset, cell in self.app.pt.iterNodeCell(node):
max_tid = tid_list[offset]
if max_tid and self.primary_partition_dict[offset] is node and \
max(cell.backup_tid, cell.replicating) < max_tid[-1]:
cell.replicating = tid
replicate_list.append(offset)
if not replicate_list:
return
getByUUID = self.nm.getByUUID
getCellList = self.pt.getCellList
source_dict = {}
address_set = set()
for offset in replicate_list:
cell_list = getCellList(offset, readable=True)
random.shuffle(cell_list)
assert cell_list, offset
for cell in cell_list:
addr = cell.getAddress()
if addr in address_set:
break
else:
address_set.add(addr)
source_dict[offset] = addr
neo.lib.logging.debug(
"ask %s to replicate partition %u up to %u from %r",
dump(node.getUUID()), offset, u64(tid), addr)
node.getConnection().notify(Packets.Replicate(
tid, self.name, source_dict))
def notifyReplicationDone(self, node, offset, tid):
app = self.app
cell = app.pt.getCell(offset, node.getUUID())
tid_list = self.tid_list[offset]
if tid_list: # may be empty if the cell is out-of-date
# or if we're not fully initialized
if tid < tid_list[0]:
cell.replicating = tid
else:
try:
tid = add64(tid_list[bisect(tid_list, tid)], -1)
except IndexError:
tid = app.getLastTransaction()
neo.lib.logging.debug("partition %u: updating backup_tid of %r to %u",
offset, cell, u64(tid))
cell.backup_tid = tid
# Forget tids we won't need anymore.
cell_list = app.pt.getCellList(offset, readable=True)
del tid_list[:bisect(tid_list, min(x.backup_tid for x in cell_list))]
primary_node = self.primary_partition_dict.get(offset)
primary = primary_node is node
result = None if primary else app.pt.setUpToDate(node, offset)
if app.getClusterState() == ClusterStates.BACKINGUP:
assert not cell.isOutOfDate()
if result: # was out-of-date
max_tid, = [x.backup_tid for x in cell_list
if x.getNode() is primary_node]
if tid < max_tid:
cell.replicating = max_tid
neo.lib.logging.debug(
"ask %s to replicate partition %u up to %u from %r",
dump(node.getUUID()), offset, u64(max_tid),
dump(primary_node.getUUID()))
node.getConnection().notify(Packets.Replicate(max_tid,
'', {offset: primary_node.getAddress()}))
else:
self.triggerBackup(node)
if primary:
# Notify secondary storages that they can replicate from
# primary ones, even if they are already replicating.
p = Packets.Replicate(tid, '', {offset: node.getAddress()})
for cell in cell_list:
if max(cell.backup_tid, cell.replicating) < tid:
cell.replicating = tid
neo.lib.logging.debug(
"ask %s to replicate partition %u up to %u from"
" %r", dump(cell.getUUID()), offset, u64(tid),
dump(node.getUUID()))
cell.getNode().getConnection().notify(p)
return result
...@@ -18,15 +18,18 @@ ...@@ -18,15 +18,18 @@
import neo import neo
from . import MasterHandler from . import MasterHandler
from ..app import StateChangedException
from neo.lib.protocol import ClusterStates, NodeStates, Packets, ProtocolError from neo.lib.protocol import ClusterStates, NodeStates, Packets, ProtocolError
from neo.lib.protocol import Errors from neo.lib.protocol import Errors
from neo.lib.util import dump from neo.lib.util import dump
CLUSTER_STATE_WORKFLOW = { CLUSTER_STATE_WORKFLOW = {
# destination: sources # destination: sources
ClusterStates.VERIFYING: set([ClusterStates.RECOVERING]), ClusterStates.VERIFYING: (ClusterStates.RECOVERING,),
ClusterStates.STOPPING: set([ClusterStates.RECOVERING, ClusterStates.STARTING_BACKUP: (ClusterStates.RUNNING,
ClusterStates.VERIFYING, ClusterStates.RUNNING]), ClusterStates.STOPPING_BACKUP),
ClusterStates.STOPPING_BACKUP: (ClusterStates.BACKINGUP,
ClusterStates.STARTING_BACKUP),
} }
class AdministrationHandler(MasterHandler): class AdministrationHandler(MasterHandler):
...@@ -42,16 +45,17 @@ class AdministrationHandler(MasterHandler): ...@@ -42,16 +45,17 @@ class AdministrationHandler(MasterHandler):
conn.answer(Packets.AnswerPrimary(app.uuid, [])) conn.answer(Packets.AnswerPrimary(app.uuid, []))
def setClusterState(self, conn, state): def setClusterState(self, conn, state):
app = self.app
# check request # check request
if state not in CLUSTER_STATE_WORKFLOW: try:
if app.cluster_state not in CLUSTER_STATE_WORKFLOW[state]:
raise ProtocolError('Can not switch to this state')
except KeyError:
raise ProtocolError('Invalid state requested') raise ProtocolError('Invalid state requested')
valid_current_states = CLUSTER_STATE_WORKFLOW[state]
if self.app.cluster_state not in valid_current_states:
raise ProtocolError('Cannot switch to this state')
# change state # change state
if state == ClusterStates.VERIFYING: if state == ClusterStates.VERIFYING:
storage_list = self.app.nm.getStorageList(only_identified=True) storage_list = app.nm.getStorageList(only_identified=True)
if not storage_list: if not storage_list:
raise ProtocolError('Cannot exit recovery without any ' raise ProtocolError('Cannot exit recovery without any '
'storage node') 'storage node')
...@@ -60,15 +64,18 @@ class AdministrationHandler(MasterHandler): ...@@ -60,15 +64,18 @@ class AdministrationHandler(MasterHandler):
if node.getConnection().isPending(): if node.getConnection().isPending():
raise ProtocolError('Cannot exit recovery now: node %r is ' raise ProtocolError('Cannot exit recovery now: node %r is '
'entering cluster' % (node, )) 'entering cluster' % (node, ))
self.app._startup_allowed = True app._startup_allowed = True
else: state = app.cluster_state
self.app.changeClusterState(state) elif state == ClusterStates.STARTING_BACKUP:
if app.tm.hasPending() or app.nm.getClientList(True):
raise ProtocolError("Can not switch to %s state with pending"
" transactions or connected clients" % state)
elif state != ClusterStates.STOPPING_BACKUP:
app.changeClusterState(state)
# answer
conn.answer(Errors.Ack('Cluster state changed')) conn.answer(Errors.Ack('Cluster state changed'))
if state == ClusterStates.STOPPING: if state != app.cluster_state:
self.app.cluster_state = state raise StateChangedException(state)
self.app.shutdown()
def setNodeState(self, conn, uuid, state, modify_partition_table): def setNodeState(self, conn, uuid, state, modify_partition_table):
neo.lib.logging.info("set node state for %s-%s : %s" % neo.lib.logging.info("set node state for %s-%s : %s" %
......
##############################################################################
#
# Copyright (c) 2011 Nexedi SARL and Contributors. All Rights Reserved.
# Julien Muchembled <jm@nexedi.com>
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsibility of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# guarantees and support are strongly advised to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
##############################################################################
from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler
from neo.lib.protocol import CellStates
class BackupHandler(EventHandler):
"""Handler dedicated to upstream master during BACKINGUP state"""
def connectionLost(self, conn, new_state):
if self.app.app.listening_conn: # if running
raise PrimaryFailure('connection lost')
def answerPartitionTable(self, conn, ptid, row_list):
self.app.pt.load(ptid, row_list, self.app.nm)
def notifyPartitionChanges(self, conn, ptid, cell_list):
self.app.pt.update(ptid, cell_list, self.app.nm)
def answerNodeInformation(self, conn):
pass
def notifyNodeInformation(self, conn, node_list):
self.app.nm.update(node_list)
def answerLastTransaction(self, conn, tid):
app = self.app
app.invalidatePartitions(tid, set(xrange(app.pt.getPartitions())))
def invalidateObjects(self, conn, tid, oid_list):
app = self.app
getPartition = app.app.pt.getPartition
partition_set = set(map(getPartition, oid_list))
partition_set.add(getPartition(tid))
app.invalidatePartitions(tid, partition_set)
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib import neo.lib
from neo.lib.protocol import Packets, ProtocolError from neo.lib.protocol import ClusterStates, Packets, ProtocolError
from neo.lib.exception import OperationFailure from neo.lib.exception import OperationFailure
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.connector import ConnectorConnectionClosedException from neo.lib.connector import ConnectorConnectionClosedException
...@@ -45,14 +45,18 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -45,14 +45,18 @@ class StorageServiceHandler(BaseServiceHandler):
if not app.pt.operational(): if not app.pt.operational():
raise OperationFailure, 'cannot continue operation' raise OperationFailure, 'cannot continue operation'
app.tm.forget(conn.getUUID()) app.tm.forget(conn.getUUID())
if app.getClusterState() == ClusterStates.BACKINGUP:
app.backup_app.nodeLost(node)
if app.packing is not None: if app.packing is not None:
self.answerPack(conn, False) self.answerPack(conn, False)
def askLastIDs(self, conn): def askLastIDs(self, conn):
app = self.app app = self.app
loid = app.tm.getLastOID() conn.answer(Packets.AnswerLastIDs(
ltid = app.tm.getLastTID() app.tm.getLastOID(),
conn.answer(Packets.AnswerLastIDs(loid, ltid, app.pt.getID())) app.tm.getLastTID(),
app.pt.getID(),
app.backup_tid))
def askUnfinishedTransactions(self, conn): def askUnfinishedTransactions(self, conn):
tm = self.app.tm tm = self.app.tm
...@@ -68,15 +72,26 @@ class StorageServiceHandler(BaseServiceHandler): ...@@ -68,15 +72,26 @@ class StorageServiceHandler(BaseServiceHandler):
# transaction locked on this storage node # transaction locked on this storage node
self.app.tm.lock(ttid, conn.getUUID()) self.app.tm.lock(ttid, conn.getUUID())
def notifyReplicationDone(self, conn, offset): def notifyReplicationDone(self, conn, offset, tid):
node = self.app.nm.getByUUID(conn.getUUID()) app = self.app
neo.lib.logging.debug("%s is up for offset %s" % (node, offset)) node = app.nm.getByUUID(conn.getUUID())
if app.backup_tid:
cell_list = app.backup_app.notifyReplicationDone(node, offset, tid)
if not cell_list:
return
else:
try: try:
cell_list = self.app.pt.setUpToDate(node, offset) cell_list = self.app.pt.setUpToDate(node, offset)
if not cell_list:
raise ProtocolError('Non-oudated partition')
except PartitionTableException, e: except PartitionTableException, e:
raise ProtocolError(str(e)) raise ProtocolError(str(e))
neo.lib.logging.debug("%s is up for offset %s", node, offset)
self.app.broadcastPartitionChanges(cell_list) self.app.broadcastPartitionChanges(cell_list)
def answerTruncate(self, conn):
pass
def answerPack(self, conn, status): def answerPack(self, conn, status):
app = self.app app = self.app
if app.packing is not None: if app.packing is not None:
......
...@@ -17,11 +17,25 @@ ...@@ -17,11 +17,25 @@
import neo.lib.pt import neo.lib.pt
from struct import pack, unpack from struct import pack, unpack
from neo.lib.protocol import CellStates from neo.lib.protocol import CellStates, ZERO_TID
from neo.lib.pt import PartitionTableException
from neo.lib.pt import PartitionTable
class PartitionTable(PartitionTable):
class Cell(neo.lib.pt.Cell):
replicating = ZERO_TID
def setState(self, state):
try:
if CellStates.OUT_OF_DATE == state != self.state:
del self.backup_tid, self.replicating
except AttributeError:
pass
return super(Cell, self).setState(state)
neo.lib.pt.Cell = Cell
class PartitionTable(neo.lib.pt.PartitionTable):
"""This class manages a partition table for the primary master node""" """This class manages a partition table for the primary master node"""
def setID(self, id): def setID(self, id):
...@@ -54,7 +68,7 @@ class PartitionTable(PartitionTable): ...@@ -54,7 +68,7 @@ class PartitionTable(PartitionTable):
row = [] row = []
for _ in xrange(repeats): for _ in xrange(repeats):
node = node_list[index] node = node_list[index]
row.append(neo.lib.pt.Cell(node)) row.append(Cell(node))
self.count_dict[node] = self.count_dict.get(node, 0) + 1 self.count_dict[node] = self.count_dict.get(node, 0) + 1
index += 1 index += 1
if index == len(node_list): if index == len(node_list):
...@@ -88,7 +102,7 @@ class PartitionTable(PartitionTable): ...@@ -88,7 +102,7 @@ class PartitionTable(PartitionTable):
node_list = [c.getNode() for c in row] node_list = [c.getNode() for c in row]
n = self.findLeastUsedNode(node_list) n = self.findLeastUsedNode(node_list)
if n is not None: if n is not None:
row.append(neo.lib.pt.Cell(n, row.append(Cell(n,
CellStates.OUT_OF_DATE)) CellStates.OUT_OF_DATE))
self.count_dict[n] += 1 self.count_dict[n] += 1
cell_list.append((offset, n.getUUID(), cell_list.append((offset, n.getUUID(),
...@@ -132,11 +146,11 @@ class PartitionTable(PartitionTable): ...@@ -132,11 +146,11 @@ class PartitionTable(PartitionTable):
# check the partition is assigned and known as outdated # check the partition is assigned and known as outdated
for cell in self.getCellList(offset): for cell in self.getCellList(offset):
if cell.getUUID() == uuid: if cell.getUUID() == uuid:
if not cell.isOutOfDate(): if cell.isOutOfDate():
raise PartitionTableException('Non-oudated partition')
break break
return
else: else:
raise PartitionTableException('Non-assigned partition') raise neo.lib.pt.PartitionTableException('Non-assigned partition')
# update the partition table # update the partition table
cell_list = [self.setCell(offset, node, CellStates.UP_TO_DATE)] cell_list = [self.setCell(offset, node, CellStates.UP_TO_DATE)]
...@@ -177,7 +191,7 @@ class PartitionTable(PartitionTable): ...@@ -177,7 +191,7 @@ class PartitionTable(PartitionTable):
else: else:
if num_cells <= self.nr: if num_cells <= self.nr:
row.append(neo.lib.pt.Cell(node, CellStates.OUT_OF_DATE)) row.append(Cell(node, CellStates.OUT_OF_DATE))
cell_list.append((offset, node.getUUID(), cell_list.append((offset, node.getUUID(),
CellStates.OUT_OF_DATE)) CellStates.OUT_OF_DATE))
node_count += 1 node_count += 1
...@@ -196,7 +210,7 @@ class PartitionTable(PartitionTable): ...@@ -196,7 +210,7 @@ class PartitionTable(PartitionTable):
CellStates.FEEDING)) CellStates.FEEDING))
# Don't count a feeding cell. # Don't count a feeding cell.
self.count_dict[max_cell.getNode()] -= 1 self.count_dict[max_cell.getNode()] -= 1
row.append(neo.lib.pt.Cell(node, CellStates.OUT_OF_DATE)) row.append(Cell(node, CellStates.OUT_OF_DATE))
cell_list.append((offset, node.getUUID(), cell_list.append((offset, node.getUUID(),
CellStates.OUT_OF_DATE)) CellStates.OUT_OF_DATE))
node_count += 1 node_count += 1
...@@ -277,7 +291,7 @@ class PartitionTable(PartitionTable): ...@@ -277,7 +291,7 @@ class PartitionTable(PartitionTable):
node = self.findLeastUsedNode([cell.getNode() for cell in row]) node = self.findLeastUsedNode([cell.getNode() for cell in row])
if node is None: if node is None:
break break
row.append(neo.lib.pt.Cell(node, CellStates.OUT_OF_DATE)) row.append(Cell(node, CellStates.OUT_OF_DATE))
changed_cell_list.append((offset, node.getUUID(), changed_cell_list.append((offset, node.getUUID(),
CellStates.OUT_OF_DATE)) CellStates.OUT_OF_DATE))
self.count_dict[node] += 1 self.count_dict[node] += 1
...@@ -309,6 +323,13 @@ class PartitionTable(PartitionTable): ...@@ -309,6 +323,13 @@ class PartitionTable(PartitionTable):
CellStates.OUT_OF_DATE)) CellStates.OUT_OF_DATE))
return change_list return change_list
def iterNodeCell(self, node):
for offset, row in enumerate(self.partition_list):
for cell in row:
if cell.getNode() is node:
yield offset, cell
break
def getUpToDateCellNodeSet(self): def getUpToDateCellNodeSet(self):
""" """
Return a set of all nodes which are part of at least one UP TO DATE Return a set of all nodes which are part of at least one UP TO DATE
...@@ -329,3 +350,16 @@ class PartitionTable(PartitionTable): ...@@ -329,3 +350,16 @@ class PartitionTable(PartitionTable):
for cell in row for cell in row
if cell.isOutOfDate()) if cell.isOutOfDate())
def setBackupTidDict(self, backup_tid_dict):
for row in self.partition_list:
for cell in row:
cell.backup_tid = backup_tid_dict.get(cell.getUUID(),
ZERO_TID)
def getBackupTid(self):
try:
return min(max(cell.backup_tid for cell in row
if not cell.isOutOfDate())
for row in self.partition_list)
except ValueError:
return ZERO_TID
...@@ -33,6 +33,7 @@ class RecoveryManager(MasterHandler): ...@@ -33,6 +33,7 @@ class RecoveryManager(MasterHandler):
super(RecoveryManager, self).__init__(app) super(RecoveryManager, self).__init__(app)
# The target node's uuid to request next. # The target node's uuid to request next.
self.target_ptid = None self.target_ptid = None
self.backup_tid_dict = {}
def getHandler(self): def getHandler(self):
return self return self
...@@ -98,6 +99,9 @@ class RecoveryManager(MasterHandler): ...@@ -98,6 +99,9 @@ class RecoveryManager(MasterHandler):
app.tm.setLastOID(ZERO_OID) app.tm.setLastOID(ZERO_OID)
pt.make(allowed_node_set) pt.make(allowed_node_set)
self._broadcastPartitionTable(pt.getID(), pt.getRowList()) self._broadcastPartitionTable(pt.getID(), pt.getRowList())
elif app.backup_tid:
pt.setBackupTidDict(self.backup_tid_dict)
app.backup_tid = pt.getBackupTid()
app.setLastTransaction(app.tm.getLastTID()) app.setLastTransaction(app.tm.getLastTID())
neo.lib.logging.debug( neo.lib.logging.debug(
...@@ -118,7 +122,7 @@ class RecoveryManager(MasterHandler): ...@@ -118,7 +122,7 @@ class RecoveryManager(MasterHandler):
# ask the last IDs to perform the recovery # ask the last IDs to perform the recovery
conn.ask(Packets.AskLastIDs()) conn.ask(Packets.AskLastIDs())
def answerLastIDs(self, conn, loid, ltid, lptid): def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid):
# Get max values. # Get max values.
if loid is not None: if loid is not None:
self.app.tm.setLastOID(loid) self.app.tm.setLastOID(loid)
...@@ -128,6 +132,7 @@ class RecoveryManager(MasterHandler): ...@@ -128,6 +132,7 @@ class RecoveryManager(MasterHandler):
# something newer # something newer
self.target_ptid = lptid self.target_ptid = lptid
conn.ask(Packets.AskPartitionTable()) conn.ask(Packets.AskPartitionTable())
self.backup_tid_dict[conn.getUUID()] = backup_tid
def answerPartitionTable(self, conn, ptid, row_list): def answerPartitionTable(self, conn, ptid, row_list):
if ptid != self.target_ptid: if ptid != self.target_ptid:
...@@ -136,6 +141,7 @@ class RecoveryManager(MasterHandler): ...@@ -136,6 +141,7 @@ class RecoveryManager(MasterHandler):
dump(self.target_ptid)) dump(self.target_ptid))
else: else:
self._broadcastPartitionTable(ptid, row_list) self._broadcastPartitionTable(ptid, row_list)
self.app.backup_tid = self.backup_tid_dict[conn.getUUID()]
def _broadcastPartitionTable(self, ptid, row_list): def _broadcastPartitionTable(self, ptid, row_list):
try: try:
......
...@@ -113,19 +113,21 @@ class VerificationManager(BaseServiceHandler): ...@@ -113,19 +113,21 @@ class VerificationManager(BaseServiceHandler):
def verifyData(self): def verifyData(self):
"""Verify the data in storage nodes and clean them up, if necessary.""" """Verify the data in storage nodes and clean them up, if necessary."""
app = self.app
em, nm = self.app.em, self.app.nm
# wait for any missing node # wait for any missing node
neo.lib.logging.debug('waiting for the cluster to be operational') neo.lib.logging.debug('waiting for the cluster to be operational')
while not self.app.pt.operational(): while not app.pt.operational():
em.poll(1) app.em.poll(1)
if app.backup_tid:
return
neo.lib.logging.info('start to verify data') neo.lib.logging.info('start to verify data')
getIdentifiedList = app.nm.getIdentifiedList
# Gather all unfinished transactions. # Gather all unfinished transactions.
self._askStorageNodesAndWait(Packets.AskUnfinishedTransactions(), self._askStorageNodesAndWait(Packets.AskUnfinishedTransactions(),
[x for x in self.app.nm.getIdentifiedList() if x.isStorage()]) [x for x in getIdentifiedList() if x.isStorage()])
# Gather OIDs for each unfinished TID, and verify whether the # Gather OIDs for each unfinished TID, and verify whether the
# transaction can be finished or must be aborted. This could be # transaction can be finished or must be aborted. This could be
...@@ -136,17 +138,16 @@ class VerificationManager(BaseServiceHandler): ...@@ -136,17 +138,16 @@ class VerificationManager(BaseServiceHandler):
if uuid_set is None: if uuid_set is None:
packet = Packets.DeleteTransaction(tid, self._oid_set or []) packet = Packets.DeleteTransaction(tid, self._oid_set or [])
# Make sure that no node has this transaction. # Make sure that no node has this transaction.
for node in self.app.nm.getIdentifiedList(): for node in getIdentifiedList():
if node.isStorage(): if node.isStorage():
node.notify(packet) node.notify(packet)
else: else:
packet = Packets.CommitTransaction(tid) packet = Packets.CommitTransaction(tid)
for node in self.app.nm.getIdentifiedList(pool_set=uuid_set): for node in getIdentifiedList(pool_set=uuid_set):
node.notify(packet) node.notify(packet)
self._oid_set = set() self._oid_set = set()
# If possible, send the packets now. # If possible, send the packets now.
em.poll(0) app.em.poll(0)
def verifyTransaction(self, tid): def verifyTransaction(self, tid):
em = self.app.em em = self.app.em
...@@ -189,11 +190,11 @@ class VerificationManager(BaseServiceHandler): ...@@ -189,11 +190,11 @@ class VerificationManager(BaseServiceHandler):
return uuid_set return uuid_set
def answerLastIDs(self, conn, loid, ltid, lptid): def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid):
# FIXME: this packet should not allowed here, the master already # FIXME: this packet should not allowed here, the master already
# accepted the current partition table end IDs. As there were manually # accepted the current partition table end IDs. As there were manually
# approved during recovery, there is no need to check them here. # approved during recovery, there is no need to check them here.
pass raise RuntimeError
def answerUnfinishedTransactions(self, conn, max_tid, tid_list): def answerUnfinishedTransactions(self, conn, max_tid, tid_list):
uuid = conn.getUUID() uuid = conn.getUUID()
......
...@@ -54,15 +54,10 @@ UNIT_TEST_MODULES = [ ...@@ -54,15 +54,10 @@ UNIT_TEST_MODULES = [
'neo.tests.storage.testInitializationHandler', 'neo.tests.storage.testInitializationHandler',
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp', 'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorageHandler', 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testStorageMySQLdb',
'neo.tests.storage.testStorageBTree',
'neo.tests.storage.testVerificationHandler', 'neo.tests.storage.testVerificationHandler',
'neo.tests.storage.testIdentificationHandler', 'neo.tests.storage.testIdentificationHandler',
'neo.tests.storage.testTransactions', 'neo.tests.storage.testTransactions',
'neo.tests.storage.testReplicationHandler',
'neo.tests.storage.testReplicator',
'neo.tests.storage.testReplication',
# client application # client application
'neo.tests.client.testClientApp', 'neo.tests.client.testClientApp',
'neo.tests.client.testMasterHandler', 'neo.tests.client.testMasterHandler',
...@@ -70,6 +65,7 @@ UNIT_TEST_MODULES = [ ...@@ -70,6 +65,7 @@ UNIT_TEST_MODULES = [
'neo.tests.client.testConnectionPool', 'neo.tests.client.testConnectionPool',
# light functional tests # light functional tests
'neo.tests.threaded.test', 'neo.tests.threaded.test',
'neo.tests.threaded.testReplication',
] ]
FUNC_TEST_MODULES = [ FUNC_TEST_MODULES = [
......
...@@ -113,28 +113,21 @@ class Application(object): ...@@ -113,28 +113,21 @@ class Application(object):
"""Load persistent configuration data from the database. """Load persistent configuration data from the database.
If data is not present, generate it.""" If data is not present, generate it."""
def NoneOnKeyError(getter):
try:
return getter()
except KeyError:
return None
dm = self.dm dm = self.dm
# check cluster name # check cluster name
try: name = dm.getName()
dm_name = dm.getName() if name is None:
except KeyError:
dm.setName(self.name) dm.setName(self.name)
else: elif name != self.name:
if dm_name != self.name: raise RuntimeError('name %r does not match with the database: %r'
raise RuntimeError('name %r does not match with the ' % (self.name, dm_name))
'database: %r' % (self.name, dm_name))
# load configuration # load configuration
self.uuid = NoneOnKeyError(dm.getUUID) self.uuid = dm.getUUID()
num_partitions = NoneOnKeyError(dm.getNumPartitions) num_partitions = dm.getNumPartitions()
num_replicas = NoneOnKeyError(dm.getNumReplicas) num_replicas = dm.getNumReplicas()
ptid = NoneOnKeyError(dm.getPTID) ptid = dm.getPTID()
# check partition table configuration # check partition table configuration
if num_partitions is not None and num_replicas is not None: if num_partitions is not None and num_replicas is not None:
...@@ -152,10 +145,7 @@ class Application(object): ...@@ -152,10 +145,7 @@ class Application(object):
def loadPartitionTable(self): def loadPartitionTable(self):
"""Load a partition table from the database.""" """Load a partition table from the database."""
try:
ptid = self.dm.getPTID() ptid = self.dm.getPTID()
except KeyError:
ptid = None
cell_list = self.dm.getPartitionTable() cell_list = self.dm.getPartitionTable()
new_cell_list = [] new_cell_list = []
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
...@@ -216,9 +206,7 @@ class Application(object): ...@@ -216,9 +206,7 @@ class Application(object):
except OperationFailure, msg: except OperationFailure, msg:
neo.lib.logging.error('operation stopped: %s', msg) neo.lib.logging.error('operation stopped: %s', msg)
except PrimaryFailure, msg: except PrimaryFailure, msg:
self.replicator.masterLost()
neo.lib.logging.error('primary master is down: %s', msg) neo.lib.logging.error('primary master is down: %s', msg)
self.master_node = None
def connectToPrimary(self): def connectToPrimary(self):
"""Find a primary master node, and connect to it. """Find a primary master node, and connect to it.
...@@ -296,6 +284,7 @@ class Application(object): ...@@ -296,6 +284,7 @@ class Application(object):
neo.lib.logging.info('doing operation') neo.lib.logging.info('doing operation')
_poll = self._poll _poll = self._poll
isIdle = self.em.isIdle
handler = master.MasterOperationHandler(self) handler = master.MasterOperationHandler(self)
self.master_conn.setHandler(handler) self.master_conn.setHandler(handler)
...@@ -304,16 +293,21 @@ class Application(object): ...@@ -304,16 +293,21 @@ class Application(object):
self.dm.dropUnfinishedData() self.dm.dropUnfinishedData()
self.tm.reset() self.tm.reset()
self.task_queue = task_queue = deque()
try:
while True: while True:
while task_queue and isIdle():
try:
task_queue[-1].next()
task_queue.rotate()
except StopIteration:
task_queue.pop()
_poll() _poll()
if self.replicator.pending(): finally:
# Call processDelayedTasks before act, so tasks added in the del self.task_queue
# act call are executed after one poll call, so that sent # Abort any replication, whether we are feeding or out-of-date.
# packets are already on the network and delayed task for node in self.nm.getStorageList(only_identified=True):
# processing happens in parallel with the same task on the node.getConnection().close()
# other storage node.
self.replicator.processDelayedTasks()
self.replicator.act()
def wait(self): def wait(self):
# change handler # change handler
...@@ -368,6 +362,13 @@ class Application(object): ...@@ -368,6 +362,13 @@ class Application(object):
neo.lib.logging.info(' %r:%r: %r:%r %r %r', key, event.__name__, neo.lib.logging.info(' %r:%r: %r:%r %r %r', key, event.__name__,
_msg_id, _conn, args) _msg_id, _conn, args)
def newTask(self, iterator):
try:
iterator.next()
except StopIteration:
return
self.task_queue.appendleft(iterator)
def shutdown(self, erase=False): def shutdown(self, erase=False):
"""Close all connections and exit""" """Close all connections and exit"""
for c in self.em.getConnectionList(): for c in self.em.getConnectionList():
......
...@@ -15,10 +15,13 @@ ...@@ -15,10 +15,13 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
LOG_QUERIES = False
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
from .manager import DatabaseManager from .manager import DatabaseManager
from .sqlite import SQLiteDatabaseManager
DATABASE_MANAGER_DICT = {} DATABASE_MANAGER_DICT = {'SQLite': SQLiteDatabaseManager}
try: try:
from .mysqldb import MySQLDatabaseManager from .mysqldb import MySQLDatabaseManager
...@@ -27,17 +30,6 @@ except ImportError: ...@@ -27,17 +30,6 @@ except ImportError:
else: else:
DATABASE_MANAGER_DICT['MySQL'] = MySQLDatabaseManager DATABASE_MANAGER_DICT['MySQL'] = MySQLDatabaseManager
try:
from .btree import BTreeDatabaseManager
except ImportError:
pass
else:
# XXX: warning: name might change in the future.
DATABASE_MANAGER_DICT['BTree'] = BTreeDatabaseManager
if not DATABASE_MANAGER_DICT:
raise ImportError('No database back-end available.')
def buildDatabaseManager(name, args=(), kw={}): def buildDatabaseManager(name, args=(), kw={}):
if name is None: if name is None:
name = DATABASE_MANAGER_DICT.keys()[0] name = DATABASE_MANAGER_DICT.keys()[0]
......
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""
Naive b-tree implementation.
Simple, though not so well tested.
Not persistent ! (no data retained after process exit)
"""
from BTrees.OOBTree import OOBTree as _OOBTree
import neo.lib
from hashlib import sha1
from . import DatabaseManager
from .manager import CreationUndone
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID
from neo.lib import util
# Keep dropped trees in memory to avoid instanciating when not needed.
TREE_POOL = []
# How many empty BTree istance to keep in ram
MAX_TREE_POOL_SIZE = 100
def batchDelete(tree, tester_callback=None, deleter_callback=None, **kw):
"""
Iter over given BTree and delete found entries.
tree BTree
Tree to delete entries from.
tester_callback function(key, value) -> boolean
Called with each key, value pair found in tree.
If return value is true, delete entry. Otherwise, skip to next key.
deleter_callback function(tree, key_list) -> None (None)
Custom function to delete items
**kw
Keyword arguments for tree.items .
"""
if tester_callback is None:
key_list = list(safeIter(tree.iterkeys, **kw))
else:
key_list = [key for key, value in safeIter(tree.iteritems, **kw)
if tester_callback(key, value)]
if deleter_callback is None:
for key in key_list:
del tree[key]
else:
deleter_callback(tree, key_list)
def OOBTree():
try:
result = TREE_POOL.pop()
except IndexError:
result = _OOBTree()
# Next btree we prune will have room, restore prune method
global prune
prune = _prune
return result
def _prune(tree):
tree.clear()
TREE_POOL.append(tree)
if len(TREE_POOL) >= MAX_TREE_POOL_SIZE:
# Already at/above max pool size, disable ourselve.
global prune
prune = _noPrune
def _noPrune(_):
pass
prune = _prune
def iterObjSerials(obj):
for tserial in obj.values():
for serial in tserial.keys():
yield serial
def descItems(tree):
try:
key = tree.maxKey()
except ValueError:
pass
else:
while True:
yield (key, tree[key])
try:
key = tree.maxKey(key - 1)
except ValueError:
break
def descKeys(tree):
try:
key = tree.maxKey()
except ValueError:
pass
else:
while True:
yield key
try:
key = tree.maxKey(key - 1)
except ValueError:
break
def safeIter(func, *args, **kw):
try:
some_list = func(*args, **kw)
except ValueError:
some_list = []
return some_list
class BTreeDatabaseManager(DatabaseManager):
def __init__(self, database, wait):
super(BTreeDatabaseManager, self).__init__(database, wait)
self.setup(reset=1)
@property
def _num_partitions(self):
return self._config['partitions']
def setup(self, reset=0):
if reset:
self._data = OOBTree()
self._obj = OOBTree()
self._trans = OOBTree()
self._tobj = OOBTree()
self._ttrans = OOBTree()
self._pt = {}
self._config = {}
self._uncommitted_data = {}
def _begin(self):
pass
def _commit(self):
pass
def _rollback(self):
pass
def getConfiguration(self, key):
return self._config[key]
def _setConfiguration(self, key, value):
self._config[key] = value
def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid)
def _getPackTID(self):
try:
result = int(self.getConfiguration('_pack_tid'))
except KeyError:
result = -1
return result
def getPartitionTable(self):
pt = []
append = pt.append
for (offset, uuid), state in self._pt.iteritems():
append((offset, uuid, state))
return pt
def getLastTID(self, all=True):
try:
ltid = self._trans.maxKey()
except ValueError:
ltid = None
if all:
try:
tmp_ltid = self._ttrans.maxKey()
except ValueError:
tmp_ltid = None
tmp_serial = None
for tserial in self._tobj.values():
try:
max_tmp_serial = tserial.maxKey()
except ValueError:
pass
else:
tmp_serial = max(tmp_serial, max_tmp_serial)
ltid = max(ltid, tmp_ltid, tmp_serial)
if ltid is not None:
ltid = util.p64(ltid)
return ltid
def getUnfinishedTIDList(self):
p64 = util.p64
tid_set = set(p64(x) for x in self._ttrans.keys())
tid_set.update(p64(x) for x in iterObjSerials(self._tobj))
return list(tid_set)
def objectPresent(self, oid, tid, all=True):
u64 = util.u64
oid = u64(oid)
tid = u64(tid)
try:
result = self._obj[oid].has_key(tid)
except KeyError:
if all:
try:
result = self._tobj[oid].has_key(tid)
except KeyError:
result = False
else:
result = False
return result
def _getObject(self, oid, tid=None, before_tid=None):
tserial = self._obj.get(oid)
if tserial is not None:
if tid is None:
try:
if before_tid is None:
tid = tserial.maxKey()
else:
tid = tserial.maxKey(before_tid - 1)
except ValueError:
return False
try:
checksum, value_serial = tserial[tid]
except KeyError:
return False
try:
next_serial = tserial.minKey(tid + 1)
except ValueError:
next_serial = None
if checksum is None:
compression = data = None
else:
compression, data, _ = self._data[checksum]
return tid, next_serial, compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset):
pt = self._pt
if reset:
pt.clear()
for offset, uuid, state in cell_list:
# TODO: this logic should move out of database manager
# add 'dropCells(cell_list)' to API and use one query
key = (offset, uuid)
if state == CellStates.DISCARDED:
pt.pop(key, None)
else:
pt[key] = int(state)
self.setPTID(ptid)
def changePartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, False)
def setPartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, True)
def _oidDeleterCallback(self, oid):
data = self._data
uncommitted_data = self._uncommitted_data
def deleter_callback(tree, key_list):
for tid in key_list:
checksum = tree.pop(tid)[0]
if checksum:
index = data[checksum][2]
index.remove((oid, tid))
if not index and checksum not in uncommitted_data:
del data[checksum]
return deleter_callback
def _objDeleterCallback(self, tree, key_list):
data = self._data
checksum_list = []
checksum_set = set()
for oid in key_list:
tserial = tree.pop(oid)
for tid, (checksum, _) in tserial.items():
if checksum:
index = data[checksum][2]
try:
index.remove((oid, tid))
except KeyError: # _tobj
checksum_list.append(checksum)
checksum_set.add(checksum)
prune(tserial)
self.unlockData(checksum_list)
self._pruneData(checksum_set)
def dropPartitions(self, offset_list):
offset_list = frozenset(offset_list)
num_partitions = self._num_partitions
def same_partition(key, _):
return key % num_partitions in offset_list
batchDelete(self._obj, same_partition, self._objDeleterCallback)
batchDelete(self._trans, same_partition)
def dropUnfinishedData(self):
batchDelete(self._tobj, deleter_callback=self._objDeleterCallback)
self._ttrans.clear()
def storeTransaction(self, tid, object_list, transaction, temporary=True):
u64 = util.u64
tid = u64(tid)
if temporary:
obj = self._tobj
trans = self._ttrans
else:
obj = self._obj
trans = self._trans
data = self._data
for oid, checksum, value_serial in object_list:
oid = u64(oid)
if value_serial:
value_serial = u64(value_serial)
checksum = self._obj[oid][value_serial][0]
if temporary:
self.storeData(checksum)
if checksum:
if not temporary:
data[checksum][2].add((oid, tid))
try:
tserial = obj[oid]
except KeyError:
tserial = obj[oid] = OOBTree()
tserial[tid] = checksum, value_serial
if transaction is not None:
oid_list, user, desc, ext, packed = transaction
trans[tid] = (tuple(oid_list), user, desc, ext, packed)
def _pruneData(self, checksum_list):
data = self._data
for checksum in set(checksum_list).difference(self._uncommitted_data):
if not data[checksum][2]:
del data[checksum]
def _storeData(self, checksum, data, compression):
try:
if self._data[checksum][:2] != (compression, data):
raise AssertionError("hash collision")
except KeyError:
self._data[checksum] = compression, data, set()
return checksum
def finishTransaction(self, tid):
tid = util.u64(tid)
self._popTransactionFromTObj(tid, True)
ttrans = self._ttrans
try:
data = ttrans[tid]
except KeyError:
pass
else:
del ttrans[tid]
self._trans[tid] = data
def _popTransactionFromTObj(self, tid, to_obj):
checksum_list = []
if to_obj:
deleter_callback = None
obj = self._obj
def callback(oid, data):
try:
tserial = obj[oid]
except KeyError:
tserial = obj[oid] = OOBTree()
tserial[tid] = data
checksum = data[0]
if checksum:
self._data[checksum][2].add((oid, tid))
checksum_list.append(checksum)
else:
deleter_callback = self._objDeleterCallback
callback = lambda oid, data: None
def tester_callback(oid, tserial):
try:
data = tserial[tid]
except KeyError:
pass
else:
del tserial[tid]
callback(oid, data)
return not tserial
batchDelete(self._tobj, tester_callback, deleter_callback)
self.unlockData(checksum_list)
def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64
tid = u64(tid)
self._popTransactionFromTObj(tid, False)
try:
del self._ttrans[tid]
except KeyError:
pass
for oid in oid_list:
self._deleteObject(u64(oid), tid)
try:
del self._trans[tid]
except KeyError:
pass
def deleteTransactionsAbove(self, partition, tid, max_tid):
num_partitions = self._num_partitions
def same_partition(key, _):
return key % num_partitions == partition
batchDelete(self._trans, same_partition,
min=util.u64(tid), max=util.u64(max_tid))
def deleteObject(self, oid, serial=None):
u64 = util.u64
self._deleteObject(u64(oid), serial and u64(serial))
def _deleteObject(self, oid, serial=None):
obj = self._obj
try:
tserial = obj[oid]
except KeyError:
return
batchDelete(tserial, deleter_callback=self._oidDeleterCallback(oid),
min=serial, max=serial)
if not tserial:
del obj[oid]
def deleteObjectsAbove(self, partition, oid, serial, max_tid):
obj = self._obj
u64 = util.u64
oid = u64(oid)
serial = u64(serial)
max_tid = u64(max_tid)
num_partitions = self._num_partitions
if oid % num_partitions == partition:
try:
tserial = obj[oid]
except KeyError:
pass
else:
batchDelete(tserial, min=serial, max=max_tid,
deleter_callback=self._oidDeleterCallback(oid))
if not tserial:
del tserial[oid]
def same_partition(key, _):
return key % num_partitions == partition
batchDelete(obj, same_partition, self._objDeleterCallback,
min=oid, excludemin=True, max=max_tid)
def getTransaction(self, tid, all=False):
tid = util.u64(tid)
try:
result = self._trans[tid]
except KeyError:
if all:
try:
result = self._ttrans[tid]
except KeyError:
result = None
else:
result = None
if result is not None:
oid_list, user, desc, ext, packed = result
result = (list(oid_list), user, desc, ext, packed)
return result
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
checksum, value_serial = self._obj[oid][value_serial]
if checksum is None:
neo.lib.logging.info("Multiple levels of indirection when " \
"searching for object data for oid %d at tid %d. This " \
"causes suboptimal performance." % (oid, value_serial))
return self._getObjectLength(oid, value_serial)
return len(self._data[checksum][1])
def getObjectHistory(self, oid, offset=0, length=1):
# FIXME: This method doesn't take client's current ransaction id as
# parameter, which means it can return transactions in the future of
# client's transaction.
oid = util.u64(oid)
p64 = util.p64
pack_tid = self._getPackTID()
try:
tserial = self._obj[oid]
except KeyError:
result = None
else:
result = []
append = result.append
tserial_iter = descItems(tserial)
while offset > 0:
tserial_iter.next()
offset -= 1
data = self._data
for serial, (checksum, value_serial) in tserial_iter:
if length == 0 or serial < pack_tid:
break
length -= 1
if checksum is None:
try:
data_length = self._getObjectLength(oid, value_serial)
except CreationUndone:
data_length = 0
else:
data_length = len(data[checksum][1])
append((p64(serial), data_length))
if not result:
result = None
return result
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
partition):
u64 = util.u64
p64 = util.p64
min_oid = u64(min_oid)
min_serial = u64(min_serial)
max_serial = u64(max_serial)
result = {}
num_partitions = self._num_partitions
for oid, tserial in safeIter(self._obj.items, min=min_oid):
if oid % num_partitions == partition:
if length == 0:
break
if oid == min_oid:
try:
tid_seq = tserial.keys(min=min_serial, max=max_serial)
except ValueError:
continue
else:
tid_seq = tserial.keys(max=max_serial)
if not tid_seq:
continue
result[p64(oid)] = tid_list = []
append = tid_list.append
for tid in tid_seq:
if length == 0:
break
length -= 1
append(p64(tid))
else:
continue
break
return result
def getTIDList(self, offset, length, partition_list):
p64 = util.p64
partition_list = frozenset(partition_list)
result = []
append = result.append
trans_iter = descKeys(self._trans)
num_partitions = self._num_partitions
while offset > 0:
tid = trans_iter.next()
if tid % num_partitions in partition_list:
offset -= 1
for tid in trans_iter:
if tid % num_partitions in partition_list:
if length == 0:
break
length -= 1
append(p64(tid))
return result
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
p64 = util.p64
u64 = util.u64
result = []
append = result.append
num_partitions = self._num_partitions
for tid in safeIter(self._trans.keys, min=u64(min_tid), max=u64(max_tid)):
if tid % num_partitions == partition:
if length == 0:
break
length -= 1
append(p64(tid))
return result
def _updatePackFuture(self, oid, orig_serial, max_serial):
# Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location.
new_serial = None
obj = self._obj
for tree in (obj, self._tobj):
try:
tserial = tree[oid]
except KeyError:
continue
for serial, (checksum, value_serial) in tserial.iteritems(
min=max_serial):
if value_serial == orig_serial:
tserial[serial] = checksum, new_serial
if not new_serial:
new_serial = serial
return new_serial
def pack(self, tid, updateObjectDataForPack):
p64 = util.p64
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
self._setPackTID(tid)
def obj_callback(oid, tserial):
try:
max_serial = tserial.maxKey(tid)
except ValueError:
# No entry before pack TID, nothing to pack on this object.
pass
else:
if tserial[max_serial][0] is None:
# Last version before/at pack TID is a creation undo, drop
# it too.
max_serial += 1
def serial_callback(serial, value):
new_serial = updatePackFuture(oid, serial, max_serial)
if new_serial:
new_serial = p64(new_serial)
updateObjectDataForPack(p64(oid), p64(serial),
new_serial, value[0])
batchDelete(tserial, serial_callback,
self._oidDeleterCallback(oid),
max=max_serial, excludemax=True)
return not tserial
batchDelete(self._obj, obj_callback, self._objDeleterCallback)
def checkTIDRange(self, min_tid, max_tid, length, partition):
if length:
tid_list = []
num_partitions = self._num_partitions
for tid in safeIter(self._trans.keys, min=util.u64(min_tid),
max=util.u64(max_tid)):
if tid % num_partitions == partition:
tid_list.append(tid)
if len(tid_list) >= length:
break
if tid_list:
return (len(tid_list),
sha1(','.join(map(str, tid_list))).digest(),
util.p64(tid_list[-1]))
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
if length:
u64 = util.u64
min_oid = u64(min_oid)
max_tid = u64(max_tid)
oid_list = []
serial_list = []
num_partitions = self._num_partitions
for oid, tserial in safeIter(self._obj.items, min=min_oid):
if oid % num_partitions == partition:
try:
if oid == min_oid:
tserial = tserial.keys(min=u64(min_serial),
max=max_tid)
else:
tserial = tserial.keys(max=max_tid)
except ValueError:
continue
for serial in tserial:
oid_list.append(oid)
serial_list.append(serial)
if len(oid_list) >= length:
break
else:
continue
break
if oid_list:
p64 = util.p64
return (len(oid_list),
sha1(','.join(map(str, oid_list))).digest(),
p64(oid_list[-1]),
sha1(','.join(map(str, serial_list))).digest(),
p64(serial_list[-1]))
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import neo.lib import neo.lib
from neo.lib import util from neo.lib import util
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import ZERO_TID
class CreationUndone(Exception): class CreationUndone(Exception):
pass pass
...@@ -37,34 +38,6 @@ class DatabaseManager(object): ...@@ -37,34 +38,6 @@ class DatabaseManager(object):
"""Called during instanciation, to process database parameter.""" """Called during instanciation, to process database parameter."""
pass pass
def isUnderTransaction(self):
return self._under_transaction
def begin(self):
"""
Begin a transaction
"""
if self._under_transaction:
raise DatabaseFailure('A transaction has already begun')
self._begin()
self._under_transaction = True
def commit(self):
"""
Commit the current transaction
"""
if not self._under_transaction:
raise DatabaseFailure('The transaction has not begun')
self._commit()
self._under_transaction = False
def rollback(self):
"""
Rollback the current transaction
"""
self._rollback()
self._under_transaction = False
def setup(self, reset = 0): def setup(self, reset = 0):
"""Set up a database """Set up a database
...@@ -79,14 +52,33 @@ class DatabaseManager(object): ...@@ -79,14 +52,33 @@ class DatabaseManager(object):
""" """
raise NotImplementedError raise NotImplementedError
def _begin(self): def __enter__(self):
raise NotImplementedError """
Begin a transaction
"""
if self._under_transaction:
raise DatabaseFailure('A transaction has already begun')
r = self.begin()
self._under_transaction = True
return r
def _commit(self): def __exit__(self, exc_type, exc_value, tb):
raise NotImplementedError if not self._under_transaction:
raise DatabaseFailure('The transaction has not begun')
self._under_transaction = False
if exc_type is None:
self.commit()
else:
self.rollback()
def _rollback(self): def begin(self):
raise NotImplementedError pass
def commit(self):
pass
def rollback(self):
pass
def _getPartition(self, oid_or_tid): def _getPartition(self, oid_or_tid):
return oid_or_tid % self.getNumPartitions() return oid_or_tid % self.getNumPartitions()
...@@ -104,13 +96,8 @@ class DatabaseManager(object): ...@@ -104,13 +96,8 @@ class DatabaseManager(object):
if self._under_transaction: if self._under_transaction:
self._setConfiguration(key, value) self._setConfiguration(key, value)
else: else:
self.begin() with self:
try:
self._setConfiguration(key, value) self._setConfiguration(key, value)
except:
self.rollback()
raise
self.commit()
def _setConfiguration(self, key, value): def _setConfiguration(self, key, value):
raise NotImplementedError raise NotImplementedError
...@@ -171,7 +158,9 @@ class DatabaseManager(object): ...@@ -171,7 +158,9 @@ class DatabaseManager(object):
""" """
Load a Partition Table ID from a database. Load a Partition Table ID from a database.
""" """
return long(self.getConfiguration('ptid')) ptid = self.getConfiguration('ptid')
if ptid is not None:
return long(ptid)
def setPTID(self, ptid): def setPTID(self, ptid):
""" """
...@@ -194,18 +183,31 @@ class DatabaseManager(object): ...@@ -194,18 +183,31 @@ class DatabaseManager(object):
""" """
self.setConfiguration('loid', util.dump(loid)) self.setConfiguration('loid', util.dump(loid))
def getBackupTID(self):
return util.bin(self.getConfiguration('backup_tid'))
def setBackupTID(self, backup_tid):
return self.setConfiguration('backup_tid', util.dump(backup_tid))
def getPartitionTable(self): def getPartitionTable(self):
"""Return a whole partition table as a tuple of rows. Each row """Return a whole partition table as a tuple of rows. Each row
is again a tuple of an offset (row ID), an UUID of a storage is again a tuple of an offset (row ID), an UUID of a storage
node, and a cell state.""" node, and a cell state."""
raise NotImplementedError raise NotImplementedError
def getLastTID(self, all = True): def _getLastTIDs(self, all=True):
"""Return the last TID in a database. If all is true,
unfinished transactions must be taken account into. If there
is no TID in the database, return None."""
raise NotImplementedError raise NotImplementedError
def getLastTIDs(self, all=True):
trans, obj = self._getLastTIDs()
if trans:
tid = max(trans.itervalues())
if obj:
tid = max(tid, max(obj.itervalues()))
else:
tid = max(obj.itervalues()) if obj else None
return tid, trans, obj
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
"""Return a list of unfinished transaction's IDs.""" """Return a list of unfinished transaction's IDs."""
raise NotImplementedError raise NotImplementedError
...@@ -352,13 +354,8 @@ class DatabaseManager(object): ...@@ -352,13 +354,8 @@ class DatabaseManager(object):
else: else:
del refcount[data_id] del refcount[data_id]
if prune: if prune:
self.begin() with self:
try:
self._pruneData(data_id_list) self._pruneData(data_id_list)
except:
self.rollback()
raise
self.commit()
__getDataTID = set() __getDataTID = set()
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
...@@ -466,23 +463,24 @@ class DatabaseManager(object): ...@@ -466,23 +463,24 @@ class DatabaseManager(object):
an oid list""" an oid list"""
raise NotImplementedError raise NotImplementedError
def deleteTransactionsAbove(self, partition, tid, max_tid):
"""Delete all transactions above given TID (inclued) in given
partition, but never above max_tid (in case transactions are committed
during replication)."""
raise NotImplementedError
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
"""Delete given object. If serial is given, only delete that serial for """Delete given object. If serial is given, only delete that serial for
given oid.""" given oid."""
raise NotImplementedError raise NotImplementedError
def deleteObjectsAbove(self, partition, oid, serial, max_tid): def _deleteRange(self, partition, min_tid=None, max_tid=None):
"""Delete all objects above given OID and serial (inclued) in given """Delete all objects and transactions between given min_tid (excluded)
partition, but never above max_tid (in case objects are stored during and max_tid (included)"""
replication)"""
raise NotImplementedError raise NotImplementedError
def truncate(self, tid):
assert tid not in (None, ZERO_TID), tid
with self:
assert self.getBackupTID()
self.setBackupTID(tid)
for partition in xrange(self.getNumPartitions()):
self._deleteRange(partition, tid)
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
"""Return a tuple of the list of OIDs, user information, """Return a tuple of the list of OIDs, user information,
a description, and extension information, for a given transaction a description, and extension information, for a given transaction
...@@ -498,10 +496,10 @@ class DatabaseManager(object): ...@@ -498,10 +496,10 @@ class DatabaseManager(object):
If there is no such object ID in a database, return None.""" If there is no such object ID in a database, return None."""
raise NotImplementedError raise NotImplementedError
def getObjectHistoryFrom(self, oid, min_serial, max_serial, length, def getReplicationObjectList(self, min_tid, max_tid, length, partition,
partition): min_oid):
"""Return a dict of length serials grouped by oid at (or above) """Return a dict of length oids grouped by serial at (or above)
min_oid and min_serial and below max_serial, for given partition, min_tid and min_oid and below max_tid, for given partition,
sorted in ascending order.""" sorted in ascending order."""
raise NotImplementedError raise NotImplementedError
......
...@@ -27,14 +27,12 @@ import re ...@@ -27,14 +27,12 @@ import re
import string import string
import time import time
from . import DatabaseManager from . import DatabaseManager, LOG_QUERIES
from .manager import CreationUndone from .manager import CreationUndone
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
from neo.lib import util from neo.lib import util
LOG_QUERIES = False
def splitOIDField(tid, oids): def splitOIDField(tid, oids):
if (len(oids) % 8) != 0 or len(oids) == 0: if (len(oids) % 8) != 0 or len(oids) == 0:
raise DatabaseFailure('invalid oids length for tid %d: %d' % (tid, raise DatabaseFailure('invalid oids length for tid %d: %d' % (tid,
...@@ -99,18 +97,22 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -99,18 +97,22 @@ class MySQLDatabaseManager(DatabaseManager):
self.conn.query("SET SESSION group_concat_max_len = -1") self.conn.query("SET SESSION group_concat_max_len = -1")
self.conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION") self.conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION")
def _begin(self): def begin(self):
self.query("""BEGIN""") q = self.query
q("BEGIN")
return q
def _commit(self):
if LOG_QUERIES: if LOG_QUERIES:
def commit(self):
neo.lib.logging.debug('committing...') neo.lib.logging.debug('committing...')
self.conn.commit() self.conn.commit()
def _rollback(self): def rollback(self):
if LOG_QUERIES:
neo.lib.logging.debug('aborting...') neo.lib.logging.debug('aborting...')
self.conn.rollback() self.conn.rollback()
else:
commit = property(lambda self: self.conn.commit)
rollback = property(lambda self: self.conn.rollback)
def query(self, query): def query(self, query):
"""Query data from a database.""" """Query data from a database."""
...@@ -194,7 +196,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -194,7 +196,8 @@ class MySQLDatabaseManager(DatabaseManager):
serial BIGINT UNSIGNED NOT NULL, serial BIGINT UNSIGNED NOT NULL,
data_id BIGINT UNSIGNED NULL, data_id BIGINT UNSIGNED NULL,
value_serial BIGINT UNSIGNED NULL, value_serial BIGINT UNSIGNED NULL,
PRIMARY KEY (partition, oid, serial), PRIMARY KEY (partition, serial, oid),
KEY (partition, oid, serial),
KEY (data_id) KEY (data_id)
) ENGINE = InnoDB""" + p) ) ENGINE = InnoDB""" + p)
...@@ -233,15 +236,15 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -233,15 +236,15 @@ class MySQLDatabaseManager(DatabaseManager):
" FROM tobj WHERE data_id IS NOT NULL GROUP BY data_id") or ()) " FROM tobj WHERE data_id IS NOT NULL GROUP BY data_id") or ())
def getConfiguration(self, key): def getConfiguration(self, key):
if key in self._config: try:
return self._config[key] return self._config[key]
q = self.query except KeyError:
e = self.escape sql_key = self.escape(str(key))
sql_key = e(str(key))
try: try:
r = q("SELECT value FROM config WHERE name = '%s'" % sql_key)[0][0] r = self.query("SELECT value FROM config WHERE name = '%s'"
% sql_key)[0][0]
except IndexError: except IndexError:
raise KeyError, key r = None
self._config[key] = r self._config[key] = r
return r return r
...@@ -251,20 +254,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -251,20 +254,19 @@ class MySQLDatabaseManager(DatabaseManager):
self._config[key] = value self._config[key] = value
key = e(str(key)) key = e(str(key))
if value is None: if value is None:
value = 'NULL' q("DELETE FROM config WHERE name = '%s'" % key)
else: else:
value = "'%s'" % (e(str(value)), ) value = e(str(value))
q("""REPLACE INTO config VALUES ('%s', %s)""" % (key, value)) q("REPLACE INTO config VALUES ('%s', '%s')" % (key, value))
def _setPackTID(self, tid): def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid) self._setConfiguration('_pack_tid', tid)
def _getPackTID(self): def _getPackTID(self):
try: try:
result = int(self.getConfiguration('_pack_tid')) return int(self.getConfiguration('_pack_tid'))
except KeyError: except TypeError:
result = -1 return -1
return result
def getPartitionTable(self): def getPartitionTable(self):
q = self.query q = self.query
...@@ -275,58 +277,42 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -275,58 +277,42 @@ class MySQLDatabaseManager(DatabaseManager):
pt.append((offset, uuid, state)) pt.append((offset, uuid, state))
return pt return pt
def getLastTID(self, all = True): def _getLastTIDs(self, all=True):
# XXX this does not consider serials in obj. p64 = util.p64
# I am not sure if this is really harmful. For safety, with self as q:
# check for tobj only at the moment. The reason why obj is trans = dict((partition, p64(tid))
# not tested is that it is too slow to get the max serial for partition, tid in q("SELECT partition, MAX(tid)"
# from obj when it has a huge number of objects, because " FROM trans GROUP BY partition"))
# serial is the second part of the primary key, so the index obj = dict((partition, p64(tid))
# is not used in this case. If doing it, it is better to for partition, tid in q("SELECT partition, MAX(serial)"
# make another index for serial, but I doubt the cost increase " FROM obj GROUP BY partition"))
# is worth.
q = self.query
self.begin()
ltid = q("SELECT MAX(value) FROM (SELECT MAX(tid) AS value FROM trans "
"GROUP BY partition) AS foo")[0][0]
if all: if all:
tmp_ltid = q("""SELECT MAX(tid) FROM ttrans""")[0][0] tid = q("SELECT MAX(tid) FROM ttrans")[0][0]
if ltid is None or (tmp_ltid is not None and ltid < tmp_ltid): if tid is not None:
ltid = tmp_ltid trans[None] = p64(tid)
tmp_serial = q("""SELECT MAX(serial) FROM tobj""")[0][0] tid = q("SELECT MAX(serial) FROM tobj")[0][0]
if ltid is None or (tmp_serial is not None and ltid < tmp_serial): if tid is not None:
ltid = tmp_serial obj[None] = p64(tid)
self.commit() return trans, obj
if ltid is not None:
ltid = util.p64(ltid)
return ltid
def getUnfinishedTIDList(self): def getUnfinishedTIDList(self):
q = self.query
tid_set = set() tid_set = set()
self.begin() with self as q:
r = q("""SELECT tid FROM ttrans""") r = q("""SELECT tid FROM ttrans""")
tid_set.update((util.p64(t[0]) for t in r)) tid_set.update((util.p64(t[0]) for t in r))
r = q("""SELECT serial FROM tobj""") r = q("""SELECT serial FROM tobj""")
self.commit()
tid_set.update((util.p64(t[0]) for t in r)) tid_set.update((util.p64(t[0]) for t in r))
return list(tid_set) return list(tid_set)
def objectPresent(self, oid, tid, all = True): def objectPresent(self, oid, tid, all = True):
q = self.query
oid = util.u64(oid) oid = util.u64(oid)
tid = util.u64(tid) tid = util.u64(tid)
partition = self._getPartition(oid) partition = self._getPartition(oid)
self.begin() with self as q:
r = q("SELECT oid FROM obj WHERE partition=%d AND oid=%d AND " return q("SELECT oid FROM obj WHERE partition=%d AND oid=%d AND "
"serial=%d" % (partition, oid, tid)) "serial=%d" % (partition, oid, tid)) or all and \
if not r and all: q("SELECT oid FROM tobj WHERE serial=%d AND oid=%d"
r = q("SELECT oid FROM tobj WHERE serial=%d AND oid=%d"
% (tid, oid)) % (tid, oid))
self.commit()
if r:
return True
return False
def _getObject(self, oid, tid=None, before_tid=None): def _getObject(self, oid, tid=None, before_tid=None):
q = self.query q = self.query
...@@ -357,11 +343,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -357,11 +343,9 @@ class MySQLDatabaseManager(DatabaseManager):
return serial, next_serial, compression, checksum, data, value_serial return serial, next_serial, compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset): def doSetPartitionTable(self, ptid, cell_list, reset):
q = self.query
e = self.escape e = self.escape
offset_list = [] offset_list = []
self.begin() with self as q:
try:
if reset: if reset:
q("""TRUNCATE pt""") q("""TRUNCATE pt""")
for offset, uuid, state in cell_list: for offset, uuid, state in cell_list:
...@@ -377,10 +361,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -377,10 +361,6 @@ class MySQLDatabaseManager(DatabaseManager):
ON DUPLICATE KEY UPDATE state = %d""" \ ON DUPLICATE KEY UPDATE state = %d""" \
% (offset, uuid, state, state)) % (offset, uuid, state, state))
self.setPTID(ptid) self.setPTID(ptid)
except:
self.rollback()
raise
self.commit()
if self._use_partition: if self._use_partition:
for offset in offset_list: for offset in offset_list:
add = """ALTER TABLE %%s ADD PARTITION ( add = """ALTER TABLE %%s ADD PARTITION (
...@@ -399,9 +379,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -399,9 +379,7 @@ class MySQLDatabaseManager(DatabaseManager):
self.doSetPartitionTable(ptid, cell_list, True) self.doSetPartitionTable(ptid, cell_list, True)
def dropPartitions(self, offset_list): def dropPartitions(self, offset_list):
q = self.query with self as q:
self.begin()
try:
# XXX: these queries are inefficient (execution time increase with # XXX: these queries are inefficient (execution time increase with
# row count, although we use indexes) when there are rows to # row count, although we use indexes) when there are rows to
# delete. It should be done as an idle task, by chunks. # delete. It should be done as an idle task, by chunks.
...@@ -413,10 +391,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -413,10 +391,6 @@ class MySQLDatabaseManager(DatabaseManager):
q("DELETE FROM obj" + where) q("DELETE FROM obj" + where)
q("DELETE FROM trans" + where) q("DELETE FROM trans" + where)
self._pruneData(data_id_list) self._pruneData(data_id_list)
except:
self.rollback()
raise
self.commit()
if self._use_partition: if self._use_partition:
drop = "ALTER TABLE %s DROP PARTITION" + \ drop = "ALTER TABLE %s DROP PARTITION" + \
','.join(' p%u' % i for i in offset_list) ','.join(' p%u' % i for i in offset_list)
...@@ -428,20 +402,13 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -428,20 +402,13 @@ class MySQLDatabaseManager(DatabaseManager):
raise raise
def dropUnfinishedData(self): def dropUnfinishedData(self):
q = self.query with self as q:
self.begin()
try:
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x] data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("""TRUNCATE tobj""") q("""TRUNCATE tobj""")
q("""TRUNCATE ttrans""") q("""TRUNCATE ttrans""")
except:
self.rollback()
raise
self.commit()
self.unlockData(data_id_list, True) self.unlockData(data_id_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary = True): def storeTransaction(self, tid, object_list, transaction, temporary = True):
q = self.query
e = self.escape e = self.escape
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
...@@ -453,8 +420,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -453,8 +420,7 @@ class MySQLDatabaseManager(DatabaseManager):
obj_table = 'obj' obj_table = 'obj'
trans_table = 'trans' trans_table = 'trans'
self.begin() with self as q:
try:
for oid, data_id, value_serial in object_list: for oid, data_id, value_serial in object_list:
oid = u64(oid) oid = u64(oid)
partition = self._getPartition(oid) partition = self._getPartition(oid)
...@@ -481,10 +447,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -481,10 +447,6 @@ class MySQLDatabaseManager(DatabaseManager):
q("REPLACE INTO %s VALUES (%d, %d, %i, '%s', '%s', '%s', '%s')" q("REPLACE INTO %s VALUES (%d, %d, %i, '%s', '%s', '%s', '%s')"
% (trans_table, partition, tid, packed, oids, user, desc, % (trans_table, partition, tid, packed, oids, user, desc,
ext)) ext))
except:
self.rollback()
raise
self.commit()
def _pruneData(self, data_id_list): def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data) data_id_list = set(data_id_list).difference(self._uncommitted_data)
...@@ -497,24 +459,19 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -497,24 +459,19 @@ class MySQLDatabaseManager(DatabaseManager):
def _storeData(self, checksum, data, compression): def _storeData(self, checksum, data, compression):
e = self.escape e = self.escape
checksum = e(checksum) checksum = e(checksum)
self.begin() with self as q:
try:
try: try:
self.query("INSERT INTO data VALUES (NULL, '%s', %d, '%s')" % q("INSERT INTO data VALUES (NULL, '%s', %d, '%s')" %
(checksum, compression, e(data))) (checksum, compression, e(data)))
except IntegrityError, (code, _): except IntegrityError, (code, _):
if code != DUP_ENTRY: if code != DUP_ENTRY:
raise raise
(r, c, d), = self.query("SELECT id, compression, value" (r, c, d), = q("SELECT id, compression, value"
" FROM data WHERE hash='%s'" % checksum) " FROM data WHERE hash='%s'" % checksum)
if c != compression or d != data: if c != compression or d != data:
raise raise
else: else:
r = self.conn.insert_id() r = self.conn.insert_id()
except:
self.rollback()
raise
self.commit()
return r return r
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
...@@ -540,27 +497,20 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -540,27 +497,20 @@ class MySQLDatabaseManager(DatabaseManager):
def finishTransaction(self, tid): def finishTransaction(self, tid):
q = self.query q = self.query
tid = util.u64(tid) tid = util.u64(tid)
self.begin() with self as q:
try:
sql = " FROM tobj WHERE serial=%d" % tid sql = " FROM tobj WHERE serial=%d" % tid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
q("INSERT INTO obj SELECT *" + sql) q("INSERT INTO obj SELECT *" + sql)
q("DELETE FROM tobj WHERE serial=%d" % tid) q("DELETE FROM tobj WHERE serial=%d" % tid)
q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid) q("INSERT INTO trans SELECT * FROM ttrans WHERE tid=%d" % tid)
q("DELETE FROM ttrans WHERE tid=%d" % tid) q("DELETE FROM ttrans WHERE tid=%d" % tid)
except:
self.rollback()
raise
self.commit()
self.unlockData(data_id_list) self.unlockData(data_id_list)
def deleteTransaction(self, tid, oid_list=()): def deleteTransaction(self, tid, oid_list=()):
q = self.query
u64 = util.u64 u64 = util.u64
tid = u64(tid) tid = u64(tid)
getPartition = self._getPartition getPartition = self._getPartition
self.begin() with self as q:
try:
sql = " FROM tobj WHERE serial=%d" % tid sql = " FROM tobj WHERE serial=%d" % tid
data_id_list = [x for x, in q("SELECT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT data_id" + sql) if x]
self.unlockData(data_id_list) self.unlockData(data_id_list)
...@@ -578,77 +528,45 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -578,77 +528,45 @@ class MySQLDatabaseManager(DatabaseManager):
q("DELETE" + sql) q("DELETE" + sql)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
except:
self.rollback()
raise
self.commit()
def deleteTransactionsAbove(self, partition, tid, max_tid):
self.begin()
try:
self.query('DELETE FROM trans WHERE partition=%(partition)d AND '
'%(tid)d <= tid AND tid <= %(max_tid)d' % {
'partition': partition,
'tid': util.u64(tid),
'max_tid': util.u64(max_tid),
})
except:
self.rollback()
raise
self.commit()
def deleteObject(self, oid, serial=None): def deleteObject(self, oid, serial=None):
q = self.query
u64 = util.u64 u64 = util.u64
oid = u64(oid) oid = u64(oid)
sql = " FROM obj WHERE partition=%d AND oid=%d" \ sql = " FROM obj WHERE partition=%d AND oid=%d" \
% (self._getPartition(oid), oid) % (self._getPartition(oid), oid)
if serial: if serial:
sql += ' AND serial=%d' % u64(serial) sql += ' AND serial=%d' % u64(serial)
self.begin() with self as q:
try:
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x]
q("DELETE" + sql) q("DELETE" + sql)
self._pruneData(data_id_list) self._pruneData(data_id_list)
except:
self.rollback()
raise
self.commit()
def deleteObjectsAbove(self, partition, oid, serial, max_tid): def _deleteRange(self, partition, min_tid=None, max_tid=None):
sql = " WHERE partition=%d" % partition
if min_tid:
sql += " AND %d < tid" % util.u64(min_tid)
if max_tid:
sql += " AND tid <= %d" % util.u64(max_tid)
q = self.query q = self.query
u64 = util.u64 q("DELETE FROM trans" + sql)
oid = u64(oid) sql = " FROM obj" + sql.replace('tid', 'serial')
sql = (" FROM obj WHERE partition=%d AND serial <= %d"
" AND (oid > %d OR (oid = %d AND serial >= %d))" %
(partition, u64(max_tid), oid, oid, u64(serial)))
self.begin()
try:
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x] data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql) if x]
q("DELETE" + sql) q("DELETE" + sql)
self._pruneData(data_id_list) self._pruneData(data_id_list)
except:
self.rollback()
raise
self.commit()
def getTransaction(self, tid, all = False): def getTransaction(self, tid, all = False):
q = self.query
tid = util.u64(tid) tid = util.u64(tid)
self.begin() with self as q:
r = q("""SELECT oids, user, description, ext, packed FROM trans r = q("SELECT oids, user, description, ext, packed FROM trans"
WHERE partition = %d AND tid = %d""" \ " WHERE partition = %d AND tid = %d"
% (self._getPartition(tid), tid)) % (self._getPartition(tid), tid))
if not r and all: if not r and all:
r = q("""SELECT oids, user, description, ext, packed FROM ttrans r = q("SELECT oids, user, description, ext, packed FROM ttrans"
WHERE tid = %d""" \ " WHERE tid = %d" % tid)
% tid)
self.commit()
if r: if r:
oids, user, desc, ext, packed = r[0] oids, user, desc, ext, packed = r[0]
oid_list = splitOIDField(tid, oids) oid_list = splitOIDField(tid, oids)
return oid_list, user, desc, ext, bool(packed) return oid_list, user, desc, ext, bool(packed)
return None
def _getObjectLength(self, oid, value_serial): def _getObjectLength(self, oid, value_serial):
if value_serial is None: if value_serial is None:
...@@ -690,34 +608,17 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -690,34 +608,17 @@ class MySQLDatabaseManager(DatabaseManager):
return result return result
return None return None
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length, def getReplicationObjectList(self, min_tid, max_tid, length, partition,
partition): min_oid):
q = self.query
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
min_oid = u64(min_oid) min_tid = u64(min_tid)
min_serial = u64(min_serial) r = self.query('SELECT serial, oid FROM obj'
max_serial = u64(max_serial) ' WHERE partition = %d AND serial <= %d'
r = q('SELECT oid, serial FROM obj ' ' AND (serial = %d AND %d <= oid OR %d < serial)'
'WHERE partition = %(partition)s ' ' ORDER BY serial ASC, oid ASC LIMIT %d' % (
'AND serial <= %(max_serial)d ' partition, u64(max_tid), min_tid, u64(min_oid), min_tid, length))
'AND ((oid = %(min_oid)d AND serial >= %(min_serial)d) ' return [(p64(serial), p64(oid)) for serial, oid in r]
'OR oid > %(min_oid)d) '
'ORDER BY oid ASC, serial ASC LIMIT %(length)d' % {
'min_oid': min_oid,
'min_serial': min_serial,
'max_serial': max_serial,
'length': length,
'partition': partition,
})
result = {}
for oid, serial in r:
try:
serial_list = result[oid]
except KeyError:
serial_list = result[oid] = []
serial_list.append(p64(serial))
return dict((p64(x), y) for x, y in result.iteritems())
def getTIDList(self, offset, length, partition_list): def getTIDList(self, offset, length, partition_list):
q = self.query q = self.query
...@@ -727,12 +628,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -727,12 +628,11 @@ class MySQLDatabaseManager(DatabaseManager):
return [util.p64(t[0]) for t in r] return [util.p64(t[0]) for t in r]
def getReplicationTIDList(self, min_tid, max_tid, length, partition): def getReplicationTIDList(self, min_tid, max_tid, length, partition):
q = self.query
u64 = util.u64 u64 = util.u64
p64 = util.p64 p64 = util.p64
min_tid = u64(min_tid) min_tid = u64(min_tid)
max_tid = u64(max_tid) max_tid = u64(max_tid)
r = q("""SELECT tid FROM trans r = self.query("""SELECT tid FROM trans
WHERE partition = %(partition)d WHERE partition = %(partition)d
AND tid >= %(min_tid)d AND tid <= %(max_tid)d AND tid >= %(min_tid)d AND tid <= %(max_tid)d
ORDER BY tid ASC LIMIT %(length)d""" % { ORDER BY tid ASC LIMIT %(length)d""" % {
...@@ -772,13 +672,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -772,13 +672,11 @@ class MySQLDatabaseManager(DatabaseManager):
def pack(self, tid, updateObjectDataForPack): def pack(self, tid, updateObjectDataForPack):
# TODO: unit test (along with updatePackFuture) # TODO: unit test (along with updatePackFuture)
q = self.query
p64 = util.p64 p64 = util.p64
tid = util.u64(tid) tid = util.u64(tid)
updatePackFuture = self._updatePackFuture updatePackFuture = self._updatePackFuture
getPartition = self._getPartition getPartition = self._getPartition
self.begin() with self as q:
try:
self._setPackTID(tid) self._setPackTID(tid)
for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, ' for count, oid, max_serial in q('SELECT COUNT(*) - 1, oid, '
'MAX(serial) FROM obj WHERE serial <= %d GROUP BY oid' 'MAX(serial) FROM obj WHERE serial <= %d GROUP BY oid'
...@@ -804,10 +702,6 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -804,10 +702,6 @@ class MySQLDatabaseManager(DatabaseManager):
q('DELETE' + sql) q('DELETE' + sql)
data_id_set.discard(None) data_id_set.discard(None)
self._pruneData(data_id_set) self._pruneData(data_id_set)
except:
self.rollback()
raise
self.commit()
def checkTIDRange(self, min_tid, max_tid, length, partition): def checkTIDRange(self, min_tid, max_tid, length, partition):
count, tid_checksum, max_tid = self.query( count, tid_checksum, max_tid = self.query(
...@@ -816,11 +710,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -816,11 +710,11 @@ class MySQLDatabaseManager(DatabaseManager):
WHERE partition = %(partition)s WHERE partition = %(partition)s
AND tid >= %(min_tid)d AND tid >= %(min_tid)d
AND tid <= %(max_tid)d AND tid <= %(max_tid)d
ORDER BY tid ASC LIMIT %(length)d) AS t""" % { ORDER BY tid ASC %(limit)s) AS t""" % {
'partition': partition, 'partition': partition,
'min_tid': util.u64(min_tid), 'min_tid': util.u64(min_tid),
'max_tid': util.u64(max_tid), 'max_tid': util.u64(max_tid),
'length': length, 'limit': '' if length is None else 'LIMIT %(length)d' % length,
})[0] })[0]
if count: if count:
return count, a2b_hex(tid_checksum), util.p64(max_tid) return count, a2b_hex(tid_checksum), util.p64(max_tid)
...@@ -839,11 +733,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -839,11 +733,11 @@ class MySQLDatabaseManager(DatabaseManager):
AND serial <= %(max_tid)d AND serial <= %(max_tid)d
AND (oid > %(min_oid)d OR AND (oid > %(min_oid)d OR
oid = %(min_oid)d AND serial >= %(min_serial)d) oid = %(min_oid)d AND serial >= %(min_serial)d)
ORDER BY oid ASC, serial ASC LIMIT %(length)d""" % { ORDER BY oid ASC, serial ASC %(limit)s""" % {
'min_oid': u64(min_oid), 'min_oid': u64(min_oid),
'min_serial': u64(min_serial), 'min_serial': u64(min_serial),
'max_tid': u64(max_tid), 'max_tid': u64(max_tid),
'length': length, 'limit': '' if length is None else 'LIMIT %(length)d' % length,
'partition': partition, 'partition': partition,
}) })
if r: if r:
......
#
# Copyright (C) 2012 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import sqlite3
import neo.lib
from array import array
from hashlib import sha1
import re
import string
from . import DatabaseManager, LOG_QUERIES
from .manager import CreationUndone
from neo.lib.exception import DatabaseFailure
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, ZERO_HASH
from neo.lib import util
def splitOIDField(tid, oids):
if (len(oids) % 8) != 0 or len(oids) == 0:
raise DatabaseFailure('invalid oids length for tid %d: %d' % (tid,
len(oids)))
oid_list = []
append = oid_list.append
for i in xrange(0, len(oids), 8):
append(oids[i:i+8])
return oid_list
class SQLiteDatabaseManager(DatabaseManager):
"""This class manages a database on SQLite.
CAUTION: Make sure we never use statement journal files, as explained at
http://www.sqlite.org/tempfiles.html for more information.
In other words, temporary files (by default in /var/tmp !) must
never be used for small requests.
"""
def __init__(self, *args, **kw):
super(SQLiteDatabaseManager, self).__init__(*args, **kw)
self._config = {}
self._connect()
def _parse(self, database):
self.db = database
def close(self):
self.conn.close()
def _connect(self):
neo.lib.logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, isolation_level=None,
check_same_thread=False)
def begin(self):
q = self.query
q("BEGIN IMMEDIATE")
return q
if LOG_QUERIES:
def commit(self):
neo.lib.logging.debug('committing...')
self.conn.commit()
def rollback(self):
neo.lib.logging.debug('aborting...')
self.conn.rollback()
def query(self, query):
printable_char_list = []
for c in query.split('\n', 1)[0][:70]:
if c not in string.printable or c in '\t\x0b\x0c\r':
c = '\\x%02x' % ord(c)
printable_char_list.append(c)
neo.lib.logging.debug('querying %s...',
''.join(printable_char_list))
return self.conn.execute(query)
else:
commit = property(lambda self: self.conn.commit)
rollback = property(lambda self: self.conn.rollback)
query = property(lambda self: self.conn.execute)
def setup(self, reset = 0):
self._config.clear()
q = self.query
if reset:
for t in 'config', 'pt', 'trans', 'obj', 'data', 'ttrans', 'tobj':
q('DROP TABLE IF EXISTS ' + t)
# The table "config" stores configuration parameters which affect the
# persistent data.
q("""CREATE TABLE IF NOT EXISTS config (
name TEXT NOT NULL PRIMARY KEY,
value BLOB)
""")
# The table "pt" stores a partition table.
q("""CREATE TABLE IF NOT EXISTS pt (
rid INTEGER NOT NULL,
uuid BLOB NOT NULL,
state INTEGER NOT NULL,
PRIMARY KEY (rid, uuid))
""")
# The table "trans" stores information on committed transactions.
q("""CREATE TABLE IF NOT EXISTS trans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL,
PRIMARY KEY (partition, tid))
""")
# The table "obj" stores committed object metadata.
q("""CREATE TABLE IF NOT EXISTS obj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
serial INTEGER NOT NULL,
data_id INTEGER,
value_serial INTEGER,
PRIMARY KEY (partition, serial, oid))
""")
q("""CREATE INDEX IF NOT EXISTS _obj_i1 ON
obj(partition, oid, serial)
""")
q("""CREATE INDEX IF NOT EXISTS _obj_i2 ON
obj(data_id)
""")
# The table "data" stores object data.
q("""CREATE TABLE IF NOT EXISTS data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
hash BLOB NOT NULL UNIQUE,
compression INTEGER,
value BLOB)
""")
# The table "ttrans" stores information on uncommitted transactions.
q("""CREATE TABLE IF NOT EXISTS ttrans (
partition INTEGER NOT NULL,
tid INTEGER NOT NULL,
packed BOOLEAN NOT NULL,
oids BLOB NOT NULL,
user BLOB NOT NULL,
description BLOB NOT NULL,
ext BLOB NOT NULL)
""")
# The table "tobj" stores uncommitted object metadata.
q("""CREATE TABLE IF NOT EXISTS tobj (
partition INTEGER NOT NULL,
oid INTEGER NOT NULL,
serial INTEGER NOT NULL,
data_id INTEGER,
value_serial INTEGER,
PRIMARY KEY (serial, oid))
""")
self._uncommitted_data = dict(q("SELECT data_id, count(*)"
" FROM tobj WHERE data_id IS NOT NULL GROUP BY data_id"))
def getConfiguration(self, key):
try:
return self._config[key]
except KeyError:
try:
r = str(self.query("SELECT value FROM config WHERE name=?",
(key,)).fetchone()[0])
except TypeError:
r = None
self._config[key] = r
return r
def _setConfiguration(self, key, value):
q = self.query
self._config[key] = value
if value is None:
q("DELETE FROM config WHERE name=?", (key,))
else:
q("REPLACE INTO config VALUES (?,?)", (key, buffer(str(value))))
def _setPackTID(self, tid):
self._setConfiguration('_pack_tid', tid)
def _getPackTID(self):
try:
return int(self.getConfiguration('_pack_tid'))
except TypeError:
return -1
def getPartitionTable(self):
return [(offset, util.bin(uuid), state)
for offset, uuid, state in self.query(
"SELECT rid, uuid, state FROM pt")]
def _getLastTIDs(self, all=True):
p64 = util.p64
with self as q:
trans = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(tid)"
" FROM trans GROUP BY partition"))
obj = dict((partition, p64(tid))
for partition, tid in q("SELECT partition, MAX(serial)"
" FROM obj GROUP BY partition"))
if all:
tid = q("SELECT MAX(tid) FROM ttrans").fetchone()[0]
if tid is not None:
trans[None] = p64(tid)
tid = q("SELECT MAX(serial) FROM tobj").fetchone()[0]
if tid is not None:
obj[None] = p64(tid)
return trans, obj
def getUnfinishedTIDList(self):
p64 = util.p64
tid_set = set()
with self as q:
tid_set.update((p64(t[0]) for t in q("SELECT tid FROM ttrans")))
tid_set.update((p64(t[0]) for t in q("SELECT serial FROM tobj")))
return list(tid_set)
def objectPresent(self, oid, tid, all=True):
oid = util.u64(oid)
tid = util.u64(tid)
with self as q:
r = q("SELECT 1 FROM obj WHERE partition=? AND oid=? AND serial=?",
(self._getPartition(oid), oid, tid)).fetchone()
if not r and all:
r = q("SELECT 1 FROM tobj WHERE serial=? AND oid=?",
(tid, oid)).fetchone()
return bool(r)
def _getObject(self, oid, tid=None, before_tid=None):
q = self.query
partition = self._getPartition(oid)
sql = ('SELECT serial, compression, data.hash, value, value_serial'
' FROM obj LEFT JOIN data ON obj.data_id = data.id'
' WHERE partition=? AND oid=?')
if tid is not None:
r = q(sql + ' AND serial=?', (partition, oid, tid))
elif before_tid is not None:
r = q(sql + ' AND serial<? ORDER BY serial DESC LIMIT 1',
(partition, oid, before_tid))
else:
r = q(sql + ' ORDER BY serial DESC LIMIT 1', (partition, oid))
try:
serial, compression, checksum, data, value_serial = r.fetchone()
except TypeError:
return None
r = q("""SELECT serial FROM obj
WHERE partition=? AND oid=? AND serial>?
ORDER BY serial LIMIT 1""",
(partition, oid, serial)).fetchone()
if checksum:
checksum = str(checksum)
data = str(data)
return serial, r and r[0], compression, checksum, data, value_serial
def doSetPartitionTable(self, ptid, cell_list, reset):
with self as q:
if reset:
q("DELETE FROM pt")
for offset, uuid, state in cell_list:
uuid = buffer(util.dump(uuid))
# TODO: this logic should move out of database manager
# add 'dropCells(cell_list)' to API and use one query
# WKRD: Why does SQLite need a statement journal file
# whereas we try to replace only 1 value ?
# We don't want to remove the 'NOT NULL' constraint
# so we must simulate a "REPLACE OR FAIL".
q("DELETE FROM pt WHERE rid=? AND uuid=?", (offset, uuid))
if state != CellStates.DISCARDED:
q("INSERT OR FAIL INTO pt VALUES (?,?,?)",
(offset, uuid, int(state)))
self.setPTID(ptid)
def changePartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, False)
def setPartitionTable(self, ptid, cell_list):
self.doSetPartitionTable(ptid, cell_list, True)
def dropPartitions(self, offset_list):
where = " WHERE partition=?"
with self as q:
for partition in offset_list:
data_id_list = [x for x, in
q("SELECT DISTINCT data_id FROM obj" + where,
(partition,)) if x]
q("DELETE FROM obj" + where, (partition,))
q("DELETE FROM trans" + where, (partition,))
self._pruneData(data_id_list)
def dropUnfinishedData(self):
with self as q:
data_id_list = [x for x, in q("SELECT data_id FROM tobj") if x]
q("DELETE FROM tobj")
q("DELETE FROM ttrans")
self.unlockData(data_id_list, True)
def storeTransaction(self, tid, object_list, transaction, temporary=True):
u64 = util.u64
tid = u64(tid)
T = 't' if temporary else ''
obj_sql = "INSERT OR FAIL INTO %sobj VALUES (?,?,?,?,?)" % T
with self as q:
for oid, data_id, value_serial in object_list:
oid = u64(oid)
partition = self._getPartition(oid)
if value_serial:
value_serial = u64(value_serial)
(data_id,), = q("SELECT data_id FROM obj"
" WHERE partition=? AND oid=? AND serial=?",
(partition, oid, value_serial))
if temporary:
self.storeData(data_id)
try:
q(obj_sql, (partition, oid, tid, data_id, value_serial))
except sqlite3.IntegrityError:
# This may happen if a previous replication of 'obj' was
# interrupted.
if not T:
r, = q("SELECT data_id, value_serial FROM obj"
" WHERE partition=? AND oid=? AND serial=?",
(partition, oid, tid))
if r == (data_id, value_serial):
continue
raise
if transaction is not None:
oid_list, user, desc, ext, packed = transaction
partition = self._getPartition(tid)
assert packed in (0, 1)
q("INSERT OR FAIL INTO %strans VALUES (?,?,?,?,?,?,?)" % T,
(partition, tid, packed, buffer(''.join(oid_list)),
buffer(user), buffer(desc), buffer(ext)))
def _pruneData(self, data_id_list):
data_id_list = set(data_id_list).difference(self._uncommitted_data)
if data_id_list:
q = self.query
data_id_list.difference_update(x for x, in q(
"SELECT DISTINCT data_id FROM obj WHERE data_id IN (%s)"
% ",".join(map(str, data_id_list))))
q("DELETE FROM data WHERE id IN (%s)"
% ",".join(map(str, data_id_list)))
def _storeData(self, checksum, data, compression):
H = buffer(checksum)
with self as q:
try:
return q("INSERT INTO data VALUES (NULL,?,?,?)",
(H, compression, buffer(data))).lastrowid
except sqlite3.IntegrityError, e:
if e.args[0] == 'column hash is not unique':
(r, c, d), = q("SELECT id, compression, value"
" FROM data WHERE hash=?", (H,))
if c == compression and str(d) == data:
return r
raise
def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getPartition(oid)
sql = 'SELECT serial, data_id, value_serial FROM obj' \
' WHERE partition=? AND oid=?'
if tid is not None:
r = self.query(sql + ' AND serial=?', (partition, oid, tid))
elif before_tid is not None:
r = self.query(sql + ' AND serial<? ORDER BY serial DESC LIMIT 1',
(partition, oid, before_tid))
else:
r = self.query(sql + ' ORDER BY serial DESC LIMIT 1',
(partition, oid))
r = r.fetchone()
if r:
serial, data_id, value_serial = r
if value_serial is None and data_id:
return serial, serial
return serial, value_serial
return None, None
def finishTransaction(self, tid):
args = util.u64(tid),
with self as q:
sql = " FROM tobj WHERE serial=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, args) if x]
q("INSERT OR FAIL INTO obj SELECT *" + sql, args)
q("DELETE FROM tobj WHERE serial=?", args)
q("INSERT OR FAIL INTO trans SELECT * FROM ttrans WHERE tid=?",
args)
q("DELETE FROM ttrans WHERE tid=?", args)
self.unlockData(data_id_list)
def deleteTransaction(self, tid, oid_list=()):
u64 = util.u64
tid = u64(tid)
getPartition = self._getPartition
with self as q:
sql = " FROM tobj WHERE serial=?"
data_id_list = [x for x, in q("SELECT data_id" + sql, (tid,)) if x]
self.unlockData(data_id_list)
q("DELETE" + sql, (tid,))
q("DELETE FROM ttrans WHERE tid=?", (tid,))
q("DELETE FROM trans WHERE partition=? AND tid=?",
(getPartition(tid), tid))
# delete from obj using indexes
data_id_set = set()
for oid in oid_list:
oid = u64(oid)
sql = " FROM obj WHERE partition=? AND oid=? AND serial=?"
args = getPartition(oid), oid, tid
data_id_set.update(*q("SELECT data_id" + sql, args))
q("DELETE" + sql, args)
data_id_set.discard(None)
self._pruneData(data_id_set)
def deleteObject(self, oid, serial=None):
oid = util.u64(oid)
sql = " FROM obj WHERE partition=? AND oid=?"
args = [self._getPartition(oid), oid]
if serial:
sql += " AND serial=?"
args.append(util.u64(serial))
with self as q:
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql, args)
if x]
q("DELETE" + sql, args)
self._pruneData(data_id_list)
def _deleteRange(self, partition, min_tid=None, max_tid=None):
sql = " WHERE partition=?"
args = [partition]
if min_tid:
sql += " AND ? < tid"
args.append(util.u64(min_tid))
if max_tid:
sql += " AND tid <= ?"
args.append(util.u64(max_tid))
q = self.query
q("DELETE FROM trans" + sql, args)
sql = " FROM obj" + sql.replace('tid', 'serial')
data_id_list = [x for x, in q("SELECT DISTINCT data_id" + sql, args)
if x]
q("DELETE" + sql, args)
self._pruneData(data_id_list)
def getTransaction(self, tid, all=False):
tid = util.u64(tid)
with self as q:
r = q("SELECT oids, user, description, ext, packed FROM trans"
" WHERE partition=? AND tid=?",
(self._getPartition(tid), tid)).fetchone()
if not r and all:
r = q("SELECT oids, user, description, ext, packed FROM ttrans"
" WHERE tid=?", (tid,)).fetchone()
if r:
oids, user, description, ext, packed = r
return splitOIDField(tid, oids), str(user), \
str(description), str(ext), packed
def _getObjectLength(self, oid, value_serial):
if value_serial is None:
raise CreationUndone
length, value_serial = self.query("""SELECT LENGTH(value), value_serial
FROM obj LEFT JOIN data ON obj.data_id=data.id
WHERE partition=? AND oid=? AND serial=?""",
(self._getPartition(oid), oid, value_serial)).fetchone()
if length is None:
neo.lib.logging.info("Multiple levels of indirection"
" when searching for object data for oid %d at tid %d."
" This causes suboptimal performance.", oid, value_serial)
length = self._getObjectLength(oid, value_serial)
return length
def getObjectHistory(self, oid, offset=0, length=1):
# FIXME: This method doesn't take client's current transaction id as
# parameter, which means it can return transactions in the future of
# client's transaction.
p64 = util.p64
oid = util.u64(oid)
pack_tid = self._getPackTID()
result = []
append = result.append
with self as q:
for serial, length, value_serial in q("""\
SELECT serial, LENGTH(value), value_serial
FROM obj LEFT JOIN data ON obj.data_id = data.id
WHERE partition=? AND oid=? AND serial>=?
ORDER BY serial DESC LIMIT ?,?""",
(self._getPartition(oid), oid, pack_tid, offset, length)):
if length is None:
try:
length = self._getObjectLength(oid, value_serial)
except CreationUndone:
length = 0
append((p64(serial), length))
return result or None
def getReplicationObjectList(self, min_tid, max_tid, length, partition,
min_oid):
u64 = util.u64
p64 = util.p64
min_tid = u64(min_tid)
return [(p64(serial), p64(oid)) for serial, oid in self.query("""\
SELECT serial, oid FROM obj
WHERE partition=? AND serial<=?
AND (serial=? AND ?<=oid OR ?<serial)
ORDER BY serial ASC, oid ASC LIMIT ?""",
(partition, u64(max_tid), min_tid, u64(min_oid), min_tid, length))]
def getTIDList(self, offset, length, partition_list):
p64 = util.p64
return [p64(t[0]) for t in self.query("""\
SELECT tid FROM trans WHERE partition in (%s)
ORDER BY tid DESC LIMIT %d,%d"""
% (','.join(map(str, partition_list)), offset, length))]
def getReplicationTIDList(self, min_tid, max_tid, length, partition):
u64 = util.u64
p64 = util.p64
min_tid = u64(min_tid)
max_tid = u64(max_tid)
return [p64(t[0]) for t in self.query("""\
SELECT tid FROM trans
WHERE partition=? AND ?<=tid AND tid<=?
ORDER BY tid ASC LIMIT ?""",
(partition, min_tid, max_tid, length))]
def _updatePackFuture(self, oid, orig_serial, max_serial):
# Before deleting this objects revision, see if there is any
# transaction referencing its value at max_serial or above.
# If there is, copy value to the first future transaction. Any further
# reference is just updated to point to the new data location.
partition = self._getPartition(oid)
value_serial = None
q = self.query
for T in '', 't':
update = """UPDATE OR FAIL %sobj SET value_serial=?
WHERE partition=? AND oid=? AND serial=?""" % T
for serial, in q("""SELECT serial FROM %sobj
WHERE partition=? AND oid=? AND serial>=? AND value_serial=?
ORDER BY serial ASC""" % T,
(partition, oid, max_serial, orig_serial)):
q(update, (value_serial, partition, oid, serial))
if value_serial is None:
# First found, mark its serial for future reference.
value_serial = serial
return value_serial
def pack(self, tid, updateObjectDataForPack):
# TODO: unit test (along with updatePackFuture)
p64 = util.p64
tid = util.u64(tid)
updatePackFuture = self._updatePackFuture
getPartition = self._getPartition
with self as q:
self._setPackTID(tid)
for count, oid, max_serial in q("SELECT COUNT(*) - 1, oid,"
" MAX(serial) FROM obj WHERE serial<=? GROUP BY oid",
(tid,)):
partition = getPartition(oid)
if q("SELECT 1 FROM obj WHERE partition=?"
" AND oid=? AND serial=? AND data_id IS NULL",
(partition, oid, max_serial)).fetchone():
max_serial += 1
elif not count:
continue
# There are things to delete for this object
data_id_set = set()
sql = " FROM obj WHERE partition=? AND oid=? AND serial<?"
args = partition, oid, max_serial
for serial, data_id in q("SELECT serial, data_id" + sql, args):
data_id_set.add(data_id)
new_serial = updatePackFuture(oid, serial, max_serial)
if new_serial:
new_serial = p64(new_serial)
updateObjectDataForPack(p64(oid), p64(serial),
new_serial, data_id)
q("DELETE" + sql, args)
data_id_set.discard(None)
self._pruneData(data_id_set)
def checkTIDRange(self, min_tid, max_tid, length, partition):
count, tids, max_tid = self.query("""\
SELECT COUNT(*), GROUP_CONCAT(tid), MAX(tid)
FROM (SELECT tid FROM trans
WHERE partition=? AND ?<=tid AND tid<=?
ORDER BY tid ASC LIMIT ?) AS t""",
(partition, util.u64(min_tid), util.u64(max_tid),
-1 if length is None else length)).fetchone()
if count:
return count, sha1(tids).digest(), util.p64(max_tid)
return 0, ZERO_HASH, ZERO_TID
def checkSerialRange(self, min_oid, min_serial, max_tid, length, partition):
u64 = util.u64
# We don't ask MySQL to compute everything (like in checkTIDRange)
# because it's difficult to get the last serial _for the last oid_.
# We would need a function (that could be named 'LAST') that returns the
# last grouped value, instead of the greatest one.
min_oid = u64(min_oid)
r = self.query("""\
SELECT oid, serial
FROM obj
WHERE partition=? AND serial<=?
AND (oid>? OR oid=? AND serial>=?)
ORDER BY oid ASC, serial ASC LIMIT ?""",
(partition, u64(max_tid), min_oid, min_oid, u64(min_serial),
-1 if length is None else length)).fetchall()
if r:
p64 = util.p64
return (len(r),
sha1(','.join(str(x[0]) for x in r)).digest(),
p64(r[-1][0]),
sha1(','.join(str(x[1]) for x in r)).digest(),
p64(r[-1][1]))
return 0, ZERO_HASH, ZERO_OID, ZERO_HASH, ZERO_TID
...@@ -18,15 +18,15 @@ ...@@ -18,15 +18,15 @@
import neo import neo
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib import protocol
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.exception import PrimaryFailure, OperationFailure from neo.lib.exception import PrimaryFailure, OperationFailure
from neo.lib.protocol import NodeStates, NodeTypes, Packets, Errors, ZERO_HASH from neo.lib.protocol import NodeStates, NodeTypes
class BaseMasterHandler(EventHandler): class BaseMasterHandler(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
if self.app.listening_conn: # if running if self.app.listening_conn: # if running
self.app.master_node = None
raise PrimaryFailure('connection lost') raise PrimaryFailure('connection lost')
def stopOperation(self, conn): def stopOperation(self, conn):
...@@ -62,44 +62,5 @@ class BaseMasterHandler(EventHandler): ...@@ -62,44 +62,5 @@ class BaseMasterHandler(EventHandler):
dump(uuid)) dump(uuid))
self.app.tm.abortFor(uuid) self.app.tm.abortFor(uuid)
def answerUnfinishedTransactions(self, conn, *args, **kw):
class BaseClientAndStorageOperationHandler(EventHandler): self.app.replicator.setUnfinishedTIDList(*args, **kw)
""" Accept requests common to client and storage nodes """
def askTransactionInformation(self, conn, tid):
app = self.app
t = app.dm.getTransaction(tid)
if t is None:
p = Errors.TidNotFound('%s does not exist' % dump(tid))
else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0])
conn.answer(p)
def _askObject(self, oid, serial, tid):
raise NotImplementedError
def askObject(self, conn, oid, serial, tid):
app = self.app
if self.app.tm.loadLocked(oid):
# Delay the response.
app.queueEvent(self.askObject, conn, (oid, serial, tid))
return
o = self._askObject(oid, serial, tid)
if o is None:
neo.lib.logging.debug('oid = %s does not exist', dump(oid))
p = Errors.OidDoesNotExist(dump(oid))
elif o is False:
neo.lib.logging.debug('oid = %s not found', dump(oid))
p = Errors.OidNotFound(dump(oid))
else:
serial, next_serial, compression, checksum, data, data_serial = o
neo.lib.logging.debug('oid = %s, serial = %s, next_serial = %s',
dump(oid), dump(serial), dump(next_serial))
if checksum is None:
checksum = ZERO_HASH
data = ''
p = Packets.AnswerObject(oid, serial, next_serial,
compression, checksum, data, data_serial)
conn.answer(p)
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib import neo.lib
from neo.lib import protocol from neo.lib.handler import EventHandler
from neo.lib.util import dump, makeChecksum from neo.lib.util import dump, makeChecksum
from neo.lib.protocol import Packets, LockState, Errors, ZERO_HASH from neo.lib.protocol import Packets, LockState, Errors, ProtocolError, \
from . import BaseClientAndStorageOperationHandler ZERO_HASH, INVALID_PARTITION
from ..transactions import ConflictError, DelayedError from ..transactions import ConflictError, DelayedError
from ..exception import AlreadyPendingError from ..exception import AlreadyPendingError
import time import time
...@@ -28,10 +28,40 @@ import time ...@@ -28,10 +28,40 @@ import time
# Set to None to disable. # Set to None to disable.
SLOW_STORE = 2 SLOW_STORE = 2
class ClientOperationHandler(BaseClientAndStorageOperationHandler): class ClientOperationHandler(EventHandler):
def _askObject(self, oid, serial, ttid): def askTransactionInformation(self, conn, tid):
return self.app.dm.getObject(oid, serial, ttid) t = self.app.dm.getTransaction(tid)
if t is None:
p = Errors.TidNotFound('%s does not exist' % dump(tid))
else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0])
conn.answer(p)
def askObject(self, conn, oid, serial, tid):
app = self.app
if app.tm.loadLocked(oid):
# Delay the response.
app.queueEvent(self.askObject, conn, (oid, serial, tid))
return
o = app.dm.getObject(oid, serial, tid)
if o is None:
neo.lib.logging.debug('oid = %s does not exist', dump(oid))
p = Errors.OidDoesNotExist(dump(oid))
elif o is False:
neo.lib.logging.debug('oid = %s not found', dump(oid))
p = Errors.OidNotFound(dump(oid))
else:
serial, next_serial, compression, checksum, data, data_serial = o
neo.lib.logging.debug('oid = %s, serial = %s, next_serial = %s',
dump(oid), dump(serial), dump(next_serial))
if checksum is None:
checksum = ZERO_HASH
data = ''
p = Packets.AnswerObject(oid, serial, next_serial,
compression, checksum, data, data_serial)
conn.answer(p)
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
uuid = conn.getUUID() uuid = conn.getUUID()
...@@ -96,22 +126,18 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -96,22 +126,18 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
self._askStoreObject(conn, oid, serial, compression, checksum, data, self._askStoreObject(conn, oid, serial, compression, checksum, data,
data_serial, ttid, unlock, time.time()) data_serial, ttid, unlock, time.time())
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list): def askTIDsFrom(self, conn, min_tid, max_tid, length, partition):
getReplicationTIDList = self.app.dm.getReplicationTIDList conn.answer(Packets.AnswerTIDsFrom(self.app.dm.getReplicationTIDList(
tid_list = [] min_tid, max_tid, length, partition)))
extend = tid_list.extend
for partition in partition_list:
extend(getReplicationTIDList(min_tid, max_tid, length, partition))
conn.answer(Packets.AnswerTIDsFrom(tid_list))
def askTIDs(self, conn, first, last, partition): def askTIDs(self, conn, first, last, partition):
# This method is complicated, because I must return TIDs only # This method is complicated, because I must return TIDs only
# about usable partitions assigned to me. # about usable partitions assigned to me.
if first >= last: if first >= last:
raise protocol.ProtocolError('invalid offsets') raise ProtocolError('invalid offsets')
app = self.app app = self.app
if partition == protocol.INVALID_PARTITION: if partition == INVALID_PARTITION:
partition_list = app.pt.getAssignedPartitionList(app.uuid) partition_list = app.pt.getAssignedPartitionList(app.uuid)
else: else:
partition_list = [partition] partition_list = [partition]
...@@ -149,7 +175,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -149,7 +175,7 @@ class ClientOperationHandler(BaseClientAndStorageOperationHandler):
def askObjectHistory(self, conn, oid, first, last): def askObjectHistory(self, conn, oid, first, last):
if first >= last: if first >= last:
raise protocol.ProtocolError( 'invalid offsets') raise ProtocolError('invalid offsets')
app = self.app app = self.app
history_list = app.dm.getObjectHistory(oid, first, last - first) history_list = app.dm.getObjectHistory(oid, first, last - first)
......
...@@ -21,6 +21,7 @@ from neo.lib.handler import EventHandler ...@@ -21,6 +21,7 @@ from neo.lib.handler import EventHandler
from neo.lib.protocol import NodeTypes, Packets, NotReadyError from neo.lib.protocol import NodeTypes, Packets, NotReadyError
from neo.lib.protocol import ProtocolError, BrokenNodeDisallowedError from neo.lib.protocol import ProtocolError, BrokenNodeDisallowedError
from neo.lib.util import dump from neo.lib.util import dump
from .storage import StorageOperationHandler
class IdentificationHandler(EventHandler): class IdentificationHandler(EventHandler):
""" Handler used for incoming connections during operation state """ """ Handler used for incoming connections during operation state """
...@@ -35,6 +36,14 @@ class IdentificationHandler(EventHandler): ...@@ -35,6 +36,14 @@ class IdentificationHandler(EventHandler):
if not self.app.ready: if not self.app.ready:
raise NotReadyError raise NotReadyError
app = self.app app = self.app
if uuid is None:
if node_type != NodeTypes.STORAGE:
raise ProtocolError('reject anonymous non-storage node')
handler = StorageOperationHandler(self.app)
conn.setHandler(handler)
else:
if uuid == app.uuid:
raise ProtocolError("uuid conflict or loopback connection")
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
# If this node is broken, reject it. # If this node is broken, reject it.
if node is not None and node.isBroken(): if node is not None and node.isBroken():
...@@ -51,21 +60,18 @@ class IdentificationHandler(EventHandler): ...@@ -51,21 +60,18 @@ class IdentificationHandler(EventHandler):
assert not node.isConnected() assert not node.isConnected()
node.setRunning() node.setRunning()
elif node_type == NodeTypes.STORAGE: elif node_type == NodeTypes.STORAGE:
from .storage import StorageOperationHandler
handler = StorageOperationHandler
if node is None: if node is None:
neo.lib.logging.error('reject an unknown storage node %s', neo.lib.logging.error('reject an unknown storage node %s',
dump(uuid)) dump(uuid))
raise NotReadyError raise NotReadyError
handler = StorageOperationHandler
else: else:
raise ProtocolError('reject non-client-or-storage node') raise ProtocolError('reject non-client-or-storage node')
# apply the handler and set up the connection # apply the handler and set up the connection
handler = handler(self.app) handler = handler(self.app)
conn.setHandler(handler) conn.setHandler(handler)
node.setConnection(conn) node.setConnection(conn, app.uuid < uuid)
args = (NodeTypes.STORAGE, app.uuid, app.pt.getPartitions(),
app.pt.getReplicas(), uuid)
# accept the identification and trigger an event # accept the identification and trigger an event
conn.answer(Packets.AcceptIdentification(*args)) conn.answer(Packets.AcceptIdentification(NodeTypes.STORAGE, uuid and
app.uuid, app.pt.getPartitions(), app.pt.getReplicas(), uuid))
handler.connectionCompleted(conn) handler.connectionCompleted(conn)
...@@ -25,10 +25,6 @@ class InitializationHandler(BaseMasterHandler): ...@@ -25,10 +25,6 @@ class InitializationHandler(BaseMasterHandler):
def answerNodeInformation(self, conn): def answerNodeInformation(self, conn):
pass pass
def notifyNodeInformation(self, conn, node_list):
# the whole node list is received here
BaseMasterHandler.notifyNodeInformation(self, conn, node_list)
def answerPartitionTable(self, conn, ptid, row_list): def answerPartitionTable(self, conn, ptid, row_list):
app = self.app app = self.app
pt = app.pt pt = app.pt
...@@ -53,8 +49,9 @@ class InitializationHandler(BaseMasterHandler): ...@@ -53,8 +49,9 @@ class InitializationHandler(BaseMasterHandler):
app.dm.setPartitionTable(ptid, cell_list) app.dm.setPartitionTable(ptid, cell_list)
def answerLastIDs(self, conn, loid, ltid, lptid): def answerLastIDs(self, conn, loid, ltid, lptid, backup_tid):
self.app.dm.setLastOID(loid) self.app.dm.setLastOID(loid)
self.app.dm.setBackupTID(backup_tid)
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
# XXX: This is safe to ignore those notifications because all of the # XXX: This is safe to ignore those notifications because all of the
......
...@@ -24,11 +24,8 @@ from . import BaseMasterHandler ...@@ -24,11 +24,8 @@ from . import BaseMasterHandler
class MasterOperationHandler(BaseMasterHandler): class MasterOperationHandler(BaseMasterHandler):
""" This handler is used for the primary master """ """ This handler is used for the primary master """
def answerUnfinishedTransactions(self, conn, max_tid, ttid_list): def notifyTransactionFinished(self, conn, *args, **kw):
self.app.replicator.setUnfinishedTIDList(max_tid, ttid_list) self.app.replicator.transactionFinished(*args, **kw)
def notifyTransactionFinished(self, conn, ttid, max_tid):
self.app.replicator.transactionFinished(ttid, max_tid)
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, cell_list):
"""This is very similar to Send Partition Table, except that """This is very similar to Send Partition Table, except that
...@@ -44,14 +41,7 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -44,14 +41,7 @@ class MasterOperationHandler(BaseMasterHandler):
app.dm.changePartitionTable(ptid, cell_list) app.dm.changePartitionTable(ptid, cell_list)
# Check changes for replications # Check changes for replications
if app.replicator is not None: app.replicator.notifyPartitionChanges(cell_list)
for offset, uuid, state in cell_list:
if uuid == app.uuid:
# If this is for myself, this can affect replications.
if state == CellStates.DISCARDED:
app.replicator.removePartition(offset)
elif state == CellStates.OUT_OF_DATE:
app.replicator.addPartition(offset)
def askLockInformation(self, conn, ttid, tid, oid_list): def askLockInformation(self, conn, ttid, tid, oid_list):
if not ttid in self.app.tm: if not ttid in self.app.tm:
...@@ -74,3 +64,11 @@ class MasterOperationHandler(BaseMasterHandler): ...@@ -74,3 +64,11 @@ class MasterOperationHandler(BaseMasterHandler):
if not conn.isClosed(): if not conn.isClosed():
conn.answer(Packets.AnswerPack(True)) conn.answer(Packets.AnswerPack(True))
def replicate(self, conn, tid, upstream_name, source_dict):
self.app.replicator.backup(tid,
dict((p, (a, upstream_name))
for p, a in source_dict.iteritems()))
def askTruncate(self, conn, tid):
self.app.dm.truncate(tid)
conn.answer(Packets.AnswerTruncate())
#
# Copyright (C) 2006-2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from functools import wraps
import neo.lib
from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets, ZERO_HASH, ZERO_TID, ZERO_OID
from neo.lib.util import add64, u64
# TODO: benchmark how different values behave
RANGE_LENGTH = 4000
MIN_RANGE_LENGTH = 1000
CHECK_CHUNK = 0
CHECK_REPLICATE = 1
CHECK_DONE = 2
"""
Replication algorithm
Purpose: replicate the content of a reference node into a replicating node,
bringing it up-to-date.
This happens both when a new storage is added to en existing cluster, as well
as when a nde was separated from cluster and rejoins it.
Replication happens per partition. Reference node can change between
partitions.
2 parts, done sequentially:
- Transaction (metadata) replication
- Object (data) replication
Both parts follow the same mechanism:
- On both sides (replicating and reference), compute a checksum of a chunk
(RANGE_LENGTH number of entries). If there is a mismatch, chunk size is
reduced, and scan restarts from same row, until it reaches a minimal length
(MIN_RANGE_LENGTH). Then, it replicates all rows in that chunk. If the
content of chunks match, it moves on to the next chunk.
- Replicating a chunk starts with asking for a list of all entries (only their
identifier) and skipping those both side have, deleting those which reference
has and replicating doesn't, and asking individually all entries missing in
replicating.
"""
# TODO: Make object replication get ordered by serial first and oid second, so
# changes are in a big segment at the end, rather than in many segments (one
# per object).
# TODO: To improve performance when a pack happened, the following algorithm
# should be used:
# - If reference node packed, find non-existant oids in reference node (their
# creation was undone, and pack pruned them), and delete them.
# - Run current algorithm, starting at our last pack TID.
# - Pack partition at reference's TID.
def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw):
if self.app.replicator.isCurrentConnection(conn):
return func(self, conn, *args, **kw)
# Should probably raise & close connection...
return wraps(func)(decorator)
class ReplicationHandler(EventHandler):
"""This class handles events for replications."""
def connectionLost(self, conn, new_state):
replicator = self.app.replicator
if replicator.isCurrentConnection(conn):
if replicator.pending():
neo.lib.logging.warning(
'replication is stopped due to a connection lost')
replicator.storageLost()
def connectionFailed(self, conn):
neo.lib.logging.warning(
'replication is stopped due to connection failure')
self.app.replicator.storageLost()
def acceptIdentification(self, conn, node_type,
uuid, num_partitions, num_replicas, your_uuid):
self.startReplication(conn)
def startReplication(self, conn):
max_tid = self.app.replicator.getCurrentCriticalTID()
conn.ask(self._doAskCheckTIDRange(ZERO_TID, max_tid), timeout=300)
@checkConnectionIsReplicatorConnection
def answerTIDsFrom(self, conn, tid_list):
assert tid_list
app = self.app
ask = conn.ask
# If I have pending TIDs, check which TIDs I don't have, and
# request the data.
tid_set = frozenset(tid_list)
my_tid_set = frozenset(app.replicator.getTIDsFromResult())
extra_tid_set = my_tid_set - tid_set
if extra_tid_set:
deleteTransaction = app.dm.deleteTransaction
for tid in extra_tid_set:
deleteTransaction(tid)
missing_tid_set = tid_set - my_tid_set
for tid in missing_tid_set:
ask(Packets.AskTransactionInformation(tid), timeout=300)
if len(tid_list) == MIN_RANGE_LENGTH:
# If we received fewer, we knew it before sending AskTIDsFrom, and
# we should have finished TID replication at that time.
max_tid = self.app.replicator.getCurrentCriticalTID()
ask(self._doAskCheckTIDRange(add64(tid_list[-1], 1), max_tid,
RANGE_LENGTH))
@checkConnectionIsReplicatorConnection
def answerTransactionInformation(self, conn, tid,
user, desc, ext, packed, oid_list):
app = self.app
# Directly store the transaction.
app.dm.storeTransaction(tid, (), (oid_list, user, desc, ext, packed),
False)
@checkConnectionIsReplicatorConnection
def answerObjectHistoryFrom(self, conn, object_dict):
assert object_dict
app = self.app
ask = conn.ask
deleteObject = app.dm.deleteObject
my_object_dict = app.replicator.getObjectHistoryFromResult()
object_set = set()
max_oid = max(object_dict.iterkeys())
max_serial = max(object_dict[max_oid])
for oid, serial_list in object_dict.iteritems():
for serial in serial_list:
object_set.add((oid, serial))
my_object_set = set()
for oid, serial_list in my_object_dict.iteritems():
filter = lambda x: True
if max_oid is not None:
if oid > max_oid:
continue
elif oid == max_oid:
filter = lambda x: x <= max_serial
for serial in serial_list:
if filter(serial):
my_object_set.add((oid, serial))
extra_object_set = my_object_set - object_set
for oid, serial in extra_object_set:
deleteObject(oid, serial)
missing_object_set = object_set - my_object_set
for oid, serial in missing_object_set:
if not app.dm.objectPresent(oid, serial):
ask(Packets.AskObject(oid, serial, None), timeout=300)
if sum(map(len, object_dict.itervalues())) == MIN_RANGE_LENGTH:
max_tid = self.app.replicator.getCurrentCriticalTID()
ask(self._doAskCheckSerialRange(max_oid, add64(max_serial, 1),
max_tid, RANGE_LENGTH))
@checkConnectionIsReplicatorConnection
def answerObject(self, conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial):
dm = self.app.dm
if data or checksum != ZERO_HASH:
data_id = dm.storeData(checksum, data, compression)
else:
data_id = None
# Directly store the transaction.
obj = oid, data_id, data_serial
dm.storeTransaction(serial_start, [obj], None, False)
def _doAskCheckSerialRange(self, min_oid, min_tid, max_tid,
length=RANGE_LENGTH):
replicator = self.app.replicator
partition = replicator.getCurrentOffset()
neo.lib.logging.debug("Check serial range (offset=%s, min_oid=%x,"
" min_tid=%x, max_tid=%x, length=%s)", partition, u64(min_oid),
u64(min_tid), u64(max_tid), length)
check_args = (min_oid, min_tid, max_tid, length, partition)
replicator.checkSerialRange(*check_args)
return Packets.AskCheckSerialRange(*check_args)
def _doAskCheckTIDRange(self, min_tid, max_tid, length=RANGE_LENGTH):
replicator = self.app.replicator
partition = replicator.getCurrentOffset()
neo.lib.logging.debug(
"Check TID range (offset=%s, min_tid=%x, max_tid=%x, length=%s)",
partition, u64(min_tid), u64(max_tid), length)
replicator.checkTIDRange(min_tid, max_tid, length, partition)
return Packets.AskCheckTIDRange(min_tid, max_tid, length, partition)
def _doAskTIDsFrom(self, min_tid, length):
replicator = self.app.replicator
partition_id = replicator.getCurrentOffset()
max_tid = replicator.getCurrentCriticalTID()
replicator.getTIDsFrom(min_tid, max_tid, length, partition_id)
neo.lib.logging.debug("Ask TIDs (offset=%s, min_tid=%x, max_tid=%x,"
"length=%s)", partition_id, u64(min_tid), u64(max_tid), length)
return Packets.AskTIDsFrom(min_tid, max_tid, length, [partition_id])
def _doAskObjectHistoryFrom(self, min_oid, min_serial, length):
replicator = self.app.replicator
partition_id = replicator.getCurrentOffset()
max_serial = replicator.getCurrentCriticalTID()
replicator.getObjectHistoryFrom(min_oid, min_serial, max_serial,
length, partition_id)
return Packets.AskObjectHistoryFrom(min_oid, min_serial, max_serial,
length, partition_id)
def _checkRange(self, match, current_boundary, next_boundary, length,
count):
if count == 0:
# Reference storage has no data for this chunk, stop and truncate.
return CHECK_DONE, (current_boundary, )
if match:
# Same data on both sides
if length < RANGE_LENGTH and length == count:
# ...and previous check detected a difference - and we still
# haven't reached the end. This means that we just check the
# first half of a chunk which, as a whole, is different. So
# next test must happen on the next chunk.
recheck_min_boundary = next_boundary
else:
# ...and we just checked a whole chunk, move on to the next
# one.
recheck_min_boundary = None
else:
# Something is different in current chunk
recheck_min_boundary = current_boundary
if recheck_min_boundary is None:
if count == length:
# Go on with next chunk
action = CHECK_CHUNK
params = (next_boundary, RANGE_LENGTH)
else:
# No more chunks.
action = CHECK_DONE
params = (next_boundary, )
else:
# We must recheck current chunk.
if not match and count <= MIN_RANGE_LENGTH:
# We are already at minimum chunk length, replicate.
action = CHECK_REPLICATE
params = (recheck_min_boundary, )
else:
# Check a smaller chunk.
# Note: +1, so we can detect we reached the end when answer
# comes back.
action = CHECK_CHUNK
params = (recheck_min_boundary, max(min(length / 2, count + 1),
MIN_RANGE_LENGTH))
return action, params
@checkConnectionIsReplicatorConnection
def answerCheckTIDRange(self, conn, min_tid, length, count, tid_checksum,
max_tid):
pkt_min_tid = min_tid
ask = conn.ask
app = self.app
replicator = app.replicator
next_tid = add64(max_tid, 1)
action, params = self._checkRange(
replicator.getTIDCheckResult(min_tid, length) == (
count, tid_checksum, max_tid), min_tid, next_tid, length,
count)
critical_tid = replicator.getCurrentCriticalTID()
if action == CHECK_REPLICATE:
(min_tid, ) = params
ask(self._doAskTIDsFrom(min_tid, count))
if length != count:
action = CHECK_DONE
params = (next_tid, )
if action == CHECK_CHUNK:
(min_tid, count) = params
if min_tid >= critical_tid:
# Stop if past critical TID
action = CHECK_DONE
params = (next_tid, )
else:
ask(self._doAskCheckTIDRange(min_tid, critical_tid, count))
if action == CHECK_DONE:
# Delete all transactions we might have which are beyond what peer
# knows.
(last_tid, ) = params
offset = replicator.getCurrentOffset()
neo.lib.logging.debug("TID range checked (offset=%s, min_tid=%x,"
" length=%s, count=%s, max_tid=%x, last_tid=%x,"
" critical_tid=%x)", offset, u64(pkt_min_tid), length, count,
u64(max_tid), u64(last_tid), u64(critical_tid))
app.dm.deleteTransactionsAbove(offset, last_tid, critical_tid)
# If no more TID, a replication of transactions is finished.
# So start to replicate objects now.
ask(self._doAskCheckSerialRange(ZERO_OID, ZERO_TID, critical_tid))
@checkConnectionIsReplicatorConnection
def answerCheckSerialRange(self, conn, min_oid, min_serial, length, count,
oid_checksum, max_oid, serial_checksum, max_serial):
ask = conn.ask
app = self.app
replicator = app.replicator
next_params = (max_oid, add64(max_serial, 1))
action, params = self._checkRange(
replicator.getSerialCheckResult(min_oid, min_serial, length) == (
count, oid_checksum, max_oid, serial_checksum, max_serial),
(min_oid, min_serial), next_params, length, count)
if action == CHECK_REPLICATE:
((min_oid, min_serial), ) = params
ask(self._doAskObjectHistoryFrom(min_oid, min_serial, count))
if length != count:
action = CHECK_DONE
params = (next_params, )
if action == CHECK_CHUNK:
((min_oid, min_serial), count) = params
max_tid = replicator.getCurrentCriticalTID()
ask(self._doAskCheckSerialRange(min_oid, min_serial, max_tid, count))
if action == CHECK_DONE:
# Delete all objects we might have which are beyond what peer
# knows.
((last_oid, last_serial), ) = params
offset = replicator.getCurrentOffset()
max_tid = replicator.getCurrentCriticalTID()
neo.lib.logging.debug("Serial range checked (offset=%s, min_oid=%x,"
" min_serial=%x, length=%s, count=%s, max_oid=%x,"
" max_serial=%x, last_oid=%x, last_serial=%x, critical_tid=%x)",
offset, u64(min_oid), u64(min_serial), length, count,
u64(max_oid), u64(max_serial), u64(last_oid), u64(last_serial),
u64(max_tid))
app.dm.deleteObjectsAbove(offset, last_oid, last_serial, max_tid)
# Nothing remains, so the replication for this partition is
# finished.
replicator.setReplicationDone()
...@@ -15,36 +15,101 @@ ...@@ -15,36 +15,101 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from . import BaseClientAndStorageOperationHandler import weakref
from neo.lib.protocol import Packets from functools import wraps
import neo.lib
from neo.lib.connector import ConnectorConnectionClosedException
from neo.lib.handler import EventHandler
from neo.lib.protocol import Errors, NodeStates, Packets, \
ZERO_HASH, ZERO_TID, ZERO_OID
from neo.lib.util import add64, u64
class StorageOperationHandler(BaseClientAndStorageOperationHandler): def checkConnectionIsReplicatorConnection(func):
def decorator(self, conn, *args, **kw):
assert self.app.replicator.getCurrentConnection() is conn
return func(self, conn, *args, **kw)
return wraps(func)(decorator)
def _askObject(self, oid, serial, tid): class StorageOperationHandler(EventHandler):
result = self.app.dm.getObject(oid, serial, tid) """This class handles events for replications."""
if result and result[5]:
return result[:2] + (None, None, None) + result[4:]
return result
def askLastIDs(self, conn): def connectionLost(self, conn, new_state):
app = self.app if self.app.listening_conn and conn.isClient():
oid = app.dm.getLastOID() # XXX: Connection and Node should merged.
tid = app.dm.getLastTID() uuid = conn.getUUID()
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID())) if uuid:
node = self.app.nm.getByUUID(uuid)
def askTIDsFrom(self, conn, min_tid, max_tid, length, partition_list): else:
assert len(partition_list) == 1, partition_list node = self.app.nm.getByAddress(conn.getAddress())
tid_list = self.app.dm.getReplicationTIDList(min_tid, max_tid, length, node.setState(NodeStates.DOWN)
partition_list[0]) replicator = self.app.replicator
conn.answer(Packets.AnswerTIDsFrom(tid_list)) if replicator.current_node is node:
replicator.abort()
def askObjectHistoryFrom(self, conn, min_oid, min_serial, max_serial,
length, partition): # Client
object_dict = self.app.dm.getObjectHistoryFrom(min_oid, min_serial,
max_serial, length, partition) def connectionFailed(self, conn):
conn.answer(Packets.AnswerObjectHistoryFrom(object_dict)) if self.app.listening_conn:
self.app.replicator.abort()
@checkConnectionIsReplicatorConnection
def acceptIdentification(self, conn, node_type,
uuid, num_partitions, num_replicas, your_uuid):
self.app.replicator.fetchTransactions()
@checkConnectionIsReplicatorConnection
def answerFetchTransactions(self, conn, pack_tid, next_tid, tid_list):
if tid_list:
deleteTransaction = self.app.dm.deleteTransaction
for tid in tid_list:
deleteTransaction(tid)
assert not pack_tid, "TODO"
if next_tid:
self.app.replicator.fetchTransactions(next_tid)
else:
self.app.replicator.fetchObjects()
@checkConnectionIsReplicatorConnection
def addTransaction(self, conn, tid, user, desc, ext, packed, oid_list):
# Directly store the transaction.
self.app.dm.storeTransaction(tid, (),
(oid_list, user, desc, ext, packed), False)
@checkConnectionIsReplicatorConnection
def answerFetchObjects(self, conn, pack_tid, next_tid,
next_oid, object_dict):
if object_dict:
deleteObject = self.app.dm.deleteObject
for serial, oid_list in object_dict.iteritems():
for oid in oid_list:
delObject(oid, serial)
assert not pack_tid, "TODO"
if next_tid:
self.app.replicator.fetchObjects(next_tid, next_oid)
else:
self.app.replicator.finish()
@checkConnectionIsReplicatorConnection
def addObject(self, conn, oid, serial, compression,
checksum, data, data_serial):
dm = self.app.dm
if data or checksum != ZERO_HASH:
data_id = dm.storeData(checksum, data, compression)
else:
data_id = None
# Directly store the transaction.
obj = oid, data_id, data_serial
dm.storeTransaction(serial, (obj,), None, False)
@checkConnectionIsReplicatorConnection
def replicationError(self, conn, message):
self.app.replicator.abort('source message: ' + message)
# Server (all methods must set connection as server so that it isn't closed
# if client tasks are finished)
def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition): def askCheckTIDRange(self, conn, min_tid, max_tid, length, partition):
conn.asServer()
count, tid_checksum, max_tid = self.app.dm.checkTIDRange(min_tid, count, tid_checksum, max_tid = self.app.dm.checkTIDRange(min_tid,
max_tid, length, partition) max_tid, length, partition)
conn.answer(Packets.AnswerCheckTIDRange(min_tid, length, conn.answer(Packets.AnswerCheckTIDRange(min_tid, length,
...@@ -52,9 +117,91 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler): ...@@ -52,9 +117,91 @@ class StorageOperationHandler(BaseClientAndStorageOperationHandler):
def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length, def askCheckSerialRange(self, conn, min_oid, min_serial, max_tid, length,
partition): partition):
conn.asServer()
count, oid_checksum, max_oid, serial_checksum, max_serial = \ count, oid_checksum, max_oid, serial_checksum, max_serial = \
self.app.dm.checkSerialRange(min_oid, min_serial, max_tid, length, self.app.dm.checkSerialRange(min_oid, min_serial, max_tid, length,
partition) partition)
conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length, conn.answer(Packets.AnswerCheckSerialRange(min_oid, min_serial, length,
count, oid_checksum, max_oid, serial_checksum, max_serial)) count, oid_checksum, max_oid, serial_checksum, max_serial))
def askFetchTransactions(self, conn, partition, length, min_tid, max_tid,
tid_list):
app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
return conn.answer(Errors.ReplicationError(
"partition %u not readable" % partition))
conn.asServer()
msg_id = conn.getPeerId()
conn = weakref.proxy(conn)
peer_tid_set = set(tid_list)
dm = app.dm
tid_list = dm.getReplicationTIDList(min_tid, max_tid, length + 1,
partition)
next_tid = tid_list.pop() if length < len(tid_list) else None
def push():
try:
pack_tid = None # TODO
for tid in tid_list:
if tid in peer_tid_set:
peer_tid_set.remove(tid)
else:
t = dm.getTransaction(tid)
if t is None:
conn.answer(Errors.ReplicationError(
"partition %u dropped" % partition))
return
oid_list, user, desc, ext, packed = t
conn.notify(Packets.AddTransaction(
tid, user, desc, ext, packed, oid_list))
yield
conn.answer(Packets.AnswerFetchTransactions(
pack_tid, next_tid, peer_tid_set), msg_id)
yield
except (weakref.ReferenceError, ConnectorConnectionClosedException):
pass
app.newTask(push())
def askFetchObjects(self, conn, partition, length, min_tid, max_tid,
min_oid, object_dict):
app = self.app
cell = app.pt.getCell(partition, app.uuid)
if cell is None or cell.isOutOfDate():
return conn.answer(Errors.ReplicationError(
"partition %u not readable" % partition))
conn.asServer()
msg_id = conn.getPeerId()
conn = weakref.proxy(conn)
dm = app.dm
object_list = dm.getReplicationObjectList(min_tid, max_tid, length,
partition, min_oid)
if length < len(object_list):
next_tid, next_oid = object_list.pop()
else:
next_tid = next_oid = None
def push():
try:
pack_tid = None # TODO
for serial, oid in object_list:
oid_set = object_dict.get(serial)
if oid_set:
if type(oid_set) is list:
object_dict[serial] = oid_set = set(oid_set)
if oid in oid_set:
oid_set.remove(oid)
if not oid_set:
del object_dict[serial]
continue
object = dm.getObject(oid, serial)
if object is None:
conn.answer(Errors.ReplicationError(
"partition %u dropped" % partition))
return
conn.notify(Packets.AddObject(oid, serial, *object[2:]))
yield
conn.answer(Packets.AnswerFetchObjects(
pack_tid, next_tid, next_oid, object_dict), msg_id)
yield
except (weakref.ReferenceError, ConnectorConnectionClosedException):
pass
app.newTask(push())
...@@ -27,15 +27,11 @@ class VerificationHandler(BaseMasterHandler): ...@@ -27,15 +27,11 @@ class VerificationHandler(BaseMasterHandler):
def askLastIDs(self, conn): def askLastIDs(self, conn):
app = self.app app = self.app
try: conn.answer(Packets.AnswerLastIDs(
oid = app.dm.getLastOID() app.dm.getLastOID(),
except KeyError: app.dm.getLastTIDs()[0],
oid = None app.pt.getID(),
try: app.dm.getBackupTID()))
tid = app.dm.getLastTID()
except KeyError:
tid = None
conn.answer(Packets.AnswerLastIDs(oid, tid, app.pt.getID()))
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
......
...@@ -15,363 +15,300 @@ ...@@ -15,363 +15,300 @@
# along with this program; if not, write to the Free Software # along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import neo.lib """
from random import choice Replication algorithm
from .handlers import replication Purpose: replicate the content of a reference node into a replicating node,
from neo.lib.protocol import NodeTypes, NodeStates, Packets bringing it up-to-date. This happens in the following cases:
from neo.lib.connection import ClientConnection - A new storage is added to en existing cluster.
from neo.lib.util import dump - A node was separated from cluster and rejoins it.
- In a backup cluster, the master notifies a node that new data exists upstream
(note that in this case, the cell is always marked as UP_TO_DATE).
class Partition(object): Replication happens per partition. Reference node can change between
"""This class abstracts the state of a partition.""" partitions.
def __init__(self, offset, max_tid, ttid_list): 2 parts, done sequentially:
# Possible optimization: - Transaction (metadata) replication
# _pending_ttid_list & _critical_tid can be shared amongst partitions - Object (data) replication
# created at the same time (cf Replicator.setUnfinishedTIDList).
# Replicator.transactionFinished would only have to iterate on these
# different sets, instead of all partitions.
self._offset = offset
self._pending_ttid_list = set(ttid_list)
# pending upper bound
self._critical_tid = max_tid
def getOffset(self): Both parts follow the same mechanism:
return self._offset - The range of data to replicate is split into chunks of FETCH_COUNT items
(transaction or object).
- For every chunk, the requesting node sends to seeding node the list of items
it already has.
- Before answering, the seeding node sends 1 packet for every missing item.
- The seeding node finally answers with the list of items to delete (usually
empty).
def getCriticalTID(self): Replication is partial, starting from the greatest stored tid in the partition:
return self._critical_tid - For transactions, this tid is excluded from replication.
- For objects, this tid is included unless the storage already knows it has
all oids for it.
def transactionFinished(self, ttid, max_tid): There is no check that item values on both nodes matches.
self._pending_ttid_list.remove(ttid)
assert max_tid is not None TODO: Packing and replication currently fail when then happen at the same time.
# final upper bound """
self._critical_tid = max_tid
def safe(self): import random
return not self._pending_ttid_list
class Task(object): import neo.lib
""" from neo.lib.protocol import CellStates, NodeTypes, NodeStates, Packets, \
A Task is a callable to execute at another time, with given parameters. INVALID_TID, ZERO_TID, ZERO_OID
Execution result is kept and can be retrieved later. from neo.lib.connection import ClientConnection
""" from neo.lib.util import add64, u64
from .handlers.storage import StorageOperationHandler
_func = None FETCH_COUNT = 1000
_args = None
_kw = None
_result = None
_processed = False
def __init__(self, func, args=(), kw=None):
self._func = func
self._args = args
if kw is None:
kw = {}
self._kw = kw
def process(self): class Partition(object):
if self._processed:
raise ValueError, 'You cannot process a single Task twice'
self._processed = True
self._result = self._func(*self._args, **self._kw)
def getResult(self): __slots__ = 'next_trans', 'next_obj', 'max_ttid'
# Should we instead execute immediately rather than raising ?
if not self._processed:
raise ValueError, 'You cannot get a result until task is executed'
return self._result
def __repr__(self): def __repr__(self):
fmt = '<%s at %x %r(*%r, **%r)%%s>' % (self.__class__.__name__, return '<%s(%s) at 0x%x>' % (self.__class__.__name__,
id(self), self._func, self._args, self._kw) ', '.join('%s=%r' % (x, getattr(self, x)) for x in self.__slots__
if self._processed: if hasattr(self, x)),
extra = ' => %r' % (self._result, ) id(self))
else:
extra = ''
return fmt % (extra, )
class Replicator(object): class Replicator(object):
"""This class handles replications of objects and transactions.
Assumptions:
- Client nodes recognize partition changes reasonably quickly.
- When an out of date partition is added, next transaction ID
is given after the change is notified and serialized.
Procedures:
- Get the last TID right after a partition is added. This TID
is called a "critical TID", because this and TIDs before this
may not be present in this storage node yet. After a critical
TID, all transactions must exist in this storage node.
- Check if a primary master node still has pending transactions
before and at a critical TID. If so, I must wait for them to be
committed or aborted.
- In order to copy data, first get the list of TIDs. This is done
part by part, because the list can be very huge. When getting
a part of the list, I verify if they are in my database, and
ask data only for non-existing TIDs. This is performed until
the check reaches a critical TID.
- Next, get the list of OIDs. And, for each OID, ask the history,
namely, a list of serials. This is also done part by part, and
I ask only non-existing data. """
# new_partition_set
# outdated partitions for which no pending transactions was asked to
# primary master yet
# partition_dict
# outdated partitions with pending transaction and temporary critical
# tid
# current_partition
# partition being currently synchronised
# current_connection
# connection to a storage node we are replicating from
# waiting_for_unfinished_tids
# unfinished tids have been asked to primary master node, but it
# didn't answer yet.
# replication_done
# False if we know there is something to replicate.
# True when current_partition is replicated, or we don't know yet if
# there is something to replicate
current_node = None
current_partition = None current_partition = None
current_connection = None
waiting_for_unfinished_tids = False
replication_done = True
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
self.new_partition_set = set()
self.partition_dict = {}
self.task_list = []
self.task_dict = {}
def masterLost(self):
"""
When connection to primary master is lost, stop waiting for unfinished
transactions.
"""
self.waiting_for_unfinished_tids = False
def storageLost(self): def getCurrentConnection(self):
""" node = self.current_node
Restart replicating. if node is not None and node.isConnected():
""" return node.getConnection()
self.reset()
def populate(self): def setUnfinishedTIDList(self, max_tid, ttid_list, offset_list):
"""
Populate partitions to replicate. Must be called when partition
table is the one accepted by primary master.
Implies a reset.
"""
partition_list = self.app.pt.getOutdatedOffsetListFor(self.app.uuid)
self.new_partition_set = set(partition_list)
self.partition_dict = {}
self.reset()
def reset(self):
"""Reset attributes to restart replicating."""
self.task_list = []
self.task_dict = {}
self.current_partition = None
self.current_connection = None
self.replication_done = True
def pending(self):
"""Return whether there is any pending partition."""
return bool(self.partition_dict or self.new_partition_set)
def getCurrentOffset(self):
assert self.current_partition is not None
return self.current_partition.getOffset()
def getCurrentCriticalTID(self):
assert self.current_partition is not None
return self.current_partition.getCriticalTID()
def setReplicationDone(self):
""" Callback from ReplicationHandler """
self.replication_done = True
def isCurrentConnection(self, conn):
return self.current_connection is conn
def setUnfinishedTIDList(self, max_tid, ttid_list):
"""This is a callback from MasterOperationHandler.""" """This is a callback from MasterOperationHandler."""
neo.lib.logging.debug('setting unfinished TTIDs %s', if ttid_list:
','.join(map(dump, ttid_list))) self.ttid_set.update(ttid_list)
# all new outdated partition must wait those ttid max_ttid = max(ttid_list)
new_partition_set = self.new_partition_set else:
while new_partition_set: max_ttid = None
offset = new_partition_set.pop() for offset in offset_list:
self.partition_dict[offset] = Partition(offset, max_tid, ttid_list) self.partition_dict[offset].max_ttid = max_ttid
self.waiting_for_unfinished_tids = False self.replicate_dict[offset] = max_tid
self._nextPartition()
def transactionFinished(self, ttid, max_tid): def transactionFinished(self, ttid, max_tid):
""" Callback from MasterOperationHandler """ """ Callback from MasterOperationHandler """
for partition in self.partition_dict.itervalues(): self.ttid_set.remove(ttid)
partition.transactionFinished(ttid, max_tid) min_ttid = min(self.ttid_set) if self.ttid_set else INVALID_TID
for offset, p in self.partition_dict.iteritems():
if p.max_ttid and p.max_ttid < min_ttid:
p.max_ttid = None
self.replicate_dict[offset] = max_tid
self._nextPartition()
def getBackupTID(self):
outdated_set = set(self.app.pt.getOutdatedOffsetListFor(self.app.uuid))
tid = INVALID_TID
for offset, p in self.partition_dict.iteritems():
if offset not in outdated_set:
tid = min(tid, p.next_trans, p.next_obj)
if tid not in (ZERO_TID, INVALID_TID):
return add64(tid, -1)
def _askUnfinishedTIDs(self): def populate(self):
conn = self.app.master_conn
conn.ask(Packets.AskUnfinishedTransactions())
self.waiting_for_unfinished_tids = True
def _startReplication(self):
# Choose a storage node for the source.
app = self.app app = self.app
cell_list = app.pt.getCellList(self.current_partition.getOffset(), pt = app.pt
readable=True) uuid = app.uuid
node_list = [cell.getNode() for cell in cell_list self.partition_dict = p = {}
if cell.getNodeState() == NodeStates.RUNNING] self.replicate_dict = {}
self.source_dict = {}
self.ttid_set = set()
last_tid, last_trans_dict, last_obj_dict = app.dm.getLastTIDs()
backup_tid = app.dm.getBackupTID()
if backup_tid and last_tid < backup_tid:
last_tid = backup_tid
outdated_list = []
for offset in xrange(pt.getPartitions()):
for cell in pt.getCellList(offset):
if cell.getUUID() == uuid:
self.partition_dict[offset] = p = Partition()
if cell.isOutOfDate():
outdated_list.append(offset)
try: try:
node = choice(node_list) p.next_trans = add64(last_trans_dict[offset], 1)
except IndexError: except KeyError:
# Not operational. p.next_trans = ZERO_TID
neo.lib.logging.error('not operational', exc_info = 1) p.next_obj = last_obj_dict.get(offset, ZERO_TID)
self.current_partition = None p.max_ttid = INVALID_TID
return else:
p.next_trans = p.next_obj = last_tid
p.max_ttid = None
if outdated_list:
self.app.master_conn.ask(Packets.AskUnfinishedTransactions(),
offset_list=outdated_list)
addr = node.getAddress() def notifyPartitionChanges(self, cell_list):
if addr is None: """This is a callback from MasterOperationHandler."""
neo.lib.logging.error("no address known for the selected node %s" % abort = False
(dump(node.getUUID()), )) added_list = []
app = self.app
for offset, uuid, state in cell_list:
if uuid == app.uuid:
if state == CellStates.DISCARDED:
del self.partition_dict[offset]
self.replicate_dict.pop(offset, None)
self.source_dict.pop(offset, None)
abort = abort or self.current_partition == offset
elif state == CellStates.OUT_OF_DATE:
assert offset not in self.partition_dict
self.partition_dict[offset] = p = Partition()
p.next_trans = p.next_obj = ZERO_TID
p.max_ttid = INVALID_TID
added_list.append(offset)
if added_list:
self.app.master_conn.ask(Packets.AskUnfinishedTransactions(),
offset_list=added_list)
if abort:
self.abort()
def backup(self, tid, source_dict):
for offset in source_dict:
self.replicate_dict[offset] = tid
self.source_dict.update(source_dict)
self._nextPartition()
def _nextPartition(self):
# XXX: One connection to another storage may remain open forever.
# All other previous connections are automatically closed
# after some time of inactivity.
# This should be improved in several ways:
# - Keeping connections open between 2 clusters (backup case) is
# quite a good thing because establishing a connection costs
# time/bandwidth and replication is actually never finished.
# - When all storages of a non-backup cluster are up-to-date,
# there's no reason to keep any connection open.
if self.current_partition is not None or not self.replicate_dict:
return return
app = self.app
connection = self.current_connection # Choose a partition with no unfinished transaction if possible.
if connection is None or connection.getAddress() != addr: for offset in self.replicate_dict:
handler = replication.ReplicationHandler(app) if not self.partition_dict[offset].max_ttid:
self.current_connection = ClientConnection(app.em, handler, break
try:
addr, name = self.source_dict[offset]
except KeyError:
assert self.app.pt.getCell(offset, self.app.uuid).isOutOfDate()
node = random.choice([cell.getNode()
for cell in app.pt.getCellList(offset, readable=True)
if cell.getNodeState() == NodeStates.RUNNING])
name = None
else:
node = app.nm.getByAddress(addr)
if node is None:
assert name, addr
node = app.nm.createStorage(address=addr)
self.current_partition = offset
previous_node = self.current_node
self.current_node = node
if node.isConnected():
node.getConnection().asClient()
self.fetchTransactions()
if node is previous_node:
return
else:
assert name or node.getUUID() != app.uuid, "loopback connection"
conn = ClientConnection(app.em, StorageOperationHandler(app),
node=node, connector=app.connector_handler()) node=node, connector=app.connector_handler())
p = Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
app.uuid, app.server, app.name) None if name else app.uuid, app.server, name or app.name))
self.current_connection.ask(p) if previous_node is not None and previous_node.isConnected():
if connection is not None: previous_node.getConnection().closeClient()
connection.close()
def fetchTransactions(self, min_tid=None):
offset = self.current_partition
p = self.partition_dict[offset]
if min_tid:
p.next_trans = min_tid
else: else:
connection.getHandler().startReplication(connection)
self.replication_done = False
def _finishReplication(self):
# TODO: remove try..except: pass
try: try:
# Notify to a primary master node that my cell is now up-to-date. addr, name = self.source_dict[offset]
conn = self.app.master_conn
offset = self.current_partition.getOffset()
self.partition_dict.pop(offset)
conn.notify(Packets.NotifyReplicationDone(offset))
except KeyError: except KeyError:
pass pass
if self.pending():
self.current_partition = None
else: else:
self.current_connection.close() if addr != self.current_node.getAddress():
return self.abort()
def act(self): min_tid = p.next_trans
self.replicate_tid = self.replicate_dict.pop(offset)
if self.current_partition is not None: neo.lib.logging.debug("starting replication of <partition=%u"
# Don't end replication until we have received all expected " min_tid=%u max_tid=%u> from %r", offset, u64(min_tid),
# answers, as we might have asked object data just before the last u64(self.replicate_tid), self.current_node)
# AnswerCheckSerialRange. max_tid = self.replicate_tid
if self.replication_done and \ tid_list = self.app.dm.getReplicationTIDList(min_tid, max_tid,
not self.current_connection.isPending(): FETCH_COUNT, offset)
# finish a replication self.current_node.getConnection().ask(Packets.AskFetchTransactions(
neo.lib.logging.info('replication is done for %s' % offset, FETCH_COUNT, min_tid, max_tid, tid_list))
(self.current_partition.getOffset(), ))
self._finishReplication() def fetchObjects(self, min_tid=None, min_oid=ZERO_OID):
return offset = self.current_partition
p = self.partition_dict[offset]
if self.waiting_for_unfinished_tids: max_tid = self.replicate_tid
# Still waiting. if min_tid:
neo.lib.logging.debug('waiting for unfinished tids') if p.next_obj < self.next_backup_tid:
self.app.dm.setBackupTID(min_tid)
else:
min_tid = p.next_obj
p.next_trans = p.next_obj = add64(max_tid, 1)
if self.app.dm.getBackupTID() is None or \
self.app.pt.getCell(offset, self.app.uuid).isOutOfDate():
self.next_backup_tid = ZERO_TID
else:
self.next_backup_tid = self.getBackupTID()
p.next_obj = min_tid
object_dict = {}
for serial, oid in self.app.dm.getReplicationObjectList(min_tid,
max_tid, FETCH_COUNT, offset, min_oid):
try:
object_dict[serial].append(oid)
except KeyError:
object_dict[serial] = [oid]
self.current_node.getConnection().ask(Packets.AskFetchObjects(
offset, FETCH_COUNT, min_tid, max_tid, min_oid, object_dict))
def finish(self):
offset = self.current_partition
tid = self.replicate_tid
del self.current_partition, self.replicate_tid, self.next_backup_tid
p = self.partition_dict[offset]
p.next_obj = add64(tid, 1)
self.app.dm.setBackupTID(self.getBackupTID())
if not p.max_ttid:
p = Packets.NotifyReplicationDone(offset, tid)
self.app.master_conn.notify(p)
neo.lib.logging.debug("partition %u replicated up to %u from %r",
offset, u64(tid), self.current_node)
self._nextPartition()
def abort(self, message=''):
offset = self.current_partition
if offset is None:
return return
del self.current_partition
if self.new_partition_set: neo.lib.logging.warning('replication aborted for partition %u%s',
# Ask pending transactions. offset, message and ' (%s)' % message)
neo.lib.logging.debug('asking unfinished tids') if self.app.master_node is None:
self._askUnfinishedTIDs()
return return
if offset in self.partition_dict:
# Try to select something. # XXX: Try another partition if possible, to increase probability to
for partition in self.partition_dict.values(): # connect to another node. It would be better to explicitely
# XXX: replication could start up to the initial critical tid, that # search for another node instead.
# is below the pending transactions, then finish when all pending tid = self.replicate_dict.pop(offset, None) or self.replicate_tid
# transactions are committed. if self.replicate_dict:
if partition.safe(): self._nextPartition()
self.current_partition = partition self.replicate_dict[offset] = tid
break
else: else:
# Not yet. self.replicate_dict[offset] = tid
neo.lib.logging.debug('not ready yet') self._nextPartition()
return else: # partition removed
self._nextPartition()
self._startReplication()
def removePartition(self, offset):
"""This is a callback from MasterOperationHandler."""
self.partition_dict.pop(offset, None)
self.new_partition_set.discard(offset)
def addPartition(self, offset):
"""This is a callback from MasterOperationHandler."""
if not self.partition_dict.has_key(offset):
self.new_partition_set.add(offset)
def _addTask(self, key, func, args=(), kw=None):
task = Task(func, args, kw)
task_dict = self.task_dict
if key in task_dict:
raise ValueError, 'Task with key %r already exists (%r), cannot ' \
'add %r' % (key, task_dict[key], task)
task_dict[key] = task
self.task_list.append(task)
def processDelayedTasks(self):
task_list = self.task_list
if task_list:
for task in task_list:
task.process()
self.task_list = []
def checkTIDRange(self, min_tid, max_tid, length, partition):
self._addTask(('TID', min_tid, length),
self.app.dm.checkTIDRange, (min_tid, max_tid, length, partition))
def checkSerialRange(self, min_oid, min_serial, max_tid, length,
partition):
self._addTask(('Serial', min_oid, min_serial, length),
self.app.dm.checkSerialRange, (min_oid, min_serial, max_tid, length,
partition))
def getTIDsFrom(self, min_tid, max_tid, length, partition):
self._addTask('TIDsFrom', self.app.dm.getReplicationTIDList,
(min_tid, max_tid, length, partition))
def getObjectHistoryFrom(self, min_oid, min_serial, max_serial, length,
partition):
self._addTask('ObjectHistoryFrom', self.app.dm.getObjectHistoryFrom,
(min_oid, min_serial, max_serial, length, partition))
def _getCheckResult(self, key):
return self.task_dict.pop(key).getResult()
def getTIDCheckResult(self, min_tid, length):
return self._getCheckResult(('TID', min_tid, length))
def getSerialCheckResult(self, min_oid, min_serial, length):
return self._getCheckResult(('Serial', min_oid, min_serial, length))
def getTIDsFromResult(self):
return self._getCheckResult('TIDsFrom')
def getObjectHistoryFromResult(self):
return self._getCheckResult('ObjectHistoryFrom')
...@@ -131,6 +131,11 @@ class NeoTestBase(unittest.TestCase): ...@@ -131,6 +131,11 @@ class NeoTestBase(unittest.TestCase):
sys.stdout.write('\n') sys.stdout.write('\n')
sys.stdout.flush() sys.stdout.flush()
class failureException(AssertionError):
def __init__(self, msg=None):
neo.lib.logging.error(msg)
AssertionError.__init__(self, msg)
failIfEqual = failUnlessEqual = assertEquals = assertNotEquals = None failIfEqual = failUnlessEqual = assertEquals = assertNotEquals = None
def assertNotEqual(self, first, second, msg=None): def assertNotEqual(self, first, second, msg=None):
......
...@@ -25,6 +25,7 @@ import signal ...@@ -25,6 +25,7 @@ import signal
import random import random
import weakref import weakref
import MySQLdb import MySQLdb
import sqlite3
import unittest import unittest
import tempfile import tempfile
import traceback import traceback
...@@ -242,9 +243,15 @@ class NEOCluster(object): ...@@ -242,9 +243,15 @@ class NEOCluster(object):
self.cleanup_on_delete = cleanup_on_delete self.cleanup_on_delete = cleanup_on_delete
self.verbose = verbose self.verbose = verbose
self.uuid_set = set() self.uuid_set = set()
self.db_list = db_list
if adapter == 'MySQL':
self.db_user = db_user self.db_user = db_user
self.db_password = db_password self.db_password = db_password
self.db_list = db_list self.db_template = '%s:%s@%%s' % (db_user, db_password)
elif adapter == 'SQLite':
self.db_template = os.path.join(temp_dir, '%s.sqlite')
else:
assert False, adapter
self.address_type = address_type self.address_type = address_type
self.local_ip = local_ip = IP_VERSION_FORMAT_DICT[self.address_type] self.local_ip = local_ip = IP_VERSION_FORMAT_DICT[self.address_type]
self.setupDB(clear_databases) self.setupDB(clear_databases)
...@@ -290,7 +297,7 @@ class NEOCluster(object): ...@@ -290,7 +297,7 @@ class NEOCluster(object):
self.local_ip), self.local_ip),
0 ), 0 ),
'--masters': self.master_nodes, '--masters': self.master_nodes,
'--database': '%s:%s@%s' % (db_user, db_password, db), '--database': self.db_template % db,
'--adapter': adapter, '--adapter': adapter,
}) })
# create neoctl # create neoctl
...@@ -316,6 +323,17 @@ class NEOCluster(object): ...@@ -316,6 +323,17 @@ class NEOCluster(object):
if self.adapter == 'MySQL': if self.adapter == 'MySQL':
setupMySQLdb(self.db_list, self.db_user, self.db_password, setupMySQLdb(self.db_list, self.db_user, self.db_password,
clear_databases) clear_databases)
elif self.adapter == 'SQLite':
if clear_databases:
for db in self.db_list:
try:
os.remove(self.db_template % db)
except OSError, e:
if e.errno != errno.ENOENT:
raise
else:
neo.lib.logging.debug('%r deleted',
db_template % db)
def run(self, except_storages=()): def run(self, except_storages=()):
""" Start cluster processes except some storage nodes """ """ Start cluster processes except some storage nodes """
...@@ -402,11 +420,14 @@ class NEOCluster(object): ...@@ -402,11 +420,14 @@ class NEOCluster(object):
db = ZODB.DB(storage=self.getZODBStorage(**kw)) db = ZODB.DB(storage=self.getZODBStorage(**kw))
return (db, db.open()) return (db, db.open())
def getSQLConnection(self, db, autocommit=False): def getSQLConnection(self, db):
assert db in self.db_list assert db in self.db_list
if self.adapter == 'MySQL':
conn = MySQLdb.Connect(user=self.db_user, passwd=self.db_password, conn = MySQLdb.Connect(user=self.db_user, passwd=self.db_password,
db=db) db=db)
conn.autocommit(autocommit) conn.autocommit(True)
elif self.adapter == 'SQLite':
conn = sqlite3.connect(self.db_template % db, isolation_level=None)
return conn return conn
def _getProcessList(self, type): def _getProcessList(self, type):
......
...@@ -234,6 +234,9 @@ class ClientTests(NEOFunctionalTest): ...@@ -234,6 +234,9 @@ class ClientTests(NEOFunctionalTest):
temp_dir=self.getTempDirectory()) temp_dir=self.getTempDirectory())
neoctl = self.neo.getNEOCTL() neoctl = self.neo.getNEOCTL()
self.neo.start() self.neo.start()
# BUG: The following 2 lines creates 2 app, i.e. 2 TCP connections
# to the storage, so there may be a race condition at network
# level and 'st2.store' may be effective before 'st1.store'.
db1, conn1 = self.neo.getZODBConnection() db1, conn1 = self.neo.getZODBConnection()
db2, conn2 = self.neo.getZODBConnection() db2, conn2 = self.neo.getZODBConnection()
st1, st2 = conn1._storage, conn2._storage st1, st2 = conn1._storage, conn2._storage
......
...@@ -35,7 +35,7 @@ class ClusterTests(NEOFunctionalTest): ...@@ -35,7 +35,7 @@ class ClusterTests(NEOFunctionalTest):
def testClusterStartup(self): def testClusterStartup(self):
neo = NEOCluster(['test_neo1', 'test_neo2'], replicas=1, neo = NEOCluster(['test_neo1', 'test_neo2'], replicas=1,
adapter='MySQL', temp_dir=self.getTempDirectory()) temp_dir=self.getTempDirectory())
neoctl = neo.getNEOCTL() neoctl = neo.getNEOCTL()
neo.run() neo.run()
# Runing a new cluster doesn't exit Recovery state. # Runing a new cluster doesn't exit Recovery state.
......
...@@ -23,7 +23,7 @@ from persistent import Persistent ...@@ -23,7 +23,7 @@ from persistent import Persistent
from . import NEOCluster, NEOFunctionalTest from . import NEOCluster, NEOFunctionalTest
from neo.lib.protocol import ClusterStates, NodeStates from neo.lib.protocol import ClusterStates, NodeStates
from ZODB.tests.StorageTestBase import zodb_pickle from ZODB.tests.StorageTestBase import zodb_pickle
from MySQLdb import ProgrammingError import MySQLdb, sqlite3
from MySQLdb.constants.ER import NO_SUCH_TABLE from MySQLdb.constants.ER import NO_SUCH_TABLE
class PObject(Persistent): class PObject(Persistent):
...@@ -46,9 +46,11 @@ class StorageTests(NEOFunctionalTest): ...@@ -46,9 +46,11 @@ class StorageTests(NEOFunctionalTest):
NEOFunctionalTest.tearDown(self) NEOFunctionalTest.tearDown(self)
def queryCount(self, db, query): def queryCount(self, db, query):
try:
db.query(query) db.query(query)
result = db.store_result().fetch_row()[0][0] except AttributeError:
return result return db.execute(query).fetchone()[0]
return db.store_result().fetch_row()[0][0]
def __setup(self, storage_number=2, pending_number=0, replicas=1, def __setup(self, storage_number=2, pending_number=0, replicas=1,
partitions=10, master_count=2): partitions=10, master_count=2):
...@@ -58,7 +60,6 @@ class StorageTests(NEOFunctionalTest): ...@@ -58,7 +60,6 @@ class StorageTests(NEOFunctionalTest):
partitions=partitions, replicas=replicas, partitions=partitions, replicas=replicas,
temp_dir=self.getTempDirectory(), temp_dir=self.getTempDirectory(),
clear_databases=True, clear_databases=True,
adapter='MySQL',
) )
# too many pending storage nodes requested # too many pending storage nodes requested
assert pending_number <= storage_number assert pending_number <= storage_number
...@@ -80,7 +81,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -80,7 +81,7 @@ class StorageTests(NEOFunctionalTest):
db.close() db.close()
def __checkDatabase(self, db_name): def __checkDatabase(self, db_name):
db = self.neo.getSQLConnection(db_name, autocommit=True) db = self.neo.getSQLConnection(db_name)
# wait for the sql transaction to be commited # wait for the sql transaction to be commited
def callback(last_try): def callback(last_try):
object_number = self.queryCount(db, 'select count(*) from obj') object_number = self.queryCount(db, 'select count(*) from obj')
...@@ -124,13 +125,16 @@ class StorageTests(NEOFunctionalTest): ...@@ -124,13 +125,16 @@ class StorageTests(NEOFunctionalTest):
def __checkReplicateCount(self, db_name, target_count, timeout=0, delay=1): def __checkReplicateCount(self, db_name, target_count, timeout=0, delay=1):
db = self.neo.getSQLConnection(db_name, autocommit=True) db = self.neo.getSQLConnection(db_name, autocommit=True)
def callback(last_try): def callback(last_try):
replicate_count = 0
try: try:
replicate_count = self.queryCount(db, replicate_count = self.queryCount(db,
'select count(distinct uuid) from pt') 'select count(distinct uuid) from pt')
except ProgrammingError, exc: except MySQLdb.ProgrammingError, e:
if exc[0] != NO_SUCH_TABLE: if e[0] != NO_SUCH_TABLE:
raise
except sqlite3.OperationalError, e:
if not e[0].startswith('no such table:'):
raise raise
replicate_count = 0
if last_try is not None and last_try < replicate_count: if last_try is not None and last_try < replicate_count:
raise AssertionError, 'Regression: %s became %s' % \ raise AssertionError, 'Regression: %s became %s' % \
(last_try, replicate_count) (last_try, replicate_count)
......
...@@ -85,7 +85,7 @@ class MasterRecoveryTests(NeoUnitTestBase): ...@@ -85,7 +85,7 @@ class MasterRecoveryTests(NeoUnitTestBase):
self.assertTrue(ptid2 > self.app.pt.getID()) self.assertTrue(ptid2 > self.app.pt.getID())
self.assertTrue(oid2 > self.app.tm.getLastOID()) self.assertTrue(oid2 > self.app.tm.getLastOID())
self.assertTrue(tid2 > self.app.tm.getLastTID()) self.assertTrue(tid2 > self.app.tm.getLastTID())
recovery.answerLastIDs(conn, oid2, tid2, ptid2) recovery.answerLastIDs(conn, oid2, tid2, ptid2, None)
self.assertEqual(oid2, self.app.tm.getLastOID()) self.assertEqual(oid2, self.app.tm.getLastOID())
self.assertEqual(tid2, self.app.tm.getLastTID()) self.assertEqual(tid2, self.app.tm.getLastTID())
self.assertEqual(ptid2, recovery.target_ptid) self.assertEqual(ptid2, recovery.target_ptid)
......
...@@ -130,10 +130,11 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -130,10 +130,11 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.app.tm.setLastTID(tid) self.app.tm.setLastTID(tid)
service.askLastIDs(conn) service.askLastIDs(conn)
packet = self.checkAnswerLastIDs(conn) packet = self.checkAnswerLastIDs(conn)
loid, ltid, lptid = packet.decode() loid, ltid, lptid, backup_tid = packet.decode()
self.assertEqual(loid, oid) self.assertEqual(loid, oid)
self.assertEqual(ltid, tid) self.assertEqual(ltid, tid)
self.assertEqual(lptid, ptid) self.assertEqual(lptid, ptid)
self.assertEqual(backup_tid, None)
def test_13_askUnfinishedTransactions(self): def test_13_askUnfinishedTransactions(self):
service = self.service service = self.service
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import math, os, random, sys, time import math, os, random, sys, time
from cStringIO import StringIO from cStringIO import StringIO
from persistent.TimeStamp import TimeStamp from persistent.TimeStamp import TimeStamp
from ZODB.utils import p64, newTid from ZODB.utils import p64, u64
from ZODB.BaseStorage import TransactionRecord from ZODB.BaseStorage import TransactionRecord
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
...@@ -44,6 +44,7 @@ class DummyZODB(object): ...@@ -44,6 +44,7 @@ class DummyZODB(object):
self.new_ratio = new_ratio self.new_ratio = new_ratio
self.next_oid = 0 self.next_oid = 0
self.err_count = 0 self.err_count = 0
self.tid = u64('TID\0\0\0\0\0')
def __call__(self): def __call__(self):
variate = self.random.lognormvariate variate = self.random.lognormvariate
...@@ -63,9 +64,11 @@ class DummyZODB(object): ...@@ -63,9 +64,11 @@ class DummyZODB(object):
yield p64(oid), int(round(variate(self.obj_size_mu, yield p64(oid), int(round(variate(self.obj_size_mu,
self.obj_size_sigma))) or 1 self.obj_size_sigma))) or 1
def as_storage(self, transaction_count, dummy_data_file=None): def as_storage(self, stop, dummy_data_file=None):
if dummy_data_file is None: if dummy_data_file is None:
dummy_data_file = DummyData(self.random) dummy_data_file = DummyData(self.random)
if isinstance(stop, int):
stop = (lambda x: lambda y: x <= y)(stop)
class dummy_change(object): class dummy_change(object):
data_txn = None data_txn = None
version = '' version = ''
...@@ -97,12 +100,14 @@ class DummyZODB(object): ...@@ -97,12 +100,14 @@ class DummyZODB(object):
size = 0 size = 0
def iterator(storage, *args): def iterator(storage, *args):
args = ' ', '', '', {} args = ' ', '', '', {}
tid = None i = 0
for i in xrange(1, transaction_count+1): variate = self.random.lognormvariate
tid = newTid(tid) while not stop(i):
t = dummy_transaction(tid, *args) self.tid += max(1, int(variate(10, 3)))
t = dummy_transaction(p64(self.tid), *args)
storage.size += t.size storage.size += t.size
yield t yield t
i += 1
def getSize(self): def getSize(self):
return self.size return self.size
return dummy_storage() return dummy_storage()
......
...@@ -164,19 +164,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -164,19 +164,6 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
calls[0].checkArgs(tid) calls[0].checkArgs(tid)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
def test_31_answerUnfinishedTransactions(self):
# set unfinished TID on replicator
conn = self.getFakeConnection()
self.app.replicator = Mock()
self.operation.answerUnfinishedTransactions(
conn=conn,
max_tid=INVALID_TID,
ttid_list=(INVALID_TID, ),
)
calls = self.app.replicator.mockGetNamedCalls('setUnfinishedTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(INVALID_TID, (INVALID_TID, ))
def test_askPack(self): def test_askPack(self):
self.app.dm = Mock({'pack': None}) self.app.dm = Mock({'pack': None})
conn = self.getFakeConnection() conn = self.getFakeConnection()
......
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock
from struct import pack
from collections import deque
from .. import NeoUnitTestBase
from neo.storage.database import buildDatabaseManager
from neo.storage.handlers.replication import ReplicationHandler
from neo.storage.handlers.replication import RANGE_LENGTH
from neo.storage.handlers.storage import StorageOperationHandler
from neo.storage.replicator import Replicator
from neo.lib.protocol import ZERO_OID, ZERO_TID
MAX_TRANSACTIONS = 10000
MAX_OBJECTS = 100000
MAX_TID = '\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE' # != INVALID_TID
class FakeConnection(object):
def __init__(self):
self._msg_id = 0
self._queue = deque()
def allocateId(self):
self._msg_id += 1
return self._msg_id
def _addPacket(self, packet, *args, **kw):
packet.setId(self.allocateId())
self._queue.append(packet)
ask = _addPacket
answer = _addPacket
notify = _addPacket
def setPeerId(self, msg_id):
pass
def process(self, dhandler, dconn):
if not self._queue:
return False
while self._queue:
dhandler.dispatch(dconn, self._queue.popleft())
return True
class ReplicationTests(NeoUnitTestBase):
def checkReplicationProcess(self, reference, outdated):
pt = Mock({'getPartitions': 1})
# reference application
rapp = Mock({})
rapp.pt = pt
rapp.dm = reference
rapp.tm = Mock({'loadLocked': False})
mconn = FakeConnection()
rapp.master_conn = mconn
# outdated application
oapp = Mock({})
oapp.dm = outdated
oapp.pt = pt
oapp.master_conn = mconn
oapp.replicator = Replicator(oapp)
oapp.replicator.getCurrentOffset = lambda: 0
oapp.replicator.isCurrentConnection = lambda c: True
oapp.replicator.getCurrentCriticalTID = lambda: MAX_TID
# handlers and connections
rhandler = StorageOperationHandler(rapp)
rconn = FakeConnection()
ohandler = ReplicationHandler(oapp)
oconn = FakeConnection()
# run replication
ohandler.startReplication(oconn)
process = True
while process:
process = oconn.process(rhandler, rconn)
oapp.replicator.processDelayedTasks()
process |= rconn.process(ohandler, oconn)
# check transactions
for tid in reference.getTIDList(0, MAX_TRANSACTIONS, [0]):
self.assertEqual(
reference.getTransaction(tid),
outdated.getTransaction(tid),
)
for tid in outdated.getTIDList(0, MAX_TRANSACTIONS, [0]):
self.assertEqual(
outdated.getTransaction(tid),
reference.getTransaction(tid),
)
# check transactions
params = ZERO_TID, '\xFF' * 8, MAX_TRANSACTIONS, 0
self.assertEqual(
reference.getReplicationTIDList(*params),
outdated.getReplicationTIDList(*params),
)
# check objects
params = ZERO_OID, ZERO_TID, '\xFF' * 8, MAX_OBJECTS, 0
self.assertEqual(
reference.getObjectHistoryFrom(*params),
outdated.getObjectHistoryFrom(*params),
)
def buildStorage(self, transactions, objects, name='BTree', database=None):
def makeid(oid_or_tid):
return pack('!Q', oid_or_tid)
storage = buildDatabaseManager(name, (database, 0))
storage.setup(reset=True)
storage.setNumPartitions(1)
storage._transactions = transactions
storage._objects = objects
# store transactions
for tid in transactions:
transaction = ([ZERO_OID], 'user', 'desc', '', False)
storage.storeTransaction(makeid(tid), [], transaction, False)
# store object history
H = "0" * 20
storage.storeData(H, '', 0)
storage.unlockData((H,))
for tid, oid_list in objects.iteritems():
object_list = [(makeid(oid), H, None) for oid in oid_list]
storage.storeTransaction(makeid(tid), object_list, None, False)
return storage
def testReplication0(self):
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=[1, 2, 3],
objects={1: [1], 2: [1], 3: [1]},
),
outdated=self.buildStorage(
transactions=[],
objects={},
),
)
def testReplication1(self):
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=[1, 2, 3],
objects={1: [1], 2: [1], 3: [1]},
),
outdated=self.buildStorage(
transactions=[1],
objects={1: [1]},
),
)
def testReplication2(self):
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=[1, 2, 3],
objects={1: [1, 2, 3]},
),
outdated=self.buildStorage(
transactions=[1, 2, 3],
objects={1: [1, 2, 3]},
),
)
def testChunkBeginning(self):
ref_number = range(RANGE_LENGTH + 1)
out_number = range(RANGE_LENGTH)
obj_list = [1, 2, 3]
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list),
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list),
),
)
def testChunkEnd(self):
ref_number = range(RANGE_LENGTH)
out_number = range(RANGE_LENGTH - 1)
obj_list = [1, 2, 3]
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list)
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list)
),
)
def testChunkMiddle(self):
obj_list = [1, 2, 3]
ref_number = range(RANGE_LENGTH)
out_number = range(4000)
out_number.remove(3000)
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list)
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list)
),
)
def testFullChunkPart(self):
obj_list = [1, 2, 3]
ref_number = range(1001)
out_number = {}
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list)
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list)
),
)
def testSameData(self):
obj_list = [1, 2, 3]
number = range(RANGE_LENGTH * 2)
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=number,
objects=dict.fromkeys(number, obj_list)
),
outdated=self.buildStorage(
transactions=number,
objects=dict.fromkeys(number, obj_list)
),
)
def testTooManyData(self):
obj_list = [0, 1]
ref_number = range(RANGE_LENGTH)
out_number = range(RANGE_LENGTH + 2)
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=ref_number,
objects=dict.fromkeys(ref_number, obj_list)
),
outdated=self.buildStorage(
transactions=out_number,
objects=dict.fromkeys(out_number, obj_list)
),
)
def testMissingObject(self):
self.checkReplicationProcess(
reference=self.buildStorage(
transactions=[1, 2],
objects=dict.fromkeys([1, 2], [1, 2]),
),
outdated=self.buildStorage(
transactions=[1, 2],
objects=dict.fromkeys([1], [1]),
),
)
if __name__ == "__main__":
unittest.main()
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock
from neo.lib.util import add64
from .. import NeoUnitTestBase
from neo.lib.protocol import Packets, ZERO_OID, ZERO_TID
from neo.storage.handlers.replication import ReplicationHandler
from neo.storage.handlers.replication import RANGE_LENGTH, MIN_RANGE_LENGTH
class FakeDict(object):
def __init__(self, items):
self._items = items
self._dict = dict(items)
assert len(self._dict) == len(items), self._dict
def iteritems(self):
for item in self._items:
yield item
def iterkeys(self):
for key, value in self.iteritems():
yield key
def itervalues(self):
for key, value in self.iteritems():
yield value
def items(self):
return self._items[:]
def keys(self):
return [x for x, y in self._items]
def values(self):
return [y for x, y in self._items]
def __getitem__(self, key):
return self._dict[key]
def __getattr__(self, key):
return getattr(self._dict, key)
def __len__(self):
return len(self._dict)
class StorageReplicationHandlerTests(NeoUnitTestBase):
def setup(self):
pass
def teardown(self):
pass
def getApp(self, conn=None, tid_check_result=(0, 0, ZERO_TID),
serial_check_result=(0, 0, ZERO_OID, 0, ZERO_TID),
tid_result=(),
history_result=None,
rid=0, critical_tid=ZERO_TID,
num_partitions=1,
):
if history_result is None:
history_result = {}
replicator = Mock({
'__repr__': 'Fake replicator',
'reset': None,
'checkSerialRange': None,
'checkTIDRange': None,
'getTIDCheckResult': tid_check_result,
'getSerialCheckResult': serial_check_result,
'getTIDsFromResult': tid_result,
'getObjectHistoryFromResult': history_result,
'checkSerialRange': None,
'checkTIDRange': None,
'getTIDsFrom': None,
'getObjectHistoryFrom': None,
'getCurrentOffset': rid,
'getCurrentCriticalTID': critical_tid,
})
def isCurrentConnection(other_conn):
return other_conn is conn
replicator.isCurrentConnection = isCurrentConnection
real_replicator = replicator
class FakeApp(object):
replicator = real_replicator
dm = Mock({
'storeTransaction': None,
'deleteObject': None,
})
pt = Mock({
'getPartitions': num_partitions,
})
return FakeApp
def _checkReplicationStarted(self, conn, rid, replicator):
min_tid, max_tid, length, partition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(min_tid, ZERO_TID)
self.assertEqual(length, RANGE_LENGTH)
self.assertEqual(partition, rid)
calls = replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, max_tid, length, partition)
def _checkPacketTIDList(self, conn, tid_list, next_tid, app):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
packet_list, next_range = packet_list[:-1], packet_list[-1]
self.assertEqual(type(next_range), Packets.AskCheckTIDRange)
pmin_tid, plength, ppartition = next_range.decode()
self.assertEqual(pmin_tid, add64(next_tid, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, app.replicator.getCurrentOffset())
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, plength, ppartition)
self.assertEqual(len(packet_list), len(tid_list))
for packet in packet_list:
self.assertEqual(type(packet),
Packets.AskTransactionInformation)
ptid = packet.decode()[0]
for tid in tid_list:
if ptid == tid:
tid_list.remove(tid)
break
else:
raise AssertionFailed('%s not found in %r'
% (dump(ptid), map(dump, tid_list)))
def _checkPacketSerialList(self, conn, object_list, next_oid, next_serial, app):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
packet_list, next_range = packet_list[:-1], packet_list[-1]
self.assertEqual(type(next_range), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_range.decode()
self.assertEqual(pmin_oid, next_oid)
self.assertEqual(pmin_serial, add64(next_serial, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, app.replicator.getCurrentOffset())
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, plength, ppartition)
self.assertEqual(len(packet_list), len(object_list),
([x.decode() for x in packet_list], object_list))
reference_set = set((x + (None, ) for x in object_list))
packet_set = set((x.decode() for x in packet_list))
assert len(packet_list) == len(reference_set) == len(packet_set)
self.assertEqual(reference_set, packet_set)
def test_connectionLost(self):
app = self.getApp()
ReplicationHandler(app).connectionLost(None, None)
self.assertEqual(len(app.replicator.mockGetNamedCalls('storageLost')), 1)
def test_connectionFailed(self):
app = self.getApp()
ReplicationHandler(app).connectionFailed(None)
self.assertEqual(len(app.replicator.mockGetNamedCalls('storageLost')), 1)
def test_acceptIdentification(self):
rid = 24
app = self.getApp(rid=rid)
conn = self.getFakeConnection()
replication = ReplicationHandler(app)
replication.acceptIdentification(conn, None, None, None,
None, None)
self._checkReplicationStarted(conn, rid, app.replicator)
def test_startReplication(self):
rid = 24
app = self.getApp(rid=rid)
conn = self.getFakeConnection()
ReplicationHandler(app).startReplication(conn)
self._checkReplicationStarted(conn, rid, app.replicator)
def test_answerTIDsFrom(self):
conn = self.getFakeConnection()
tid_list = [self.getOID(0), self.getOID(1), self.getOID(2)]
app = self.getApp(conn=conn, tid_result=[])
# With no known TID
ReplicationHandler(app).answerTIDsFrom(conn, tid_list)
# With some TIDs known
conn = self.getFakeConnection()
known_tid_list = [tid_list[0], tid_list[1]]
unknown_tid_list = [tid_list[2], ]
app = self.getApp(conn=conn, tid_result=known_tid_list)
ReplicationHandler(app).answerTIDsFrom(conn, tid_list[1:])
calls = app.dm.mockGetNamedCalls('deleteTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid_list[0])
def test_answerTransactionInformation(self):
conn = self.getFakeConnection()
app = self.getApp(conn=conn)
tid = self.getNextTID()
user = 'foo'
desc = 'bar'
ext = 'baz'
packed = True
oid_list = [self.getOID(1), self.getOID(2)]
ReplicationHandler(app).answerTransactionInformation(conn, tid, user,
desc, ext, packed, oid_list)
calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, (), (oid_list, user, desc, ext, packed), False)
def test_answerObjectHistoryFrom(self):
conn = self.getFakeConnection()
oid_1 = self.getOID(1)
oid_2 = self.getOID(2)
oid_3 = self.getOID(3)
oid_4 = self.getOID(4)
oid_5 = self.getOID(5)
tid_list = map(self.getOID, xrange(7))
oid_dict = FakeDict((
(oid_1, [tid_list[0], tid_list[1]]),
(oid_2, [tid_list[2], tid_list[3]]),
(oid_4, [tid_list[5]]),
))
flat_oid_list = []
for oid, serial_list in oid_dict.iteritems():
for serial in serial_list:
flat_oid_list.append((oid, serial))
app = self.getApp(conn=conn, history_result={})
# With no known OID/Serial
ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
# With some known OID/Serials
# For test to be realist, history_result should contain the same
# number of serials as oid_dict, otherise it just tests the special
# case of the last check in a partition.
conn = self.getFakeConnection()
app = self.getApp(conn=conn, history_result={
oid_1: [oid_dict[oid_1][0], ],
oid_3: [tid_list[2]],
oid_4: [tid_list[4], oid_dict[oid_4][0], tid_list[6]],
oid_5: [tid_list[6]],
})
ReplicationHandler(app).answerObjectHistoryFrom(conn, oid_dict)
calls = app.dm.mockGetNamedCalls('deleteObject')
actual_deletes = set(((x.getParam(0), x.getParam(1)) for x in calls))
expected_deletes = set((
(oid_3, tid_list[2]),
(oid_4, tid_list[4]),
))
self.assertEqual(actual_deletes, expected_deletes)
def test_answerObject(self):
conn = self.getFakeConnection()
app = self.getApp(conn=conn)
oid = self.getOID(1)
serial_start = self.getNextTID()
serial_end = self.getNextTID()
compression = 1
checksum = "0" * 20
data = 'foo'
data_serial = None
app.dm.mockAddReturnValues(storeData=checksum)
ReplicationHandler(app).answerObject(conn, oid, serial_start,
serial_end, compression, checksum, data, data_serial)
calls = app.dm.mockGetNamedCalls('storeTransaction')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(serial_start, [(oid, checksum, data_serial)],
None, False)
# CheckTIDRange
def test_answerCheckTIDFullRangeIdenticalChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
assert max_tid < critical_tid
length = RANGE_LENGTH
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with next chunk
pmin_tid, pmax_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
def test_answerCheckTIDSmallRangeIdenticalChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
assert max_tid < critical_tid
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with next chunk
pmin_tid, pmax_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(pmin_tid, add64(max_tid, 1))
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkAboveCriticalTID(self):
critical_tid = self.getNextTID()
min_tid = self.getNextTID()
max_tid = self.getNextTID()
assert critical_tid < max_tid
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: go on with object range checks
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckTIDRangeIdenticalChunkWithoutNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
num_partitions = 13
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 1, 0, max_tid), rid=rid,
conn=conn, num_partitions=num_partitions)
handler = ReplicationHandler(app)
# Peer has the same data as we have: length, checksum and max_tid
# match.
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
max_tid)
# Result: go on with object range checks
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, ZERO_OID)
self.assertEqual(pmin_serial, ZERO_TID)
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteTransactionsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(rid, add64(max_tid, 1), ZERO_TID)
def test_answerCheckTIDRangeDifferentBigChunk(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
assert min_tid < max_tid < critical_tid, (min_tid, max_tid,
critical_tid)
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: ask again, length halved
pmin_tid, pmax_tid, plength, ppartition = self.checkAskPacket(conn,
Packets.AskCheckTIDRange, decode=True)
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition)
def test_answerCheckTIDRangeDifferentSmallChunkWithNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckTIDRange(conn, min_tid, length, length, 0, max_tid)
# Result: ask tid list, and ask next chunk
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1)
tid_packet = calls[0].getParam(0)
self.assertEqual(type(tid_packet), Packets.AskTIDsFrom)
pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length)
self.assertEqual(ppartition, [rid])
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition[0])
def test_answerCheckTIDRangeDifferentSmallChunkWithoutNext(self):
min_tid = self.getNextTID()
max_tid = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_tid), rid=rid,
conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has different data, and less than length
handler.answerCheckTIDRange(conn, min_tid, length, length - 1, 0,
max_tid)
# Result: ask tid list, and start replicating object range
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2)
tid_packet = calls[0].getParam(0)
self.assertEqual(type(tid_packet), Packets.AskTIDsFrom)
pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid)
self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, [rid])
calls = app.replicator.mockGetNamedCalls('getTIDsFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_tid, pmax_tid, plength, ppartition[0])
# CheckSerialRange
def test_answerCheckSerialFullRangeIdenticalChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH
rid = 12
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: go on with next chunk
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, RANGE_LENGTH)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckSerialSmallRangeIdenticalChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: go on with next chunk
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, max_oid)
self.assertEqual(pmin_serial, add64(max_serial, 1))
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckSerialRangeIdenticalChunkWithoutNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
num_partitions = 13
conn = self.getFakeConnection()
app = self.getApp(serial_check_result=(length - 1, 0, max_oid, 1,
max_serial), rid=rid, conn=conn, num_partitions=num_partitions)
handler = ReplicationHandler(app)
# Peer has the same data as we have
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length - 1, 0, max_oid, 1, max_serial)
# Result: mark replication as done
self.checkNoPacketSent(conn)
self.assertTrue(app.replicator.replication_done)
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(rid, max_oid, add64(max_serial, 1), ZERO_TID)
def test_answerCheckSerialRangeDifferentBigChunk(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = RANGE_LENGTH / 2
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
max_serial), rid=rid, conn=conn)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: ask again, length halved
pmin_oid, pmin_serial, pmax_tid, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskCheckSerialRange, decode=True)
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(plength, length / 2)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_tid, plength, ppartition)
def test_answerCheckSerialRangeDifferentSmallChunkWithNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid, 1,
max_serial), rid=rid, conn=conn, critical_tid=critical_tid)
handler = ReplicationHandler(app)
# Peer has different data
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length, 0, max_oid, 1, max_serial)
# Result: ask serial list, and ask next chunk
calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1)
serial_packet = calls[0].getParam(0)
self.assertEqual(type(serial_packet), Packets.AskObjectHistoryFrom)
pmin_oid, pmin_serial, pmax_serial, plength, ppartition = \
serial_packet.decode()
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(pmax_serial, critical_tid)
self.assertEqual(plength, length)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_serial, plength,
ppartition)
def test_answerCheckSerialRangeDifferentSmallChunkWithoutNext(self):
min_oid = self.getOID(1)
max_oid = self.getOID(10)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
critical_tid = self.getNextTID()
length = MIN_RANGE_LENGTH - 1
rid = 12
num_partitions = 13
conn = self.getFakeConnection()
app = self.getApp(tid_check_result=(length - 5, 0, max_oid,
1, max_serial), rid=rid, conn=conn, critical_tid=critical_tid,
num_partitions=num_partitions,
)
handler = ReplicationHandler(app)
# Peer has different data, and less than length
handler.answerCheckSerialRange(conn, min_oid, min_serial, length,
length - 1, 0, max_oid, 1, max_serial)
# Result: ask tid list, and mark replication as done
pmin_oid, pmin_serial, pmax_serial, plength, ppartition = \
self.checkAskPacket(conn, Packets.AskObjectHistoryFrom,
decode=True)
self.assertEqual(pmin_oid, min_oid)
self.assertEqual(pmin_serial, min_serial)
self.assertEqual(pmax_serial, critical_tid)
self.assertEqual(plength, length - 1)
self.assertEqual(ppartition, rid)
calls = app.replicator.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(pmin_oid, pmin_serial, pmax_serial, plength,
ppartition)
self.assertTrue(app.replicator.replication_done)
# ...and delete partition tail
calls = app.dm.mockGetNamedCalls('deleteObjectsAbove')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(rid, max_oid, add64(max_serial, 1), critical_tid)
if __name__ == "__main__":
unittest.main()
#
# Copyright (C) 2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock, ReturnValues
from .. import NeoUnitTestBase
from neo.storage.replicator import Replicator, Partition, Task
from neo.lib.protocol import CellStates, NodeStates, Packets
class StorageReplicatorTests(NeoUnitTestBase):
def setup(self):
pass
def teardown(self):
pass
def test_populate(self):
my_uuid = self.getNewUUID()
other_uuid = self.getNewUUID()
app = Mock()
app.uuid = my_uuid
app.pt = Mock({
'getPartitions': 2,
'getOutdatedOffsetListFor': [0],
})
replicator = Replicator(app)
self.assertEqual(replicator.new_partition_set, set())
replicator.populate()
self.assertEqual(replicator.new_partition_set, set([0]))
def test_reset(self):
replicator = Replicator(None)
replicator.task_list = ['foo']
replicator.task_dict = {'foo': 'bar'}
replicator.current_partition = 'foo'
replicator.current_connection = 'foo'
replicator.replication_done = 'foo'
replicator.reset()
self.assertEqual(replicator.task_list, [])
self.assertEqual(replicator.task_dict, {})
self.assertEqual(replicator.current_partition, None)
self.assertEqual(replicator.current_connection, None)
self.assertTrue(replicator.replication_done)
def test_setCriticalTID(self):
critical_tid = self.getNextTID()
partition = Partition(0, critical_tid, [])
self.assertEqual(partition.getCriticalTID(), critical_tid)
self.assertEqual(partition.getOffset(), 0)
def test_act(self):
# Also tests "pending"
uuid = self.getNewUUID()
master_uuid = self.getNewUUID()
critical_tid_0 = self.getNextTID()
critical_tid_1 = self.getNextTID()
critical_tid_2 = self.getNextTID()
unfinished_ttid_1 = self.getOID(1)
unfinished_ttid_2 = self.getOID(2)
app = Mock()
app.server = ('127.0.0.1', 10000)
app.name = 'fake cluster'
app.em = Mock({
'register': None,
})
def connectorGenerator():
return Mock()
app.connector_handler = connectorGenerator
app.uuid = uuid
node_addr = ('127.0.0.1', 1234)
node = Mock({
'getAddress': node_addr,
})
running_cell = Mock({
'getNodeState': NodeStates.RUNNING,
'getNode': node,
})
unknown_cell = Mock({
'getNodeState': NodeStates.UNKNOWN,
})
app.pt = Mock({
'getCellList': [running_cell, unknown_cell],
'getOutdatedOffsetListFor': [0],
'getPartition': 0,
})
node_conn_handler = Mock({
'startReplication': None,
})
node_conn = Mock({
'getAddress': node_addr,
'getHandler': node_conn_handler,
})
replicator = Replicator(app)
replicator.populate()
def act():
app.master_conn = self.getFakeConnection(uuid=master_uuid)
self.assertTrue(replicator.pending())
replicator.act()
# ask unfinished tids
act()
unfinished_tids = app.master_conn.mockGetNamedCalls('ask')[0].getParam(0)
self.assertTrue(replicator.new_partition_set)
self.assertEqual(type(unfinished_tids),
Packets.AskUnfinishedTransactions)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# nothing happens until waiting_for_unfinished_tids becomes False
act()
self.checkNoPacketSent(app.master_conn)
self.assertTrue(replicator.waiting_for_unfinished_tids)
# first time, there is an unfinished tid before critical tid,
# replication cannot start, and unfinished TIDs are asked again
replicator.setUnfinishedTIDList(critical_tid_0,
[unfinished_ttid_1, unfinished_ttid_2])
self.assertFalse(replicator.waiting_for_unfinished_tids)
# Note: detection that nothing can be replicated happens on first call
# and unfinished tids are asked again on second call. This is ok, but
# might change, so just call twice.
act()
replicator.transactionFinished(unfinished_ttid_1, critical_tid_1)
act()
replicator.transactionFinished(unfinished_ttid_2, critical_tid_2)
replicator.current_connection = node_conn
act()
self.assertEqual(replicator.current_partition,
replicator.partition_dict[0])
self.assertEqual(len(node_conn_handler.mockGetNamedCalls(
'startReplication')), 1)
self.assertFalse(replicator.replication_done)
# Other calls should do nothing
replicator.current_connection = Mock()
act()
self.checkNoPacketSent(app.master_conn)
self.checkNoPacketSent(replicator.current_connection)
# Mark replication over for this partition
replicator.replication_done = True
# Don't finish while there are pending answers
replicator.current_connection = Mock({
'isPending': True,
})
act()
self.assertTrue(replicator.pending())
replicator.current_connection = Mock({
'isPending': False,
})
act()
# also, replication is over
self.assertFalse(replicator.pending())
def test_removePartition(self):
replicator = Replicator(None)
replicator.partition_dict = {0: None, 2: None}
replicator.new_partition_set = set([1])
replicator.removePartition(0)
self.assertEqual(replicator.partition_dict, {2: None})
self.assertEqual(replicator.new_partition_set, set([1]))
replicator.removePartition(1)
replicator.removePartition(2)
self.assertEqual(replicator.partition_dict, {})
self.assertEqual(replicator.new_partition_set, set())
# Must not raise
replicator.removePartition(3)
def test_addPartition(self):
replicator = Replicator(None)
replicator.partition_dict = {0: None}
replicator.new_partition_set = set([1])
replicator.addPartition(0)
replicator.addPartition(1)
self.assertEqual(replicator.partition_dict, {0: None})
self.assertEqual(replicator.new_partition_set, set([1]))
replicator.addPartition(2)
self.assertEqual(replicator.partition_dict, {0: None})
self.assertEqual(len(replicator.new_partition_set), 2)
self.assertEqual(replicator.new_partition_set, set([1, 2]))
def test_processDelayedTasks(self):
replicator = Replicator(None)
replicator.reset()
marker = []
def someCallable(foo, bar=None):
return (foo, bar)
replicator._addTask(1, someCallable, args=('foo', ))
self.assertRaises(ValueError, replicator._addTask, 1, None)
replicator._addTask(2, someCallable, args=('foo', ), kw={'bar': 'bar'})
replicator.processDelayedTasks()
self.assertEqual(replicator._getCheckResult(1), ('foo', None))
self.assertEqual(replicator._getCheckResult(2), ('foo', 'bar'))
# Also test Task
task = Task(someCallable, args=('foo', ))
self.assertRaises(ValueError, task.getResult)
task.process()
self.assertRaises(ValueError, task.process)
self.assertEqual(task.getResult(), ('foo', None))
if __name__ == "__main__":
unittest.main()
...@@ -18,11 +18,10 @@ ...@@ -18,11 +18,10 @@
import unittest import unittest
from mock import Mock from mock import Mock
from neo.lib.util import dump, p64, u64 from neo.lib.util import dump, p64, u64
from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID from neo.lib.protocol import CellStates, ZERO_OID, ZERO_TID, MAX_TID
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.exception import DatabaseFailure from neo.lib.exception import DatabaseFailure
MAX_TID = '\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFE' # != INVALID_TID
class StorageDBTests(NeoUnitTestBase): class StorageDBTests(NeoUnitTestBase):
...@@ -74,7 +73,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -74,7 +73,7 @@ class StorageDBTests(NeoUnitTestBase):
def checkConfigEntry(self, get_call, set_call, value): def checkConfigEntry(self, get_call, set_call, value):
# generic test for all configuration entries accessors # generic test for all configuration entries accessors
self.assertRaises(KeyError, get_call) self.assertEqual(get_call(), None)
set_call(value) set_call(value)
self.assertEqual(get_call(), value) self.assertEqual(get_call(), value)
set_call(value * 2) set_call(value * 2)
...@@ -92,6 +91,29 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -92,6 +91,29 @@ class StorageDBTests(NeoUnitTestBase):
db = self.getDB() db = self.getDB()
self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1)) self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1))
def test_transaction(self):
db = self.getDB()
x = []
class DB(db.__class__):
begin = lambda self: x.append('begin')
commit = lambda self: x.append('commit')
rollback = lambda self: x.append('rollback')
db.__class__ = DB
with db:
self.assertEqual(x.pop(), 'begin')
self.assertEqual(x.pop(), 'commit')
try:
with db:
self.assertEqual(x.pop(), 'begin')
with db:
self.fail()
self.fail()
except DatabaseFailure:
pass
self.assertEqual(x.pop(), 'rollback')
self.assertRaises(DatabaseFailure, db.__exit__, None, None, None)
self.assertFalse(x)
def test_getPartitionTable(self): def test_getPartitionTable(self):
db = self.getDB() db = self.getDB()
ptid = self.getPTID(1) ptid = self.getPTID(1)
...@@ -128,21 +150,22 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -128,21 +150,22 @@ class StorageDBTests(NeoUnitTestBase):
def checkSet(self, list1, list2): def checkSet(self, list1, list2):
self.assertEqual(set(list1), set(list2)) self.assertEqual(set(list1), set(list2))
def test_getLastTID(self): def test_getLastTIDs(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
txn, objs = self.getTransaction([oid1, oid2]) txn, objs = self.getTransaction([oid1, oid2])
# max TID is in obj table
self.db.storeTransaction(tid1, objs, txn, False) self.db.storeTransaction(tid1, objs, txn, False)
self.db.storeTransaction(tid2, objs, txn, False) self.db.storeTransaction(tid2, objs, txn, False)
self.assertEqual(self.db.getLastTID(), tid2) self.assertEqual(self.db.getLastTIDs(), (tid2, {0: tid2}, {0: tid2}))
# max tid is in ttrans table
self.db.storeTransaction(tid3, objs, txn) self.db.storeTransaction(tid3, objs, txn)
result = self.db.getLastTID() tids = {0: tid2, None: tid3}
self.assertEqual(self.db.getLastTID(), tid3) self.assertEqual(self.db.getLastTIDs(), (tid3, tids, tids))
# max tid is in tobj (serial)
self.db.storeTransaction(tid4, objs, None) self.db.storeTransaction(tid4, objs, None)
self.assertEqual(self.db.getLastTID(), tid4) self.assertEqual(self.db.getLastTIDs(),
(tid4, tids, {0: tid2, None: tid4}))
self.db.finishTransaction(tid3)
self.assertEqual(self.db.getLastTIDs(),
(tid4, {0: tid3}, {0: tid3, None: tid4}))
def test_getUnfinishedTIDList(self): def test_getUnfinishedTIDList(self):
tid1, tid2, tid3, tid4 = self.getTIDs(4) tid1, tid2, tid3, tid4 = self.getTIDs(4)
...@@ -294,7 +317,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -294,7 +317,7 @@ class StorageDBTests(NeoUnitTestBase):
txn1, objs1 = self.getTransaction([oid1]) txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2]) txn2, objs2 = self.getTransaction([oid2])
# nothing in database # nothing in database
self.assertEqual(self.db.getLastTID(), None) self.assertEqual(self.db.getLastTIDs(), (None, {}, {}))
self.assertEqual(self.db.getUnfinishedTIDList(), []) self.assertEqual(self.db.getUnfinishedTIDList(), [])
self.assertEqual(self.db.getObject(oid1), None) self.assertEqual(self.db.getObject(oid1), None)
self.assertEqual(self.db.getObject(oid2), None) self.assertEqual(self.db.getObject(oid2), None)
...@@ -362,24 +385,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -362,24 +385,6 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getTransaction(tid1, True), None) self.assertEqual(self.db.getTransaction(tid1, True), None)
self.assertEqual(self.db.getTransaction(tid2, True), None) self.assertEqual(self.db.getTransaction(tid2, True), None)
def test_deleteTransactionsAbove(self):
self.setNumPartitions(2)
tid1 = self.getOID(0)
tid2 = self.getOID(1)
tid3 = self.getOID(2)
oid1 = self.getOID(1)
for tid in (tid1, tid2, tid3):
txn, objs = self.getTransaction([oid1])
self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid)
self.db.deleteTransactionsAbove(0, tid2, tid3)
# Right partition, below cutoff
self.assertNotEqual(self.db.getTransaction(tid1, True), None)
# Wrong partition, above cutoff
self.assertNotEqual(self.db.getTransaction(tid2, True), None)
# Right partition, above cutoff
self.assertEqual(self.db.getTransaction(tid3, True), None)
def test_deleteObject(self): def test_deleteObject(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
...@@ -397,34 +402,28 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -397,34 +402,28 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid2, tid=tid2), self.assertEqual(self.db.getObject(oid2, tid=tid2),
(tid2, None, 1, "0" * 20, '', None)) (tid2, None, 1, "0" * 20, '', None))
def test_deleteObjectsAbove(self): def test_deleteRange(self):
self.setNumPartitions(2) np = 4
tid1 = self.getOID(1) self.setNumPartitions(np)
tid2 = self.getOID(2) t1, t2, t3 = map(self.getOID, (1, 2, 3))
tid3 = self.getOID(3) oid_list = self.getOIDs(np * 2)
oid1 = self.getOID(0) for tid in t1, t2, t3:
oid2 = self.getOID(1) txn, objs = self.getTransaction(oid_list)
oid3 = self.getOID(2)
for tid in (tid1, tid2, tid3):
txn, objs = self.getTransaction([oid1, oid2, oid3])
self.db.storeTransaction(tid, objs, txn) self.db.storeTransaction(tid, objs, txn)
self.db.finishTransaction(tid) self.db.finishTransaction(tid)
self.db.deleteObjectsAbove(0, oid1, tid2, tid3) def check(offset, tid_list, *tids):
# Check getObjectHistoryFrom because MySQL adapter use two tables self.assertEqual(self.db.getReplicationTIDList(ZERO_TID,
# that must be synchronized MAX_TID, len(tid_list) + 1, offset), tid_list)
self.assertEqual(self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, expected = [(t, oid_list[offset+i]) for t in tids for i in 0, np]
MAX_TID, 10, 0), {oid1: [tid1]}) self.assertEqual(self.db.getReplicationObjectList(ZERO_TID,
# Right partition, below cutoff MAX_TID, len(expected) + 1, offset, ZERO_OID), expected)
self.assertNotEqual(self.db.getObject(oid1, tid=tid1), None) self.db._deleteRange(0, MAX_TID)
# Right partition, above tid cutoff self.db._deleteRange(0, max_tid=ZERO_TID)
self.assertFalse(self.db.getObject(oid1, tid=tid2)) check(0, [], t1, t2, t3)
self.assertFalse(self.db.getObject(oid1, tid=tid3)) self.db._deleteRange(0); check(0, [])
# Wrong partition, above cutoff self.db._deleteRange(1, t2); check(1, [t1], t1, t2)
self.assertNotEqual(self.db.getObject(oid2, tid=tid1), None) self.db._deleteRange(2, max_tid=t2); check(2, [], t3)
self.assertNotEqual(self.db.getObject(oid2, tid=tid2), None) self.db._deleteRange(3, t1, t2); check(3, [t3], t1, t3)
self.assertNotEqual(self.db.getObject(oid2, tid=tid3), None)
# Right partition, above cutoff
self.assertEqual(self.db.getObject(oid3), None)
def test_getTransaction(self): def test_getTransaction(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
...@@ -467,59 +466,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -467,59 +466,6 @@ class StorageDBTests(NeoUnitTestBase):
result = self.db.getObjectHistory(oid, 2, 3) result = self.db.getObjectHistory(oid, 2, 3)
self.assertEqual(result, None) self.assertEqual(result, None)
def test_getObjectHistoryFrom(self):
self.setNumPartitions(2)
oid1 = self.getOID(0)
oid2 = self.getOID(2)
oid3 = self.getOID(1)
tid1, tid2, tid3, tid4, tid5 = self.getTIDs(5)
txn1, objs1 = self.getTransaction([oid1])
txn2, objs2 = self.getTransaction([oid2])
txn3, objs3 = self.getTransaction([oid1])
txn4, objs4 = self.getTransaction([oid2])
txn5, objs5 = self.getTransaction([oid3])
self.db.storeTransaction(tid1, objs1, txn1)
self.db.storeTransaction(tid2, objs2, txn2)
self.db.storeTransaction(tid3, objs3, txn3)
self.db.storeTransaction(tid4, objs4, txn4)
self.db.storeTransaction(tid5, objs5, txn5)
self.db.finishTransaction(tid1)
self.db.finishTransaction(tid2)
self.db.finishTransaction(tid3)
self.db.finishTransaction(tid4)
self.db.finishTransaction(tid5)
# Check full result
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Lower bound is inclusive
result = self.db.getObjectHistoryFrom(oid1, tid1, MAX_TID, 10, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2, tid4],
})
# Upper bound is inclusive
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, tid3, 10, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Length is total number of serials
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 3, 0)
self.assertEqual(result, {
oid1: [tid1, tid3],
oid2: [tid2],
})
# Partition constraints are honored
result = self.db.getObjectHistoryFrom(ZERO_OID, ZERO_TID, MAX_TID, 10,
1)
self.assertEqual(result, {
oid3: [tid5],
})
def _storeTransactions(self, count): def _storeTransactions(self, count):
# use OID generator to know result of tid % N # use OID generator to know result of tid % N
tid_list = self.getOIDs(count) tid_list = self.getOIDs(count)
......
#
# Copyright (C) 2009-2010 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest
from mock import Mock
from collections import deque
from .. import NeoUnitTestBase
from neo.storage.app import Application
from neo.storage.handlers.storage import StorageOperationHandler
from neo.lib.protocol import INVALID_PARTITION, Packets
from neo.lib.protocol import INVALID_TID, INVALID_OID
class StorageStorageHandlerTests(NeoUnitTestBase):
def checkHandleUnexpectedPacket(self, _call, _msg_type, _listening=True, **kwargs):
conn = self.getFakeConnection(address=("127.0.0.1", self.master_port),
is_server=_listening)
# hook
self.operation.peerBroken = lambda c: c.peerBrokendCalled()
self.checkUnexpectedPacketRaised(_call, conn=conn, **kwargs)
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
self.app.transaction_dict = {}
self.app.store_lock_dict = {}
self.app.load_lock_dict = {}
self.app.event_queue = deque()
self.app.event_queue_dict = {}
# handler
self.operation = StorageOperationHandler(self.app)
# set pmn
self.master_uuid = self.getNewUUID()
pmn = self.app.nm.getMasterList()[0]
pmn.setUUID(self.master_uuid)
self.app.primary_master_node = pmn
self.master_port = 10010
def test_18_askTransactionInformation1(self):
# transaction does not exists
conn = self.getFakeConnection()
self.app.dm = Mock({'getNumPartitions': 1})
self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn)
def test_18_askTransactionInformation2(self):
# answer
conn = self.getFakeConnection()
tid = self.getNextTID()
oid_list = [self.getOID(1), self.getOID(2)]
dm = Mock({"getTransaction": (oid_list, 'user', 'desc', '', False), })
self.app.dm = dm
self.operation.askTransactionInformation(conn, tid)
self.checkAnswerTransactionInformation(conn)
def test_24_askObject1(self):
# delayed response
conn = self.getFakeConnection()
oid = self.getOID(1)
tid = self.getNextTID()
serial = self.getNextTID()
self.app.dm = Mock()
self.app.tm = Mock({'loadLocked': True})
self.app.load_lock_dict[oid] = object()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
self.assertEqual(len(self.app.event_queue), 1)
self.checkNoPacketSent(conn)
self.assertEqual(len(self.app.dm.mockGetNamedCalls('getObject')), 0)
def test_24_askObject2(self):
# invalid serial / tid / packet not found
self.app.dm = Mock({'getObject': None})
conn = self.getFakeConnection()
oid = self.getOID(1)
tid = self.getNextTID()
serial = self.getNextTID()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
calls = self.app.dm.mockGetNamedCalls('getObject')
self.assertEqual(len(self.app.event_queue), 0)
self.assertEqual(len(calls), 1)
calls[0].checkArgs(oid, serial, tid)
self.checkErrorPacket(conn)
def test_24_askObject3(self):
oid = self.getOID(1)
tid = self.getNextTID()
serial = self.getNextTID()
next_serial = self.getNextTID()
H = "0" * 20
# object found => answer
self.app.dm = Mock({'getObject': (serial, next_serial, 0, H, '', None)})
conn = self.getFakeConnection()
self.assertEqual(len(self.app.event_queue), 0)
self.operation.askObject(conn, oid=oid, serial=serial, tid=tid)
self.assertEqual(len(self.app.event_queue), 0)
self.checkAnswerObject(conn)
def test_25_askTIDsFrom(self):
# well case => answer
conn = self.getFakeConnection()
self.app.dm = Mock({'getReplicationTIDList': (INVALID_TID, )})
self.app.pt = Mock({'getPartitions': 1})
tid = self.getNextTID()
tid2 = self.getNextTID()
self.operation.askTIDsFrom(conn, tid, tid2, 2, [1])
calls = self.app.dm.mockGetNamedCalls('getReplicationTIDList')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(tid, tid2, 2, 1)
self.checkAnswerTidsFrom(conn)
def test_26_askObjectHistoryFrom(self):
min_oid = self.getOID(2)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = 4
partition = 8
num_partitions = 16
tid = self.getNextTID()
conn = self.getFakeConnection()
self.app.dm = Mock({'getObjectHistoryFrom': {min_oid: [tid]},})
self.app.pt = Mock({
'getPartitions': num_partitions,
})
self.operation.askObjectHistoryFrom(conn, min_oid, min_serial,
max_serial, length, partition)
self.checkAnswerObjectHistoryFrom(conn)
calls = self.app.dm.mockGetNamedCalls('getObjectHistoryFrom')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, max_serial, length, partition)
def test_askCheckTIDRange(self):
count = 1
tid_checksum = "1" * 20
min_tid = self.getNextTID()
num_partitions = 4
length = 5
partition = 6
max_tid = self.getNextTID()
self.app.dm = Mock({'checkTIDRange': (count, tid_checksum, max_tid)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection()
self.operation.askCheckTIDRange(conn, min_tid, max_tid, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkTIDRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_tid, max_tid, length, partition)
pmin_tid, plength, pcount, ptid_checksum, pmax_tid = \
self.checkAnswerPacket(conn, Packets.AnswerCheckTIDRange,
decode=True)
self.assertEqual(min_tid, pmin_tid)
self.assertEqual(length, plength)
self.assertEqual(count, pcount)
self.assertEqual(tid_checksum, ptid_checksum)
self.assertEqual(max_tid, pmax_tid)
def test_askCheckSerialRange(self):
count = 1
oid_checksum = "2" * 20
min_oid = self.getOID(1)
num_partitions = 4
length = 5
partition = 6
serial_checksum = "3" * 20
min_serial = self.getNextTID()
max_serial = self.getNextTID()
max_oid = self.getOID(2)
self.app.dm = Mock({'checkSerialRange': (count, oid_checksum, max_oid,
serial_checksum, max_serial)})
self.app.pt = Mock({'getPartitions': num_partitions})
conn = self.getFakeConnection()
self.operation.askCheckSerialRange(conn, min_oid, min_serial,
max_serial, length, partition)
calls = self.app.dm.mockGetNamedCalls('checkSerialRange')
self.assertEqual(len(calls), 1)
calls[0].checkArgs(min_oid, min_serial, max_serial, length, partition)
pmin_oid, pmin_serial, plength, pcount, poid_checksum, pmax_oid, \
pserial_checksum, pmax_serial = self.checkAnswerPacket(conn,
Packets.AnswerCheckSerialRange, decode=True)
self.assertEqual(min_oid, pmin_oid)
self.assertEqual(min_serial, pmin_serial)
self.assertEqual(length, plength)
self.assertEqual(count, pcount)
self.assertEqual(oid_checksum, poid_checksum)
self.assertEqual(max_oid, pmax_oid)
self.assertEqual(serial_checksum, pserial_checksum)
self.assertEqual(max_serial, pmax_serial)
if __name__ == "__main__":
unittest.main()
...@@ -35,11 +35,6 @@ class StorageMySQSLdbTests(StorageDBTests): ...@@ -35,11 +35,6 @@ class StorageMySQSLdbTests(StorageDBTests):
db.setup(reset) db.setup(reset)
return db return db
def checkCalledQuery(self, query=None, call=0):
self.assertTrue(len(self.db.conn.mockGetNamedCalls('query')) > call)
call = self.db.conn.mockGetNamedCalls('query')[call]
call.checkArgs('BEGIN')
def test_MySQLDatabaseManagerInit(self): def test_MySQLDatabaseManagerInit(self):
db = MySQLDatabaseManager('%s@%s' % (NEO_SQL_USER, NEO_SQL_DATABASE), db = MySQLDatabaseManager('%s@%s' % (NEO_SQL_USER, NEO_SQL_DATABASE),
0) 0)
...@@ -48,30 +43,6 @@ class StorageMySQSLdbTests(StorageDBTests): ...@@ -48,30 +43,6 @@ class StorageMySQSLdbTests(StorageDBTests):
self.assertEqual(db.user, NEO_SQL_USER) self.assertEqual(db.user, NEO_SQL_USER)
# & connect # & connect
self.assertTrue(isinstance(db.conn, MySQLdb.connection)) self.assertTrue(isinstance(db.conn, MySQLdb.connection))
self.assertFalse(db.isUnderTransaction())
def test_begin(self):
# no current transaction
self.db.conn = Mock({ })
self.assertFalse(self.db.isUnderTransaction())
self.db.begin()
self.checkCalledQuery(query='COMMIT')
self.assertTrue(self.db.isUnderTransaction())
def test_commit(self):
self.db.conn = Mock()
self.db.begin()
self.db.commit()
self.assertEqual(len(self.db.conn.mockGetNamedCalls('commit')), 1)
self.assertFalse(self.db.isUnderTransaction())
def test_rollback(self):
# rollback called and no current transaction
self.db.conn = Mock({ })
self.db.under_transaction = True
self.db.rollback()
self.assertEqual(len(self.db.conn.mockGetNamedCalls('rollback')), 1)
self.assertFalse(self.db.isUnderTransaction())
def test_query1(self): def test_query1(self):
# fake result object # fake result object
......
...@@ -16,15 +16,13 @@ ...@@ -16,15 +16,13 @@
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import unittest import unittest
from mock import Mock
from .testStorageDBTests import StorageDBTests from .testStorageDBTests import StorageDBTests
from neo.storage.database.btree import BTreeDatabaseManager from neo.storage.database.sqlite import SQLiteDatabaseManager
class StorageBTreeTests(StorageDBTests): class StorageSQLiteTests(StorageDBTests):
def getDB(self, reset=0): def getDB(self, reset=0):
# db manager db = SQLiteDatabaseManager(':memory:', 0)
db = BTreeDatabaseManager('', 0)
db.setup(reset) db.setup(reset)
return db return db
......
...@@ -68,40 +68,6 @@ class StorageVerificationHandlerTests(NeoUnitTestBase): ...@@ -68,40 +68,6 @@ class StorageVerificationHandlerTests(NeoUnitTestBase):
# nothing happens # nothing happens
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
def test_07_askLastIDs(self):
conn = self.getClientConnection()
last_ptid = self.getPTID(1)
last_oid = self.getOID(2)
self.app.pt = Mock({'getID': last_ptid})
class DummyDM(object):
def getLastOID(self):
raise KeyError
getLastTID = getLastOID
self.app.dm = DummyDM()
self.verification.askLastIDs(conn)
oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True)
self.assertEqual(oid, None)
self.assertEqual(tid, None)
self.assertEqual(ptid, last_ptid)
# return value stored in db
conn = self.getClientConnection()
self.app.dm = Mock({
'getLastOID': last_oid,
'getLastTID': p64(4),
})
self.verification.askLastIDs(conn)
oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True)
self.assertEqual(oid, last_oid)
self.assertEqual(u64(tid), 4)
self.assertEqual(ptid, self.app.pt.getID())
call_list = self.app.dm.mockGetNamedCalls('getLastOID')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs()
call_list = self.app.dm.mockGetNamedCalls('getLastTID')
self.assertEqual(len(call_list), 1)
call_list[0].checkArgs()
def test_08_askPartitionTable(self): def test_08_askPartitionTable(self):
node = self.app.nm.createStorage( node = self.app.nm.createStorage(
address=("127.7.9.9", 1), address=("127.7.9.9", 1),
......
...@@ -160,16 +160,6 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -160,16 +160,6 @@ class ProtocolTests(NeoUnitTestBase):
p = Packets.AskLastIDs() p = Packets.AskLastIDs()
self.assertEqual(p.decode(), ()) self.assertEqual(p.decode(), ())
def test_19_answerLastIDs(self):
oid = self.getNextTID()
tid = self.getNextTID()
ptid = self.getPTID()
p = Packets.AnswerLastIDs(oid, tid, ptid)
loid, ltid, lptid = p.decode()
self.assertEqual(loid, oid)
self.assertEqual(ltid, tid)
self.assertEqual(lptid, ptid)
def test_20_askPartitionTable(self): def test_20_askPartitionTable(self):
self.assertEqual(Packets.AskPartitionTable().decode(), ()) self.assertEqual(Packets.AskPartitionTable().decode(), ())
...@@ -638,40 +628,16 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -638,40 +628,16 @@ class ProtocolTests(NeoUnitTestBase):
def test_AskTIDsFrom(self): def test_AskTIDsFrom(self):
tid = self.getNextTID() tid = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
p = Packets.AskTIDsFrom(tid, tid2, 1000, [5]) p = Packets.AskTIDsFrom(tid, tid2, 1000, 5)
min_tid, max_tid, length, partition = p.decode() min_tid, max_tid, length, partition = p.decode()
self.assertEqual(min_tid, tid) self.assertEqual(min_tid, tid)
self.assertEqual(max_tid, tid2) self.assertEqual(max_tid, tid2)
self.assertEqual(length, 1000) self.assertEqual(length, 1000)
self.assertEqual(partition, [5]) self.assertEqual(partition, 5)
def test_AnswerTIDsFrom(self): def test_AnswerTIDsFrom(self):
self._test_AnswerTIDs(Packets.AnswerTIDsFrom) self._test_AnswerTIDs(Packets.AnswerTIDsFrom)
def test_AskObjectHistoryFrom(self):
oid = self.getOID(1)
min_serial = self.getNextTID()
max_serial = self.getNextTID()
length = 5
partition = 4
p = Packets.AskObjectHistoryFrom(oid, min_serial, max_serial, length,
partition)
p_oid, p_min_serial, p_max_serial, p_length, p_partition = p.decode()
self.assertEqual(p_oid, oid)
self.assertEqual(p_min_serial, min_serial)
self.assertEqual(p_max_serial, max_serial)
self.assertEqual(p_length, length)
self.assertEqual(p_partition, partition)
def test_AnswerObjectHistoryFrom(self):
object_dict = {}
for int_oid in xrange(4):
object_dict[self.getOID(int_oid)] = [self.getNextTID() \
for _ in xrange(5)]
p = Packets.AnswerObjectHistoryFrom(object_dict)
p_object_dict = p.decode()[0]
self.assertEqual(object_dict, p_object_dict)
def test_AskCheckTIDRange(self): def test_AskCheckTIDRange(self):
min_tid = self.getNextTID() min_tid = self.getNextTID()
max_tid = self.getNextTID() max_tid = self.getNextTID()
......
...@@ -32,8 +32,9 @@ from neo.lib.connection import BaseConnection, Connection ...@@ -32,8 +32,9 @@ from neo.lib.connection import BaseConnection, Connection
from neo.lib.connector import SocketConnector, \ from neo.lib.connector import SocketConnector, \
ConnectorConnectionRefusedException, ConnectorTryAgainException ConnectorConnectionRefusedException, ConnectorTryAgainException
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes, \
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList UUID_NAMESPACES, INVALID_UUID
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList, p64
from .. import NeoTestBase, getTempDirectory, setupMySQLdb, \ from .. import NeoTestBase, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
...@@ -293,38 +294,18 @@ class StorageApplication(ServerNode, neo.storage.app.Application): ...@@ -293,38 +294,18 @@ class StorageApplication(ServerNode, neo.storage.app.Application):
pass pass
def switchTables(self): def switchTables(self):
adapter = self._init_args['getAdapter'] with self.dm as q:
dm = self.dm
if adapter == 'BTree':
dm._obj, dm._tobj = dm._tobj, dm._obj
dm._trans, dm._ttrans = dm._ttrans, dm._trans
uncommitted_data = dm._uncommitted_data
for checksum, (_, _, index) in dm._data.iteritems():
uncommitted_data[checksum] = len(index)
index.clear()
elif adapter == 'MySQL':
q = dm.query
dm.begin()
for table in ('trans', 'obj'): for table in ('trans', 'obj'):
q('RENAME TABLE %s to tmp' % table) q('ALTER TABLE %s RENAME TO tmp' % table)
q('RENAME TABLE t%s to %s' % (table, table)) q('ALTER TABLE t%s RENAME TO %s' % (table, table))
q('RENAME TABLE tmp to t%s' % table) q('ALTER TABLE tmp RENAME TO t%s' % table)
dm.commit()
else:
assert False
def getDataLockInfo(self): def getDataLockInfo(self):
adapter = self._init_args['getAdapter']
dm = self.dm dm = self.dm
if adapter == 'BTree':
checksum_dict = dict((x, x) for x in dm._data)
elif adapter == 'MySQL':
checksum_dict = dict(dm.query("SELECT id, hash FROM data")) checksum_dict = dict(dm.query("SELECT id, hash FROM data"))
else:
assert False
assert set(dm._uncommitted_data).issubset(checksum_dict) assert set(dm._uncommitted_data).issubset(checksum_dict)
get = dm._uncommitted_data.get get = dm._uncommitted_data.get
return dict((v, get(k, 0)) for k, v in checksum_dict.iteritems()) return dict((str(v), get(k, 0)) for k, v in checksum_dict.iteritems())
class ClientApplication(Node, neo.client.app.Application): class ClientApplication(Node, neo.client.app.Application):
...@@ -406,13 +387,15 @@ class Patch(object): ...@@ -406,13 +387,15 @@ class Patch(object):
class ConnectionFilter(object): class ConnectionFilter(object):
filtered_count = 0
def __init__(self, *conns): def __init__(self, *conns):
self.filter_dict = {} self.filter_dict = {}
self.lock = threading.Lock() self.lock = threading.Lock()
self.conn_list = [(conn, self._patch(conn)) for conn in conns] self.conn_list = [(conn, self._patch(conn)) for conn in conns]
def _patch(self, conn): def _patch(self, conn):
assert '_addPacket' not in conn.__dict__ assert '_addPacket' not in conn.__dict__, "already patched"
lock = self.lock lock = self.lock
filter_dict = self.filter_dict filter_dict = self.filter_dict
orig = conn.__class__._addPacket orig = conn.__class__._addPacket
...@@ -423,6 +406,7 @@ class ConnectionFilter(object): ...@@ -423,6 +406,7 @@ class ConnectionFilter(object):
if not queue: if not queue:
for filter in filter_dict: for filter in filter_dict:
if filter(conn, packet): if filter(conn, packet):
self.filtered_count += 1
break break
else: else:
return orig(conn, packet) return orig(conn, packet)
...@@ -551,8 +535,8 @@ class NEOCluster(object): ...@@ -551,8 +535,8 @@ class NEOCluster(object):
SocketConnector.send = cls.SocketConnector_send SocketConnector.send = cls.SocketConnector_send
Storage.setupLog = setupLog Storage.setupLog = setupLog
def __init__(self, master_count=1, partitions=1, replicas=0, def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'), adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
storage_count=None, db_list=None, clear_databases=True, storage_count=None, db_list=None, clear_databases=True,
db_user=DB_USER, db_password='', verbose=None): db_user=DB_USER, db_password='', verbose=None):
if verbose is not None: if verbose is not None:
...@@ -570,6 +554,10 @@ class NEOCluster(object): ...@@ -570,6 +554,10 @@ class NEOCluster(object):
weak_self = weakref.proxy(self) weak_self = weakref.proxy(self)
kw = dict(cluster=weak_self, getReplicas=replicas, getAdapter=adapter, kw = dict(cluster=weak_self, getReplicas=replicas, getAdapter=adapter,
getPartitions=partitions, getReset=clear_databases) getPartitions=partitions, getReset=clear_databases)
if upstream is not None:
self.upstream = weakref.proxy(upstream)
kw.update(getUpstreamCluster=upstream.name,
getUpstreamMasters=parseMasterList(upstream.master_nodes))
self.master_list = [MasterApplication(address=x, **kw) self.master_list = [MasterApplication(address=x, **kw)
for x in master_list] for x in master_list]
if db_list is None: if db_list is None:
...@@ -581,8 +569,8 @@ class NEOCluster(object): ...@@ -581,8 +569,8 @@ class NEOCluster(object):
if adapter == 'MySQL': if adapter == 'MySQL':
setupMySQLdb(db_list, db_user, db_password, clear_databases) setupMySQLdb(db_list, db_user, db_password, clear_databases)
db = '%s:%s@%%s' % (db_user, db_password) db = '%s:%s@%%s' % (db_user, db_password)
elif adapter == 'BTree': elif adapter == 'SQLite':
db = '%s' db = os.path.join(getTempDirectory(), '%s.sqlite')
else: else:
assert False, adapter assert False, adapter
self.storage_list = [StorageApplication(getDatabase=db % x, **kw) self.storage_list = [StorageApplication(getDatabase=db % x, **kw)
...@@ -607,6 +595,11 @@ class NEOCluster(object): ...@@ -607,6 +595,11 @@ class NEOCluster(object):
return admin return admin
### ###
@property
def primary_master(self):
master, = [master for master in self.master_list if master.primary]
return master
def reset(self, clear_database=False): def reset(self, clear_database=False):
for node_type in 'master', 'storage', 'admin': for node_type in 'master', 'storage', 'admin':
kw = {} kw = {}
...@@ -635,7 +628,7 @@ class NEOCluster(object): ...@@ -635,7 +628,7 @@ class NEOCluster(object):
self._startCluster() self._startCluster()
self.tic() self.tic()
state = self.neoctl.getClusterState() state = self.neoctl.getClusterState()
assert state == ClusterStates.RUNNING, state assert state in (ClusterStates.RUNNING, ClusterStates.BACKINGUP), state
self.enableStorageList(storage_list) self.enableStorageList(storage_list)
def _startCluster(self): def _startCluster(self):
...@@ -644,6 +637,7 @@ class NEOCluster(object): ...@@ -644,6 +637,7 @@ class NEOCluster(object):
except RuntimeError: except RuntimeError:
self.tic() self.tic()
if self.neoctl.getClusterState() not in ( if self.neoctl.getClusterState() not in (
ClusterStates.BACKINGUP,
ClusterStates.RUNNING, ClusterStates.RUNNING,
ClusterStates.VERIFYING, ClusterStates.VERIFYING,
): ):
...@@ -704,7 +698,7 @@ class NEOCluster(object): ...@@ -704,7 +698,7 @@ class NEOCluster(object):
self.client.setPoll(True) self.client.setPoll(True)
return Storage.Storage(None, self.name, _app=self.client, **kw) return Storage.Storage(None, self.name, _app=self.client, **kw)
def populate(self, dummy_zodb=None, random=random): def importZODB(self, dummy_zodb=None, random=random):
if dummy_zodb is None: if dummy_zodb is None:
from ..stat_zodb import PROD1 from ..stat_zodb import PROD1
dummy_zodb = PROD1(random) dummy_zodb = PROD1(random)
...@@ -713,6 +707,20 @@ class NEOCluster(object): ...@@ -713,6 +707,20 @@ class NEOCluster(object):
return lambda count: self.getZODBStorage().importFrom( return lambda count: self.getZODBStorage().importFrom(
as_storage(count), preindex=preindex) as_storage(count), preindex=preindex)
def populate(self, transaction_list, tid=lambda i: p64(i+1),
oid=lambda i: p64(i+1)):
storage = self.getZODBStorage()
tid_dict = {}
for i, oid_list in enumerate(transaction_list):
txn = transaction.Transaction()
storage.tpc_begin(txn, tid(i))
for o in oid_list:
storage.store(p64(o), tid_dict.get(o), repr((i, o)), '', txn)
storage.tpc_vote(txn)
i = storage.tpc_finish(txn)
for o in oid_list:
tid_dict[o] = i
def getTransaction(self): def getTransaction(self):
txn = transaction.TransactionManager() txn = transaction.TransactionManager()
return txn, self.db.open(transaction_manager=txn) return txn, self.db.open(transaction_manager=txn)
...@@ -774,3 +782,28 @@ class NEOThreadedTest(NeoTestBase): ...@@ -774,3 +782,28 @@ class NEOThreadedTest(NeoTestBase):
etype, value, tb = self.__exc_info etype, value, tb = self.__exc_info
del self.__exc_info del self.__exc_info
raise etype, value, tb raise etype, value, tb
def predictable_random(seed=None):
# Because we have 2 running threads when client works, we can't
# patch neo.client.pool (and cluster should have 1 storage).
from neo.master import backup_app
from neo.storage import replicator
def decorator(wrapped):
def wrapper(*args, **kw):
s = repr(time.time()) if seed is None else seed
neo.lib.logging.info("using seed %r", s)
r = random.Random(s)
try:
MasterApplication.getNewUUID = lambda self, node_type: (
super(MasterApplication, self).getNewUUID(node_type)
if node_type == NodeTypes.CLIENT else
UUID_NAMESPACES[node_type] + ''.join(
chr(r.randrange(256)) for _ in xrange(15)))
backup_app.random = replicator.random = r
return wrapped(*args, **kw)
finally:
del MasterApplication.getNewUUID
backup_app.random = replicator.random = random
return wraps(wrapped)(wrapper)
return decorator
...@@ -26,8 +26,7 @@ from neo.storage.transactions import TransactionManager, \ ...@@ -26,8 +26,7 @@ from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError DelayedError, ConflictError
from neo.lib.connection import MTClientConnection from neo.lib.connection import MTClientConnection
from neo.lib.protocol import NodeStates, Packets, ZERO_TID from neo.lib.protocol import NodeStates, Packets, ZERO_TID
from . import NEOCluster, NEOThreadedTest, \ from . import NEOCluster, NEOThreadedTest, Patch
Patch, ConnectionFilter
from neo.lib.util import makeChecksum from neo.lib.util import makeChecksum
from neo.client.pool import CELL_CONNECTED, CELL_GOOD from neo.client.pool import CELL_CONNECTED, CELL_GOOD
......
#
# Copyright (c) 2011 Nexedi SARL and Contributors. All Rights Reserved.
# Julien Muchembled <jm@nexedi.com>
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
import random
import sys
import time
import threading
import transaction
import unittest
import neo.lib
from neo.storage.transactions import TransactionManager, \
DelayedError, ConflictError
from neo.lib.connection import MTClientConnection
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID
from . import NEOCluster, NEOThreadedTest, Patch, predictable_random
from neo.client.pool import CELL_CONNECTED, CELL_GOOD
class ReplicationTests(NEOThreadedTest):
def checksumPartition(self, storage, partition):
dm = storage.dm
args = ZERO_TID, MAX_TID, None, partition
return dm.checkTIDRange(*args), dm.checkSerialRange(ZERO_TID, *args)
def checkPartitionReplicated(self, source, destination, partition):
self.assertEqual(self.checksumPartition(source, partition),
self.checksumPartition(destination, partition))
def checkBackup(self, cluster):
upstream_pt = cluster.upstream.primary_master.pt
pt = cluster.primary_master.pt
np = pt.getPartitions()
self.assertEqual(np, upstream_pt.getPartitions())
checked = 0
source_dict = dict((x.uuid, x) for x in cluster.upstream.storage_list)
for storage in cluster.storage_list:
self.assertEqual(np, storage.pt.getPartitions())
for partition in pt.getAssignedPartitionList(storage.uuid):
cell_list = upstream_pt.getCellList(partition, readable=True)
source = source_dict[random.choice(cell_list).getUUID()]
self.checkPartitionReplicated(source, storage, partition)
checked += 1
return checked
def testBackupNormalCase(self):
upstream = NEOCluster(partitions=7, replicas=1, storage_count=3)
try:
upstream.start()
importZODB = upstream.importZODB()
importZODB(3)
upstream.client.setPoll(0)
backup = NEOCluster(partitions=7, replicas=1, storage_count=5,
upstream=upstream)
try:
backup.start()
# Initialize & catch up.
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
# Normal case, following upstream cluster closely.
importZODB(17)
upstream.client.setPoll(0)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
# Check that a backup cluster can be restarted.
finally:
backup.stop()
backup.reset()
try:
backup.start()
self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.BACKINGUP)
importZODB(17)
upstream.client.setPoll(0)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
# Stop backing up, nothing truncated.
backup.neoctl.setClusterState(ClusterStates.STOPPING_BACKUP)
backup.tic()
self.assertEqual(14, self.checkBackup(backup))
self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.RUNNING)
finally:
backup.stop()
finally:
upstream.stop()
@predictable_random()
def testBackupNodeLost(self):
"""Check backup cluster can recover after random connection loss
- backup master disconnected from upstream master
- primary storage disconnected from backup master
- non-primary storage disconnected from backup master
"""
from neo.master.backup_app import random
def fetchObjects(orig, min_tid=None, min_oid=ZERO_OID):
if min_tid is None:
counts[0] += 1
if counts[0] > 1:
orig.im_self.app.master_conn.close()
return orig(min_tid, min_oid)
def onTransactionCommitted(orig, txn):
counts[0] += 1
if counts[0] > 1:
node_list = orig.im_self.nm.getClientList(only_identified=True)
node_list.remove(txn.getNode())
node_list[0].getConnection().close()
return orig(txn)
upstream = NEOCluster(partitions=4, replicas=0, storage_count=1)
try:
upstream.start()
importZODB = upstream.importZODB(random=random)
backup = NEOCluster(partitions=4, replicas=2, storage_count=4,
upstream=upstream)
try:
backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
backup.tic()
storage_list = [x.uuid for x in backup.storage_list]
slave = set(xrange(len(storage_list))).difference
for event in xrange(10):
counts = [0]
if event == 5:
p = Patch(upstream.master.tm,
_on_commit=onTransactionCommitted)
else:
primary_dict = {}
for k, v in sorted(backup.master.backup_app
.primary_partition_dict.iteritems()):
primary_dict.setdefault(storage_list.index(v._uuid),
[]).append(k)
if event % 2:
storage = slave(primary_dict).pop()
else:
storage, partition_list = primary_dict.popitem()
# Populate until the found storage performs
# a second replication partially and aborts.
p = Patch(backup.storage_list[storage].replicator,
fetchObjects=fetchObjects)
try:
importZODB(lambda x: counts[0] > 1)
finally:
del p
upstream.client.setPoll(0)
backup.tic()
self.assertEqual(12, self.checkBackup(backup))
finally:
backup.stop()
finally:
upstream.stop()
def testReplicationAbortedBySource(self):
"""
Check that a feeding node aborts replication when its partition is
dropped, and that the out-of-date node finishes to replicate from
another source.
Here are the different states of partitions over time:
pt: 0: U|U|U
pt: 0: UO|UO|UO
pt: 0: FOO|UO.|U.O # node 1 replicates from node 0
pt: 0: .OU|UO.|U.O # here node 0 lost partition 0
# and node 1 must switch to node 2
pt: 0: .OU|UO.|U.U
pt: 0: .OU|UU.|U.U
pt: 0: .UU|UU.|U.U
"""
def connected(orig, *args, **kw):
patch[0] = s1.filterConnection(s0)
patch[0].add(delayAskFetch,
Patch(s0.dm, changePartitionTable=changePartitionTable))
return orig(*args, **kw)
def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and packet.decode()[0] == offset
def changePartitionTable(orig, ptid, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
patch[0].remove(delayAskFetch)
# XXX: this is currently not done by
# default for performance reason
orig.im_self.dropPartitions((offset,))
return orig(ptid, cell_list)
cluster = NEOCluster(partitions=3, replicas=1, storage_count=3)
s0, s1, s2 = cluster.storage_list
for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects:
try:
cluster.start([s0])
cluster.populate([range(6)] * 3)
cluster.client.setPoll(0)
s1.start()
s2.start()
cluster.tic()
cluster.neoctl.enableStorageList([s1.uuid, s2.uuid])
offset, = [offset for offset, row in enumerate(
cluster.master.pt.partition_list)
for cell in row if cell.isFeeding()]
patch = [Patch(s1.replicator, fetchTransactions=connected)]
try:
cluster.tic()
self.assertEqual(1, patch[0].filtered_count)
patch[0]()
finally:
del patch[:]
cluster.tic()
self.checkPartitionReplicated(s1, s2, offset)
finally:
cluster.stop()
cluster.reset(True)
if __name__ == "__main__":
unittest.main()
...@@ -29,7 +29,7 @@ extras_require = { ...@@ -29,7 +29,7 @@ extras_require = {
'client': ['ZODB3'], # ZODB3 >= 3.10 'client': ['ZODB3'], # ZODB3 >= 3.10
'ctl': [], 'ctl': [],
'master': [], 'master': [],
'storage-btree': ['ZODB3'], 'storage-sqlite': [],
'storage-mysqldb': ['MySQL-python'], 'storage-mysqldb': ['MySQL-python'],
} }
extras_require['tests'] = ['zope.testing', 'psutil', extras_require['tests'] = ['zope.testing', 'psutil',
......
...@@ -78,13 +78,14 @@ def main(): ...@@ -78,13 +78,14 @@ def main():
if subprocess.call((os.path.join(bin, 'buildout'), '-v'), if subprocess.call((os.path.join(bin, 'buildout'), '-v'),
cwd=test_home): cwd=test_home):
continue continue
title = '[%s:%s-g%s:%s]' % (branch, for backend in 'SQLite', 'MySQL':
os.environ['NEO_TESTS_ADAPTER'] = backend
title = '[%s:%s-g%s:%s:%s]' % (branch,
git('rev-list', '--topo-order', '--count', revision), git('rev-list', '--topo-order', '--count', revision),
revision[:7], os.path.basename(test_home)) revision[:7], os.path.basename(test_home), backend)
if tests: if tests:
subprocess.call([os.path.join(bin, 'neotestrunner'), subprocess.call([os.path.join(bin, 'neotestrunner'),
'-' + tests, '--title', '-' + tests, '--title', 'NEO tests ' + title,
'NEO tests ' + title,
] + sys.argv[1:arg_count]) ] + sys.argv[1:arg_count])
if 'm' in tasks: if 'm' in tasks:
subprocess.call([os.path.join(bin, 'python'), subprocess.call([os.path.join(bin, 'python'),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment