Commit 8ba42463 authored by Julien Muchembled's avatar Julien Muchembled

Merge protocol v0

parents a33c624c 2b9e14e8
Change History Change History
============== ==============
1.12 (2019-04-28)
-----------------
Most changes in this version focus on the ability to migrate efficiently
and reliably a big ZODB to NEO, which required changes in the protocol.
See testSplitAndMakeResilientUsingClone for an example of scenario.
Better cluster management:
- New --new-nid storage option for fast cloning.
- The number of wanted replicas is now a property of the database, which is
modifiable when the cluster is running, and reported by `neoctl print pt`.
- Better error reporting from the master to neoctl for denied requests.
- tweak: do not touch cells of nodes that are intended to be dropped.
- tweak: do not crash when trying to remove all nodes.
- tweak: new neoctl option to ask the master to simulate.
- neoctl: better display of full partition tables.
- master: reject drop/tweak commands that could lead to unwanted status.
Importer:
- Fix possible data loss on writeback.
- v1.9 broke replication (as source) once the import is finished.
- Speed up startup when the import is already finished.
- Fix closure of ZODB, and also do it when the import is finished.
- Fix hidden "maximum recursion depth exceeded" at startup.
- Fix resumption when using SQLite.
- v1.10 broke resumption when there are new transactions since the import
started.
MySQL:
- Better support of RocksDB by specifying column families.
- Fix handling of connection strings (--database) without credentials.
1.11 (2019-03-11) 1.11 (2019-03-11)
----------------- -----------------
......
...@@ -18,10 +18,8 @@ from neo.lib import logging ...@@ -18,10 +18,8 @@ from neo.lib import logging
from neo.lib.app import BaseApplication, buildOptionParser from neo.lib.app import BaseApplication, buildOptionParser
from neo.lib.connection import ListeningConnection from neo.lib.connection import ListeningConnection
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
from .handler import AdminEventHandler, MasterEventHandler, \ from .handler import AdminEventHandler, MasterEventHandler
MasterRequestEventHandler
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from neo.lib.pt import PartitionTable
from neo.lib.protocol import ClusterStates, Errors, NodeTypes, Packets from neo.lib.protocol import ClusterStates, Errors, NodeTypes, Packets
from neo.lib.debug import register as registerLiveDebugger from neo.lib.debug import register as registerLiveDebugger
...@@ -36,8 +34,8 @@ class Application(BaseApplication): ...@@ -36,8 +34,8 @@ class Application(BaseApplication):
cls.addCommonServerOptions('admin', '127.0.0.1:9999') cls.addCommonServerOptions('admin', '127.0.0.1:9999')
_ = _.group('admin') _ = _.group('admin')
_.int('u', 'uuid', _.int('i', 'nid',
help="specify an UUID to use for this process (testing purpose)") help="specify an NID to use for this process (testing purpose)")
def __init__(self, config): def __init__(self, config):
super(Application, self).__init__( super(Application, self).__init__(
...@@ -53,9 +51,8 @@ class Application(BaseApplication): ...@@ -53,9 +51,8 @@ class Application(BaseApplication):
# The partition table is initialized after getting the number of # The partition table is initialized after getting the number of
# partitions. # partitions.
self.pt = None self.pt = None
self.uuid = config.get('uuid') self.uuid = config.get('nid')
logging.node(self.name, self.uuid) logging.node(self.name, self.uuid)
self.request_handler = MasterRequestEventHandler(self)
self.master_event_handler = MasterEventHandler(self) self.master_event_handler = MasterEventHandler(self)
self.cluster_state = None self.cluster_state = None
self.reset() self.reset()
...@@ -66,7 +63,6 @@ class Application(BaseApplication): ...@@ -66,7 +63,6 @@ class Application(BaseApplication):
super(Application, self).close() super(Application, self).close()
def reset(self): def reset(self):
self.bootstrapped = False
self.master_conn = None self.master_conn = None
self.master_node = None self.master_node = None
...@@ -117,40 +113,20 @@ class Application(BaseApplication): ...@@ -117,40 +113,20 @@ class Application(BaseApplication):
self.cluster_state = None self.cluster_state = None
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, NodeTypes.ADMIN, self.server) bootstrap = BootstrapManager(self, NodeTypes.ADMIN, self.server)
self.master_node, self.master_conn, num_partitions, num_replicas = \ self.master_node, self.master_conn = bootstrap.getPrimaryConnection()
bootstrap.getPrimaryConnection()
if self.pt is None:
self.pt = PartitionTable(num_partitions, num_replicas)
elif self.pt.getPartitions() != num_partitions:
# XXX: shouldn't we recover instead of raising ?
raise RuntimeError('the number of partitions is inconsistent')
elif self.pt.getReplicas() != num_replicas:
# XXX: shouldn't we recover instead of raising ?
raise RuntimeError('the number of replicas is inconsistent')
# passive handler # passive handler
self.master_conn.setHandler(self.master_event_handler) self.master_conn.setHandler(self.master_event_handler)
self.master_conn.ask(Packets.AskClusterState()) self.master_conn.ask(Packets.AskClusterState())
self.master_conn.ask(Packets.AskPartitionTable())
def sendPartitionTable(self, conn, min_offset, max_offset, uuid): def sendPartitionTable(self, conn, min_offset, max_offset, uuid):
# we have a pt pt = self.pt
self.pt.log()
row_list = []
if max_offset == 0: if max_offset == 0:
max_offset = self.pt.getPartitions() max_offset = pt.getPartitions()
try: try:
for offset in xrange(min_offset, max_offset): row_list = map(pt.getRow, xrange(min_offset, max_offset))
row = []
try:
for cell in self.pt.getCellList(offset):
if uuid is None or cell.getUUID() == uuid:
row.append((cell.getUUID(), cell.getState()))
except TypeError:
pass
row_list.append((offset, row))
except IndexError: except IndexError:
conn.send(Errors.ProtocolError('invalid partition table offset')) conn.send(Errors.ProtocolError('invalid partition table offset'))
else: else:
conn.answer(Packets.AnswerPartitionList(self.pt.getID(), row_list)) conn.answer(Packets.AnswerPartitionList(
pt.getID(), pt.getReplicas(), row_list))
...@@ -17,30 +17,49 @@ ...@@ -17,30 +17,49 @@
from neo.lib import logging, protocol from neo.lib import logging, protocol
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import uuid_str, Packets from neo.lib.protocol import uuid_str, Packets
from neo.lib.pt import PartitionTable
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
def check_primary_master(func): def AdminEventHandlerType(name, bases, d):
def wrapper(self, *args, **kw): def check_primary_master(func):
if self.app.bootstrapped: def wrapper(self, *args, **kw):
return func(self, *args, **kw) if self.app.master_conn is not None:
raise protocol.NotReadyError('Not connected to a primary master.') return func(self, *args, **kw)
return wrapper raise protocol.NotReadyError('Not connected to a primary master.')
return wrapper
def forward_ask(klass):
return check_primary_master(lambda self, conn, *args, **kw: def forward_ask(klass):
self.app.master_conn.ask(klass(*args, **kw), return lambda self, conn, *args: self.app.master_conn.ask(
conn=conn, msg_id=conn.getPeerId())) klass(*args), conn=conn, msg_id=conn.getPeerId())
del d['__metaclass__']
for x in (
Packets.AddPendingNodes,
Packets.AskLastIDs,
Packets.AskLastTransaction,
Packets.AskRecovery,
Packets.CheckReplicas,
Packets.Repair,
Packets.SetClusterState,
Packets.SetNodeState,
Packets.SetNumReplicas,
Packets.Truncate,
Packets.TweakPartitionTable,
):
d[x.handler_method_name] = forward_ask(x)
return type(name, bases, {k: v if k[0] == '_' else check_primary_master(v)
for k, v in d.iteritems()})
class AdminEventHandler(EventHandler): class AdminEventHandler(EventHandler):
"""This class deals with events for administrating cluster.""" """This class deals with events for administrating cluster."""
@check_primary_master __metaclass__ = AdminEventHandlerType
def askPartitionList(self, conn, min_offset, max_offset, uuid): def askPartitionList(self, conn, min_offset, max_offset, uuid):
logging.info("ask partition list from %s to %s for %s", logging.info("ask partition list from %s to %s for %s",
min_offset, max_offset, uuid_str(uuid)) min_offset, max_offset, uuid_str(uuid))
self.app.sendPartitionTable(conn, min_offset, max_offset, uuid) self.app.sendPartitionTable(conn, min_offset, max_offset, uuid)
@check_primary_master
def askNodeList(self, conn, node_type): def askNodeList(self, conn, node_type):
if node_type is None: if node_type is None:
node_type = 'all' node_type = 'all'
...@@ -53,36 +72,22 @@ class AdminEventHandler(EventHandler): ...@@ -53,36 +72,22 @@ class AdminEventHandler(EventHandler):
p = Packets.AnswerNodeList(node_information_list) p = Packets.AnswerNodeList(node_information_list)
conn.answer(p) conn.answer(p)
@check_primary_master
def askClusterState(self, conn): def askClusterState(self, conn):
conn.answer(Packets.AnswerClusterState(self.app.cluster_state)) conn.answer(Packets.AnswerClusterState(self.app.cluster_state))
@check_primary_master
def askPrimary(self, conn): def askPrimary(self, conn):
master_node = self.app.master_node master_node = self.app.master_node
conn.answer(Packets.AnswerPrimary(master_node.getUUID())) conn.answer(Packets.AnswerPrimary(master_node.getUUID()))
@check_primary_master
def flushLog(self, conn): def flushLog(self, conn):
self.app.master_conn.send(Packets.FlushLog()) self.app.master_conn.send(Packets.FlushLog())
super(AdminEventHandler, self).flushLog(conn) super(AdminEventHandler, self).flushLog(conn)
askLastIDs = forward_ask(Packets.AskLastIDs)
askLastTransaction = forward_ask(Packets.AskLastTransaction)
addPendingNodes = forward_ask(Packets.AddPendingNodes)
askRecovery = forward_ask(Packets.AskRecovery)
tweakPartitionTable = forward_ask(Packets.TweakPartitionTable)
setClusterState = forward_ask(Packets.SetClusterState)
setNodeState = forward_ask(Packets.SetNodeState)
checkReplicas = forward_ask(Packets.CheckReplicas)
truncate = forward_ask(Packets.Truncate)
repair = forward_ask(Packets.Repair)
class MasterEventHandler(EventHandler): class MasterEventHandler(EventHandler):
""" This class is just used to dispatch message to right handler""" """ This class is just used to dispatch message to right handler"""
def _connectionLost(self, conn): def connectionClosed(self, conn):
app = self.app app = self.app
if app.listening_conn: # if running if app.listening_conn: # if running
assert app.master_conn in (conn, None) assert app.master_conn in (conn, None)
...@@ -91,42 +96,21 @@ class MasterEventHandler(EventHandler): ...@@ -91,42 +96,21 @@ class MasterEventHandler(EventHandler):
app.uuid = None app.uuid = None
raise PrimaryFailure raise PrimaryFailure
def connectionFailed(self, conn):
self._connectionLost(conn)
def connectionClosed(self, conn):
self._connectionLost(conn)
def dispatch(self, conn, packet, kw={}): def dispatch(self, conn, packet, kw={}):
if 'conn' in kw: forward = kw.get('conn')
# expected answer if forward is None:
if packet.isResponse():
packet.setId(kw['msg_id'])
kw['conn'].answer(packet)
else:
self.app.request_handler.dispatch(conn, packet, kw)
else:
# unexpected answers and notifications
super(MasterEventHandler, self).dispatch(conn, packet, kw) super(MasterEventHandler, self).dispatch(conn, packet, kw)
else:
forward.send(packet, kw['msg_id'])
def answerClusterState(self, conn, state): def answerClusterState(self, conn, state):
self.app.cluster_state = state self.app.cluster_state = state
def notifyPartitionChanges(self, conn, ptid, cell_list): notifyClusterInformation = answerClusterState
self.app.pt.update(ptid, cell_list, self.app.nm)
def answerPartitionTable(self, conn, ptid, row_list):
self.app.pt.load(ptid, row_list, self.app.nm)
self.app.bootstrapped = True
def sendPartitionTable(self, conn, ptid, row_list):
if self.app.bootstrapped:
self.app.pt.load(ptid, row_list, self.app.nm)
def notifyClusterInformation(self, conn, cluster_state):
self.app.cluster_state = cluster_state
def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
pt = self.app.pt = object.__new__(PartitionTable)
pt.load(ptid, num_replicas, row_list, self.app.nm)
class MasterRequestEventHandler(EventHandler): def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
""" This class handle all answer from primary master node""" self.app.pt.update(ptid, num_replicas, cell_list, self.app.nm)
# XXX: to be deleted ?
...@@ -13,6 +13,13 @@ ...@@ -13,6 +13,13 @@
############################################################################## ##############################################################################
def patch(): def patch():
# For msgpack & Py2/ZODB5.
try:
from zodbpickle import binary
binary._pack = bytes.__str__
except ImportError:
pass
from hashlib import md5 from hashlib import md5
from ZODB.Connection import Connection from ZODB.Connection import Connection
......
...@@ -18,6 +18,7 @@ import heapq ...@@ -18,6 +18,7 @@ import heapq
import random import random
import time import time
from collections import defaultdict from collections import defaultdict
try: try:
from ZODB._compat import dumps, loads, _protocol from ZODB._compat import dumps, loads, _protocol
except ImportError: except ImportError:
...@@ -76,11 +77,11 @@ class Application(ThreadedApplication): ...@@ -76,11 +77,11 @@ class Application(ThreadedApplication):
self.primary_master_node = None self.primary_master_node = None
self.trying_master_node = None self.trying_master_node = None
# no self-assigned UUID, primary master will supply us one # no self-assigned NID, primary master will supply us one
self._cache = ClientCache() if cache_size is None else \ self._cache = ClientCache() if cache_size is None else \
ClientCache(max_size=cache_size) ClientCache(max_size=cache_size)
self._loading = defaultdict(lambda: (Lock(), [])) self._loading = defaultdict(lambda: (Lock(), []))
self.new_oid_list = () self.new_oids = ()
self.last_oid = '\0' * 8 self.last_oid = '\0' * 8
self.storage_event_handler = storage.StorageEventHandler(self) self.storage_event_handler = storage.StorageEventHandler(self)
self.storage_bootstrap_handler = storage.StorageBootstrapHandler(self) self.storage_bootstrap_handler = storage.StorageBootstrapHandler(self)
...@@ -181,7 +182,7 @@ class Application(ThreadedApplication): ...@@ -181,7 +182,7 @@ class Application(ThreadedApplication):
with self._connecting_to_master_node: with self._connecting_to_master_node:
result = self.master_conn result = self.master_conn
if result is None: if result is None:
self.new_oid_list = () self.new_oids = ()
result = self.master_conn = self._connectToPrimaryNode() result = self.master_conn = self._connectToPrimaryNode()
return result return result
...@@ -220,8 +221,8 @@ class Application(ThreadedApplication): ...@@ -220,8 +221,8 @@ class Application(ThreadedApplication):
self.notifications_handler, self.notifications_handler,
node=node, node=node,
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
p = Packets.RequestIdentification( p = Packets.RequestIdentification(NodeTypes.CLIENT,
NodeTypes.CLIENT, self.uuid, None, self.name, (), None) self.uuid, None, self.name, None, (), ())
try: try:
ask(conn, p, handler=handler) ask(conn, p, handler=handler)
except ConnectionClosed: except ConnectionClosed:
...@@ -238,7 +239,6 @@ class Application(ThreadedApplication): ...@@ -238,7 +239,6 @@ class Application(ThreadedApplication):
# operational. Might raise ConnectionClosed so that the new # operational. Might raise ConnectionClosed so that the new
# primary can be looked-up again. # primary can be looked-up again.
logging.info('Initializing from master') logging.info('Initializing from master')
ask(conn, Packets.AskPartitionTable(), handler=handler)
ask(conn, Packets.AskLastTransaction(), handler=handler) ask(conn, Packets.AskLastTransaction(), handler=handler)
if self.pt.operational(): if self.pt.operational():
break break
...@@ -264,7 +264,7 @@ class Application(ThreadedApplication): ...@@ -264,7 +264,7 @@ class Application(ThreadedApplication):
conn = MTClientConnection(self, self.storage_event_handler, node, conn = MTClientConnection(self, self.storage_event_handler, node,
dispatcher=self.dispatcher) dispatcher=self.dispatcher)
p = Packets.RequestIdentification(NodeTypes.CLIENT, p = Packets.RequestIdentification(NodeTypes.CLIENT,
self.uuid, None, self.name, (), self.id_timestamp) self.uuid, None, self.name, self.id_timestamp, (), ())
try: try:
self._ask(conn, p, handler=self.storage_bootstrap_handler) self._ask(conn, p, handler=self.storage_bootstrap_handler)
except ConnectionClosed: except ConnectionClosed:
...@@ -306,15 +306,19 @@ class Application(ThreadedApplication): ...@@ -306,15 +306,19 @@ class Application(ThreadedApplication):
"""Get a new OID.""" """Get a new OID."""
self._oid_lock_acquire() self._oid_lock_acquire()
try: try:
if not self.new_oid_list: for oid in self.new_oids:
break
else:
# Get new oid list from master node # Get new oid list from master node
# we manage a list of oid here to prevent # we manage a list of oid here to prevent
# from asking too many time new oid one by one # from asking too many time new oid one by one
# from master node # from master node
self._askPrimary(Packets.AskNewOIDs(100)) self._askPrimary(Packets.AskNewOIDs(100))
if not self.new_oid_list: for oid in self.new_oids:
break
else:
raise NEOStorageError('new_oid failed') raise NEOStorageError('new_oid failed')
self.last_oid = oid = self.new_oid_list.pop() self.last_oid = oid
return oid return oid
finally: finally:
self._oid_lock_release() self._oid_lock_release()
...@@ -611,7 +615,7 @@ class Application(ThreadedApplication): ...@@ -611,7 +615,7 @@ class Application(ThreadedApplication):
# user and description are cast to str in case they're unicode. # user and description are cast to str in case they're unicode.
# BBB: This is not required anymore with recent ZODB. # BBB: This is not required anymore with recent ZODB.
packet = Packets.AskStoreTransaction(ttid, str(transaction.user), packet = Packets.AskStoreTransaction(ttid, str(transaction.user),
str(transaction.description), ext, txn_context.cache_dict) str(transaction.description), ext, list(txn_context.cache_dict))
queue = txn_context.queue queue = txn_context.queue
conn_dict = txn_context.conn_dict conn_dict = txn_context.conn_dict
# Ask in parallel all involved storage nodes to commit object metadata. # Ask in parallel all involved storage nodes to commit object metadata.
...@@ -696,7 +700,7 @@ class Application(ThreadedApplication): ...@@ -696,7 +700,7 @@ class Application(ThreadedApplication):
else: else:
try: try:
notify(Packets.AbortTransaction(txn_context.ttid, notify(Packets.AbortTransaction(txn_context.ttid,
txn_context.conn_dict)) list(txn_context.conn_dict)))
except ConnectionClosed: except ConnectionClosed:
pass pass
# We don't need to flush queue, as it won't be reused by future # We don't need to flush queue, as it won't be reused by future
...@@ -731,7 +735,8 @@ class Application(ThreadedApplication): ...@@ -731,7 +735,8 @@ class Application(ThreadedApplication):
for oid in checked_list: for oid in checked_list:
del cache_dict[oid] del cache_dict[oid]
ttid = txn_context.ttid ttid = txn_context.ttid
p = Packets.AskFinishTransaction(ttid, cache_dict, checked_list) p = Packets.AskFinishTransaction(ttid, list(cache_dict),
checked_list)
try: try:
tid = self._askPrimary(p, cache_dict=cache_dict, callback=f) tid = self._askPrimary(p, cache_dict=cache_dict, callback=f)
assert tid assert tid
...@@ -765,17 +770,11 @@ class Application(ThreadedApplication): ...@@ -765,17 +770,11 @@ class Application(ThreadedApplication):
def undo(self, undone_tid, txn): def undo(self, undone_tid, txn):
txn_context = self._txn_container.get(txn) txn_context = self._txn_container.get(txn)
txn_info, txn_ext = self._getTransactionInformation(undone_tid) txn_info, txn_ext = self._getTransactionInformation(undone_tid)
txn_oid_list = txn_info['oids']
# Regroup objects per partition, to ask a minimum set of storage. # Regroup objects per partition, to ask a minimum set of storage.
partition_oid_dict = {} partition_oid_dict = defaultdict(list)
for oid in txn_oid_list: for oid in txn_info['oids']:
partition = self.pt.getPartition(oid) partition_oid_dict[self.pt.getPartition(oid)].append(oid)
try:
oid_list = partition_oid_dict[partition]
except KeyError:
oid_list = partition_oid_dict[partition] = []
oid_list.append(oid)
# Ask storage the undo serial (serial at which object's previous data # Ask storage the undo serial (serial at which object's previous data
# is) # is)
...@@ -817,8 +816,8 @@ class Application(ThreadedApplication): ...@@ -817,8 +816,8 @@ class Application(ThreadedApplication):
raise UndoError('non-undoable transaction') raise UndoError('non-undoable transaction')
# Send undo data to all storage nodes. # Send undo data to all storage nodes.
for oid in txn_oid_list: for oid, (current_serial, undo_serial, is_current) in \
current_serial, undo_serial, is_current = undo_object_tid_dict[oid] undo_object_tid_dict.iteritems():
if is_current: if is_current:
data = None data = None
else: else:
...@@ -852,7 +851,7 @@ class Application(ThreadedApplication): ...@@ -852,7 +851,7 @@ class Application(ThreadedApplication):
self._store(txn_context, oid, current_serial, data, undo_serial) self._store(txn_context, oid, current_serial, data, undo_serial)
self.waitStoreResponses(txn_context) self.waitStoreResponses(txn_context)
return None, txn_oid_list return None, list(undo_object_tid_dict)
def _getTransactionInformation(self, tid): def _getTransactionInformation(self, tid):
return self._askStorageForRead(tid, return self._askStorageForRead(tid,
...@@ -933,9 +932,9 @@ class Application(ThreadedApplication): ...@@ -933,9 +932,9 @@ class Application(ThreadedApplication):
for serial, size in self._askStorageForRead(oid, packet): for serial, size in self._askStorageForRead(oid, packet):
txn_info, txn_ext = self._getTransactionInformation(serial) txn_info, txn_ext = self._getTransactionInformation(serial)
# create history dict # create history dict
txn_info.pop('id') del txn_info['id']
txn_info.pop('oids') del txn_info['oids']
txn_info.pop('packed') del txn_info['packed']
txn_info['tid'] = serial txn_info['tid'] = serial
txn_info['version'] = '' txn_info['version'] = ''
txn_info['size'] = size txn_info['size'] = size
......
...@@ -26,10 +26,6 @@ from ..exception import NEOStorageError ...@@ -26,10 +26,6 @@ from ..exception import NEOStorageError
class PrimaryBootstrapHandler(AnswerBaseHandler): class PrimaryBootstrapHandler(AnswerBaseHandler):
""" Bootstrap handler used when looking for the primary master """ """ Bootstrap handler used when looking for the primary master """
def answerPartitionTable(self, conn, ptid, row_list):
assert row_list
self.app.pt.load(ptid, row_list, self.app.nm)
def answerLastTransaction(*args): def answerLastTransaction(*args):
pass pass
...@@ -42,9 +38,6 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -42,9 +38,6 @@ class PrimaryNotificationsHandler(MTEventHandler):
except PrimaryElected, e: except PrimaryElected, e:
self.app.primary_master_node, = e.args self.app.primary_master_node, = e.args
def _acceptIdentification(self, node, num_partitions, num_replicas):
self.app.pt = PartitionTable(num_partitions, num_replicas)
def answerLastTransaction(self, conn, ltid): def answerLastTransaction(self, conn, ltid):
app = self.app app = self.app
app_last_tid = app.__dict__.get('last_tid', '') app_last_tid = app.__dict__.get('last_tid', '')
...@@ -131,9 +124,12 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -131,9 +124,12 @@ class PrimaryNotificationsHandler(MTEventHandler):
if db is not None: if db is not None:
db.invalidate(tid, oid_list) db.invalidate(tid, oid_list)
def notifyPartitionChanges(self, conn, ptid, cell_list): def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
if self.app.pt.filled(): pt = self.app.pt = object.__new__(PartitionTable)
self.app.pt.update(ptid, cell_list, self.app.nm) pt.load(ptid, num_replicas, row_list, self.app.nm)
def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
self.app.pt.update(ptid, num_replicas, cell_list, self.app.nm)
def notifyNodeInformation(self, conn, timestamp, node_list): def notifyNodeInformation(self, conn, timestamp, node_list):
super(PrimaryNotificationsHandler, self).notifyNodeInformation( super(PrimaryNotificationsHandler, self).notifyNodeInformation(
...@@ -161,8 +157,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler): ...@@ -161,8 +157,7 @@ class PrimaryAnswersHandler(AnswerBaseHandler):
self.app.setHandlerData(ttid) self.app.setHandlerData(ttid)
def answerNewOIDs(self, conn, oid_list): def answerNewOIDs(self, conn, oid_list):
oid_list.reverse() self.app.new_oids = iter(oid_list)
self.app.new_oid_list = oid_list
def incompleteTransaction(self, conn, message): def incompleteTransaction(self, conn, message):
raise NEOStorageError("storage nodes for which vote failed can not be" raise NEOStorageError("storage nodes for which vote failed can not be"
......
...@@ -26,7 +26,7 @@ from .exception import NEOStorageError ...@@ -26,7 +26,7 @@ from .exception import NEOStorageError
class _WakeupPacket(object): class _WakeupPacket(object):
handler_method_name = 'pong' handler_method_name = 'pong'
decode = tuple _args = ()
getId = int getId = int
class Transaction(object): class Transaction(object):
......
...@@ -26,7 +26,7 @@ class BootstrapManager(EventHandler): ...@@ -26,7 +26,7 @@ class BootstrapManager(EventHandler):
Manage the bootstrap stage, lookup for the primary master then connect to it Manage the bootstrap stage, lookup for the primary master then connect to it
""" """
def __init__(self, app, node_type, server=None, devpath=()): def __init__(self, app, node_type, server=None, devpath=(), new_nid=()):
""" """
Manage the bootstrap stage of a non-master node, it lookup for the Manage the bootstrap stage of a non-master node, it lookup for the
primary master node, connect to it then returns when the master node primary master node, connect to it then returns when the master node
...@@ -34,9 +34,8 @@ class BootstrapManager(EventHandler): ...@@ -34,9 +34,8 @@ class BootstrapManager(EventHandler):
""" """
self.server = server self.server = server
self.devpath = devpath self.devpath = devpath
self.new_nid = new_nid
self.node_type = node_type self.node_type = node_type
self.num_replicas = None
self.num_partitions = None
app.nm.reset() app.nm.reset()
uuid = property(lambda self: self.app.uuid) uuid = property(lambda self: self.app.uuid)
...@@ -44,7 +43,7 @@ class BootstrapManager(EventHandler): ...@@ -44,7 +43,7 @@ class BootstrapManager(EventHandler):
def connectionCompleted(self, conn): def connectionCompleted(self, conn):
EventHandler.connectionCompleted(self, conn) EventHandler.connectionCompleted(self, conn)
conn.ask(Packets.RequestIdentification(self.node_type, self.uuid, conn.ask(Packets.RequestIdentification(self.node_type, self.uuid,
self.server, self.app.name, self.devpath, None)) self.server, self.app.name, None, self.devpath, self.new_nid))
def connectionFailed(self, conn): def connectionFailed(self, conn):
EventHandler.connectionFailed(self, conn) EventHandler.connectionFailed(self, conn)
...@@ -53,10 +52,8 @@ class BootstrapManager(EventHandler): ...@@ -53,10 +52,8 @@ class BootstrapManager(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
self.current = None self.current = None
def _acceptIdentification(self, node, num_partitions, num_replicas): def _acceptIdentification(self, node):
assert self.current is node, (self.current, node) assert self.current is node, (self.current, node)
self.num_partitions = num_partitions
self.num_replicas = num_replicas
def getPrimaryConnection(self): def getPrimaryConnection(self):
""" """
...@@ -73,8 +70,7 @@ class BootstrapManager(EventHandler): ...@@ -73,8 +70,7 @@ class BootstrapManager(EventHandler):
try: try:
while self.current: while self.current:
if self.current.isIdentified(): if self.current.isIdentified():
return (self.current, self.current.getConnection(), return self.current, self.current.getConnection()
self.num_partitions, self.num_replicas)
poll(1) poll(1)
except PrimaryElected, e: except PrimaryElected, e:
if self.current: if self.current:
......
...@@ -16,12 +16,19 @@ ...@@ -16,12 +16,19 @@
from functools import wraps from functools import wraps
from time import time from time import time
import msgpack
from msgpack.exceptions import UnpackValueError
from . import attributeTracker, logging from . import attributeTracker, logging
from .connector import ConnectorException, ConnectorDelayedConnection from .connector import ConnectorException, ConnectorDelayedConnection
from .locking import RLock from .locking import RLock
from .protocol import uuid_str, Errors, PacketMalformedError, Packets from .protocol import uuid_str, Errors, PacketMalformedError, Packets, \
from .util import dummy_read_buffer, ReadBuffer Unpacker
@apply
class dummy_read_buffer(msgpack.Unpacker):
def feed(self, _):
pass
class ConnectionClosed(Exception): class ConnectionClosed(Exception):
pass pass
...@@ -209,7 +216,7 @@ class BaseConnection(object): ...@@ -209,7 +216,7 @@ class BaseConnection(object):
def _getReprInfo(self): def _getReprInfo(self):
r = [ r = [
('uuid', uuid_str(self.getUUID())), ('nid', uuid_str(self.getUUID())),
('address', ('[%s]:%s' if ':' in self.addr[0] else '%s:%s') ('address', ('[%s]:%s' if ':' in self.addr[0] else '%s:%s')
% self.addr if self.addr else '?'), % self.addr if self.addr else '?'),
('handler', self.getHandler()), ('handler', self.getHandler()),
...@@ -291,7 +298,7 @@ class ListeningConnection(BaseConnection): ...@@ -291,7 +298,7 @@ class ListeningConnection(BaseConnection):
# message. # message.
else: else:
conn._connected() conn._connected()
self.em.addWriter(conn) # for ENCODED_VERSION self.em.addWriter(conn) # for HANDSHAKE_PACKET
def getAddress(self): def getAddress(self):
return self.connector.getAddress() return self.connector.getAddress()
...@@ -310,12 +317,12 @@ class Connection(BaseConnection): ...@@ -310,12 +317,12 @@ class Connection(BaseConnection):
client = False client = False
server = False server = False
peer_id = None peer_id = None
_parser_state = None _total_unpacked = 0
_timeout = None _timeout = None
def __init__(self, event_manager, *args, **kw): def __init__(self, event_manager, *args, **kw):
BaseConnection.__init__(self, event_manager, *args, **kw) BaseConnection.__init__(self, event_manager, *args, **kw)
self.read_buf = ReadBuffer() self.read_buf = Unpacker()
self.cur_id = 0 self.cur_id = 0
self.aborted = False self.aborted = False
self.uuid = None self.uuid = None
...@@ -425,42 +432,39 @@ class Connection(BaseConnection): ...@@ -425,42 +432,39 @@ class Connection(BaseConnection):
self._closure() self._closure()
def _parse(self): def _parse(self):
read = self.read_buf.read from .protocol import HANDSHAKE_PACKET, MAGIC_SIZE, Packets
version = read(4) read_buf = self.read_buf
if version is None: handshake = read_buf.read_bytes(len(HANDSHAKE_PACKET))
return if handshake != HANDSHAKE_PACKET:
from .protocol import (ENCODED_VERSION, MAX_PACKET_SIZE, if HANDSHAKE_PACKET.startswith(handshake): # unlikely so tested last
PACKET_HEADER_FORMAT, Packets) # Not enough data and there's no API to know it in advance.
if version != ENCODED_VERSION: # Put it back.
logging.warning('Protocol version mismatch with %r', self) read_buf.feed(handshake)
return
if HANDSHAKE_PACKET.startswith(handshake[:MAGIC_SIZE]):
logging.warning('Protocol version mismatch with %r', self)
else:
logging.debug('Rejecting non-NEO %r', self)
raise ConnectorException raise ConnectorException
header_size = PACKET_HEADER_FORMAT.size read_next = read_buf.next
unpack = PACKET_HEADER_FORMAT.unpack read_pos = read_buf.tell
def parse(): def parse():
state = self._parser_state try:
if state is None: msg_id, msg_type, args = read_next()
header = read(header_size) except StopIteration:
if header is None: return
return except UnpackValueError as e:
msg_id, msg_type, msg_len = unpack(header) raise PacketMalformedError(str(e))
try: try:
packet_klass = Packets[msg_type] packet_klass = Packets[msg_type]
except KeyError: except KeyError:
raise PacketMalformedError('Unknown packet type') raise PacketMalformedError('Unknown packet type')
if msg_len > MAX_PACKET_SIZE: pos = read_pos()
raise PacketMalformedError('message too big (%d)' % msg_len) packet = packet_klass(*args)
else: packet.setId(msg_id)
msg_id, packet_klass, msg_len = state packet.size = pos - self._total_unpacked
data = read(msg_len) self._total_unpacked = pos
if data is None: return packet
# Not enough.
if state is None:
self._parser_state = msg_id, packet_klass, msg_len
else:
self._parser_state = None
packet = packet_klass()
packet.setContent(msg_id, data)
return packet
self._parse = parse self._parse = parse
return parse() return parse()
...@@ -513,7 +517,7 @@ class Connection(BaseConnection): ...@@ -513,7 +517,7 @@ class Connection(BaseConnection):
def close(self): def close(self):
if self.connector is None: if self.connector is None:
assert self._on_close is None assert self._on_close is None
assert not self.read_buf assert not self.read_buf.read_bytes(1)
assert not self.isPending() assert not self.isPending()
return return
# process the network events with the last registered handler to # process the network events with the last registered handler to
...@@ -524,7 +528,7 @@ class Connection(BaseConnection): ...@@ -524,7 +528,7 @@ class Connection(BaseConnection):
if self._on_close is not None: if self._on_close is not None:
self._on_close() self._on_close()
self._on_close = None self._on_close = None
self.read_buf.clear() self.read_buf = dummy_read_buffer
try: try:
if self.connecting: if self.connecting:
handler.connectionFailed(self) handler.connectionFailed(self)
......
...@@ -19,7 +19,7 @@ import ssl ...@@ -19,7 +19,7 @@ import ssl
import errno import errno
from time import time from time import time
from . import logging from . import logging
from .protocol import ENCODED_VERSION from .protocol import HANDSHAKE_PACKET
# Global connector registry. # Global connector registry.
# Fill by calling registerConnectorHandler. # Fill by calling registerConnectorHandler.
...@@ -74,15 +74,14 @@ class SocketConnector(object): ...@@ -74,15 +74,14 @@ class SocketConnector(object):
s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
# disable Nagle algorithm to reduce latency # disable Nagle algorithm to reduce latency
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.queued = [ENCODED_VERSION] self.queued = [HANDSHAKE_PACKET]
self.queue_size = len(ENCODED_VERSION) self.queue_size = len(HANDSHAKE_PACKET)
return self return self
def queue(self, data): def queue(self, data):
was_empty = not self.queued was_empty = not self.queued
self.queued += data self.queued.append(data)
for data in data: self.queue_size += len(data)
self.queue_size += len(data)
return was_empty return was_empty
def _error(self, op, exc=None): def _error(self, op, exc=None):
...@@ -172,7 +171,7 @@ class SocketConnector(object): ...@@ -172,7 +171,7 @@ class SocketConnector(object):
except socket.error, e: except socket.error, e:
self._error('recv', e) self._error('recv', e)
if data: if data:
read_buf.append(data) read_buf.feed(data)
return return
self._error('recv') self._error('recv')
...@@ -278,7 +277,7 @@ class _SSL: ...@@ -278,7 +277,7 @@ class _SSL:
def receive(self, read_buf): def receive(self, read_buf):
try: try:
while 1: while 1:
read_buf.append(self.socket.recv(4096)) read_buf.feed(self.socket.recv(4096))
except ssl.SSLWantReadError: except ssl.SSLWantReadError:
pass pass
except socket.error, e: except socket.error, e:
......
...@@ -23,7 +23,7 @@ NOBODY = [] ...@@ -23,7 +23,7 @@ NOBODY = []
class _ConnectionClosed(object): class _ConnectionClosed(object):
handler_method_name = 'connectionClosed' handler_method_name = 'connectionClosed'
decode = tuple _args = ()
class getId(object): class getId(object):
def __eq__(self, other): def __eq__(self, other):
......
...@@ -26,6 +26,9 @@ from .protocol import (NodeStates, NodeTypes, Packets, uuid_str, ...@@ -26,6 +26,9 @@ from .protocol import (NodeStates, NodeTypes, Packets, uuid_str,
from .util import cached_property from .util import cached_property
class AnswerDenied(Exception):
"""Helper exception to stop packet processing and answer a Denied error"""
class DelayEvent(Exception): class DelayEvent(Exception):
pass pass
...@@ -68,7 +71,7 @@ class EventHandler(object): ...@@ -68,7 +71,7 @@ class EventHandler(object):
method = getattr(self, packet.handler_method_name) method = getattr(self, packet.handler_method_name)
except AttributeError: except AttributeError:
raise UnexpectedPacketError('no handler found') raise UnexpectedPacketError('no handler found')
args = packet.decode() or () args = packet._args
method(conn, *args, **kw) method(conn, *args, **kw)
except DelayEvent, e: except DelayEvent, e:
assert not kw, kw assert not kw, kw
...@@ -76,9 +79,6 @@ class EventHandler(object): ...@@ -76,9 +79,6 @@ class EventHandler(object):
except UnexpectedPacketError, e: except UnexpectedPacketError, e:
if not conn.isClosed(): if not conn.isClosed():
self.__unexpectedPacket(conn, packet, *e.args) self.__unexpectedPacket(conn, packet, *e.args)
except PacketMalformedError, e:
logging.error('malformed packet from %r: %s', conn, e)
conn.close()
except NotReadyError, message: except NotReadyError, message:
if not conn.isClosed(): if not conn.isClosed():
if not message.args: if not message.args:
...@@ -98,6 +98,8 @@ class EventHandler(object): ...@@ -98,6 +98,8 @@ class EventHandler(object):
% (m.im_class.__module__, m.im_class.__name__, m.__name__))) % (m.im_class.__module__, m.im_class.__name__, m.__name__)))
except NonReadableCell, e: except NonReadableCell, e:
conn.answer(Errors.NonReadableCell()) conn.answer(Errors.NonReadableCell())
except AnswerDenied, e:
conn.answer(Errors.Denied(str(e)))
except AssertionError: except AssertionError:
e = sys.exc_info() e = sys.exc_info()
try: try:
...@@ -160,8 +162,7 @@ class EventHandler(object): ...@@ -160,8 +162,7 @@ class EventHandler(object):
def _acceptIdentification(*args): def _acceptIdentification(*args):
pass pass
def acceptIdentification(self, conn, node_type, uuid, def acceptIdentification(self, conn, node_type, uuid, your_uuid):
num_partitions, num_replicas, your_uuid):
app = self.app app = self.app
node = app.nm.getByAddress(conn.getAddress()) node = app.nm.getByAddress(conn.getAddress())
assert node.getConnection() is conn, (node.getConnection(), conn) assert node.getConnection() is conn, (node.getConnection(), conn)
...@@ -180,7 +181,7 @@ class EventHandler(object): ...@@ -180,7 +181,7 @@ class EventHandler(object):
elif node.getUUID() != uuid or app.uuid != your_uuid != None: elif node.getUUID() != uuid or app.uuid != your_uuid != None:
raise ProtocolError('invalid uuids') raise ProtocolError('invalid uuids')
node.setIdentified() node.setIdentified()
self._acceptIdentification(node, num_partitions, num_replicas) self._acceptIdentification(node)
return return
conn.close() conn.close()
......
...@@ -74,7 +74,7 @@ def implements(obj, ignore=()): ...@@ -74,7 +74,7 @@ def implements(obj, ignore=()):
assert not wrong_signature, wrong_signature assert not wrong_signature, wrong_signature
return obj return obj
def _set_code(func): def _stub(func):
args, varargs, varkw, _ = inspect.getargspec(func) args, varargs, varkw, _ = inspect.getargspec(func)
if varargs: if varargs:
args.append("*" + varargs) args.append("*" + varargs)
...@@ -82,16 +82,25 @@ def _set_code(func): ...@@ -82,16 +82,25 @@ def _set_code(func):
args.append("**" + varkw) args.append("**" + varkw)
exec "def %s(%s): raise NotImplementedError\nf = %s" % ( exec "def %s(%s): raise NotImplementedError\nf = %s" % (
func.__name__, ",".join(args), func.__name__) func.__name__, ",".join(args), func.__name__)
func.func_code = f.func_code return f
def abstract(func): def abstract(func):
_set_code(func) f = _stub(func)
func.__abstract__ = 1 f.__abstract__ = 1
return func f.__defaults__ = func.__defaults__
f.__doc__ = func.__doc__
return f
def requires(*args): def requires(*args):
for func in args: for func in args:
_set_code(func) # Tolerate useless abstract decoration on required method (e.g. it
# simplifies the implementation of a fallback decorator), but remove
# marker since it does not need to be implemented if it's required
# by a method that is overridden.
try:
del func.__abstract__
except AttributeError:
func.__code__ = _stub(func).__code__
def decorator(func): def decorator(func):
func.__requires__ = args func.__requires__ = args
return func return func
......
...@@ -152,7 +152,8 @@ class NEOLogger(Logger): ...@@ -152,7 +152,8 @@ class NEOLogger(Logger):
def _setup(self, filename=None, reset=False): def _setup(self, filename=None, reset=False):
from . import protocol as p from . import protocol as p
global uuid_str global packb, uuid_str
packb = p.packb
uuid_str = p.uuid_str uuid_str = p.uuid_str
if self._db is not None: if self._db is not None:
self._db.close() self._db.close()
...@@ -250,7 +251,7 @@ class NEOLogger(Logger): ...@@ -250,7 +251,7 @@ class NEOLogger(Logger):
'>' if r.outgoing else '<', uuid_str(r.uuid), ip, port) '>' if r.outgoing else '<', uuid_str(r.uuid), ip, port)
msg = r.msg msg = r.msg
if msg is not None: if msg is not None:
msg = buffer(msg) msg = buffer(msg if type(msg) is bytes else packb(msg))
q = "INSERT INTO packet VALUES (?,?,?,?,?,?)" q = "INSERT INTO packet VALUES (?,?,?,?,?,?)"
x = [r.created, nid, r.msg_id, r.code, peer, msg] x = [r.created, nid, r.msg_id, r.code, peer, msg]
else: else:
...@@ -299,9 +300,14 @@ class NEOLogger(Logger): ...@@ -299,9 +300,14 @@ class NEOLogger(Logger):
def packet(self, connection, packet, outgoing): def packet(self, connection, packet, outgoing):
if self._db is not None: if self._db is not None:
body = packet._body if self._max_packet and self._max_packet < packet.size:
if self._max_packet and self._max_packet < len(body): args = None
body = None else:
args = packet._args
try:
hash(args)
except TypeError:
args = packb(args)
self._queue(PacketRecord( self._queue(PacketRecord(
created=time(), created=time(),
msg_id=packet._id, msg_id=packet._id,
...@@ -309,7 +315,7 @@ class NEOLogger(Logger): ...@@ -309,7 +315,7 @@ class NEOLogger(Logger):
outgoing=outgoing, outgoing=outgoing,
uuid=connection.getUUID(), uuid=connection.getUUID(),
addr=connection.getAddress(), addr=connection.getAddress(),
msg=body)) msg=args))
def node(self, *cluster_nid): def node(self, *cluster_nid):
name = self.name and str(self.name) name = self.name and str(self.name)
......
...@@ -486,7 +486,7 @@ class NodeManager(EventQueue): ...@@ -486,7 +486,7 @@ class NodeManager(EventQueue):
# For the first notification, we receive a full list of nodes from # For the first notification, we receive a full list of nodes from
# the master. Remove all unknown nodes from a previous connection. # the master. Remove all unknown nodes from a previous connection.
for node in self._node_set.difference(added_list): for node in self._node_set.difference(added_list):
if app.pt.dropNode(node): if not node.isStorage() or app.pt.dropNode(node):
self.remove(node) self.remove(node)
self.log() self.log()
self.executeQueuedEvents() self.executeQueuedEvents()
......
...@@ -14,27 +14,63 @@ ...@@ -14,27 +14,63 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys import threading
import traceback from functools import partial
from cStringIO import StringIO from msgpack import packb
from struct import Struct
# The protocol version must be increased whenever upgrading a node may require # The protocol version must be increased whenever upgrading a node may require
# to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and # to upgrade other nodes.
# the high order byte 0 is different from TLS Handshake (0x16). PROTOCOL_VERSION = 0
PROTOCOL_VERSION = 5 # By encoding the handshake packet with msgpack, the whole NEO stream can be
ENCODED_VERSION = Struct('!L').pack(PROTOCOL_VERSION) # decoded with msgpack. The first byte is 0x92, which is different from TLS
# Handshake (0x16).
HANDSHAKE_PACKET = packb(('NEO', PROTOCOL_VERSION))
# Used to distinguish non-NEO stream from version mismatch.
MAGIC_SIZE = len(HANDSHAKE_PACKET) - len(packb(PROTOCOL_VERSION))
# Avoid memory errors on corrupted data.
MAX_PACKET_SIZE = 0x4000000
PACKET_HEADER_FORMAT = Struct('!LHL')
RESPONSE_MASK = 0x8000 RESPONSE_MASK = 0x8000
# Avoid some memory errors on corrupted data.
# Before we use msgpack, we limited the size of a whole packet. That's not
# possible anymore because the size is not known in advance. Packets bigger
# than the buffer size are possible (e.g. a huge list of small items) and for
# that we could compare the stream position (Unpacker.tell); it's not worth it.
UNPACK_BUFFER_SIZE = 0x4000000
@apply
def Unpacker():
global registerExtType, packb
from msgpack import ExtType, unpackb, Packer, Unpacker
ext_type_dict = []
kw = dict(use_bin_type=True)
pack_ext = Packer(**kw).pack
def registerExtType(getstate, make):
code = len(ext_type_dict)
ext_type_dict.append(lambda data: make(unpackb(data, use_list=False)))
return lambda obj: ExtType(code, pack_ext(getstate(obj)))
iterable_types = set, tuple
def default(obj):
try:
pack = obj._pack
except AttributeError:
assert type(obj) in iterable_types, type(obj)
return list(obj)
return pack()
lock = threading.Lock()
pack = Packer(default, strict_types=True, **kw).pack
def packb(obj):
with lock: # in case that 'default' is called
return pack(obj)
return partial(Unpacker, use_list=False, max_buffer_size=UNPACK_BUFFER_SIZE,
ext_hook=lambda code, data: ext_type_dict[code](data))
class Enum(tuple): class Enum(tuple):
class Item(int): class Item(int):
__slots__ = '_name', '_enum' __slots__ = '_name', '_enum', '_pack'
def __str__(self): def __str__(self):
return self._name return self._name
def __repr__(self): def __repr__(self):
...@@ -49,30 +85,38 @@ class Enum(tuple): ...@@ -49,30 +85,38 @@ class Enum(tuple):
names = func.func_code.co_names names = func.func_code.co_names
self = tuple.__new__(cls, map(cls.Item, xrange(len(names)))) self = tuple.__new__(cls, map(cls.Item, xrange(len(names))))
self._name = func.__name__ self._name = func.__name__
pack = registerExtType(int, self.__getitem__)
for item, name in zip(self, names): for item, name in zip(self, names):
setattr(self, name, item) setattr(self, name, item)
item._name = name item._name = name
item._enum = self item._enum = self
item._pack = (lambda x: lambda: x)(pack(item))
return self return self
def __repr__(self): def __repr__(self):
return "<Enum %s>" % self._name return "<Enum %s>" % self._name
# The order of extension type is important.
# Enum types first, sorted alphabetically.
@Enum @Enum
def ErrorCodes(): def CellStates():
ACK # Write-only cell. Last transactions are missing because storage is/was down
NOT_READY # for a while, or because it is new for the partition. It usually becomes
OID_NOT_FOUND # UP_TO_DATE when replication is done.
TID_NOT_FOUND OUT_OF_DATE
OID_DOES_NOT_EXIST # Normal state: cell is writable/readable, and it isn't planned to drop it.
PROTOCOL_ERROR UP_TO_DATE
REPLICATION_ERROR # Same as UP_TO_DATE, except that it will be discarded as soon as another
CHECKING_ERROR # node finishes to replicate it. It means a partition is moved from 1 node
BACKEND_NOT_IMPLEMENTED # to another. It is also discarded immediately if out-of-date.
NON_READABLE_CELL FEEDING
READ_ONLY_ACCESS # A check revealed that data differs from other replicas. Cell is neither
INCOMPLETE_TRANSACTION # readable nor writable.
CORRUPTED
# Not really a state: only used in network packets to tell storages to drop
# partitions.
DISCARDED
@Enum @Enum
def ClusterStates(): def ClusterStates():
...@@ -107,11 +151,20 @@ def ClusterStates(): ...@@ -107,11 +151,20 @@ def ClusterStates():
STOPPING_BACKUP STOPPING_BACKUP
@Enum @Enum
def NodeTypes(): def ErrorCodes():
MASTER ACK
STORAGE DENIED
CLIENT NOT_READY
ADMIN OID_NOT_FOUND
TID_NOT_FOUND
OID_DOES_NOT_EXIST
PROTOCOL_ERROR
REPLICATION_ERROR
CHECKING_ERROR
BACKEND_NOT_IMPLEMENTED
NON_READABLE_CELL
READ_ONLY_ACCESS
INCOMPLETE_TRANSACTION
@Enum @Enum
def NodeStates(): def NodeStates():
...@@ -121,23 +174,11 @@ def NodeStates(): ...@@ -121,23 +174,11 @@ def NodeStates():
PENDING PENDING
@Enum @Enum
def CellStates(): def NodeTypes():
# Write-only cell. Last transactions are missing because storage is/was down MASTER
# for a while, or because it is new for the partition. It usually becomes STORAGE
# UP_TO_DATE when replication is done. CLIENT
OUT_OF_DATE ADMIN
# Normal state: cell is writable/readable, and it isn't planned to drop it.
UP_TO_DATE
# Same as UP_TO_DATE, except that it will be discarded as soon as another
# node finishes to replicate it. It means a partition is moved from 1 node
# to another. It is also discarded immediately if out-of-date.
FEEDING
# A check revealed that data differs from other replicas. Cell is neither
# readable nor writable.
CORRUPTED
# Not really a state: only used in network packets to tell storages to drop
# partitions.
DISCARDED
# used for logging # used for logging
node_state_prefix_dict = { node_state_prefix_dict = {
...@@ -212,45 +253,22 @@ class NonReadableCell(Exception): ...@@ -212,45 +253,22 @@ class NonReadableCell(Exception):
On such event, the client must retry, preferably another cell. On such event, the client must retry, preferably another cell.
""" """
class Packet(object): class Packet(object):
""" """
Base class for any packet definition. The _fmt class attribute must be Base class for any packet definition.
defined for any non-empty packet.
""" """
_ignore_when_closed = False _ignore_when_closed = False
_request = None _request = None
_answer = None _answer = None
_body = None
_code = None _code = None
_fmt = None
_id = None _id = None
nodelay = True nodelay = True
poll_thread = False poll_thread = False
def __init__(self, *args): def __init__(self, *args):
assert self._code is not None, "Packet class not registered" assert self._code is not None, "Packet class not registered"
if args: self._args = args
buf = StringIO()
self._fmt.encode(buf.write, args)
self._body = buf.getvalue()
else:
self._body = ''
def decode(self):
assert self._body is not None
if self._fmt is None:
return ()
buf = StringIO(self._body)
try:
return self._fmt.decode(buf.read)
except ParseError, msg:
name = self.__class__.__name__
raise PacketMalformedError("%s fail (%s)" % (name, msg))
def setContent(self, msg_id, body):
""" Register the packet content for future decoding """
self._id = msg_id
self._body = body
def setId(self, value): def setId(self, value):
self._id = value self._id = value
...@@ -259,14 +277,11 @@ class Packet(object): ...@@ -259,14 +277,11 @@ class Packet(object):
assert self._id is not None, "No identifier applied on the packet" assert self._id is not None, "No identifier applied on the packet"
return self._id return self._id
def encode(self): def encode(self, packb=packb):
""" Encode a packet as a string to send it over the network """ """ Encode a packet as a string to send it over the network """
content = self._body r = packb((self._id, self._code, self._args))
return (PACKET_HEADER_FORMAT.pack(self._id, self._code, len(content)), self.size = len(r)
content) return r
def __len__(self):
return PACKET_HEADER_FORMAT.size + len(self._body)
def __repr__(self): def __repr__(self):
return '%s[%r]' % (self.__class__.__name__, self._id) return '%s[%r]' % (self.__class__.__name__, self._id)
...@@ -279,10 +294,10 @@ class Packet(object): ...@@ -279,10 +294,10 @@ class Packet(object):
return self._code == other._code return self._code == other._code
def isError(self): def isError(self):
return isinstance(self, Error) return self._code == RESPONSE_MASK
def isResponse(self): def isResponse(self):
return self._code & RESPONSE_MASK == RESPONSE_MASK return self._code & RESPONSE_MASK
def getAnswerClass(self): def getAnswerClass(self):
return self._answer return self._answer
...@@ -294,1530 +309,530 @@ class Packet(object): ...@@ -294,1530 +309,530 @@ class Packet(object):
""" """
return self._ignore_when_closed return self._ignore_when_closed
class ParseError(Exception):
"""
An exception that encapsulate another and build the 'path' of the
packet item that generate the error.
"""
def __init__(self, item, trace):
Exception.__init__(self)
self._trace = trace
self._items = [item]
def append(self, item):
self._items.append(item)
def __repr__(self):
chain = '/'.join([item.getName() for item in reversed(self._items)])
return 'at %s:\n%s' % (chain, self._trace)
__str__ = __repr__
# packet parsers
class PItem(object):
"""
Base class for any packet item, _encode and _decode must be overridden
by subclasses.
"""
def __init__(self, name):
self._name = name
def __repr__(self):
return self.__class__.__name__
def getName(self):
return self._name
def _trace(self, method, *args):
try:
return method(*args)
except ParseError, e:
# trace and forward exception
e.append(self)
raise
except Exception:
# original exception, encapsulate it
trace = ''.join(traceback.format_exception(*sys.exc_info())[2:])
raise ParseError(self, trace)
def encode(self, writer, items):
return self._trace(self._encode, writer, items)
def decode(self, reader):
return self._trace(self._decode, reader)
def _encode(self, writer, items):
raise NotImplementedError, self.__class__.__name__
def _decode(self, reader): class PacketRegistryFactory(dict):
raise NotImplementedError, self.__class__.__name__
class PStruct(PItem): def __call__(self, name, base, d):
""" for k, v in d.items():
Aggregate other items if isinstance(v, type) and issubclass(v, Packet):
""" v.__name__ = k
def __init__(self, name, *items): v.handler_method_name = k[0].lower() + k[1:]
PItem.__init__(self, name) # this builds a "singleton"
self._items = items return type('PacketRegistry', base, d)(self)
def _encode(self, writer, items): def register(self, doc, ignore_when_closed=None, request=False, error=False,
assert len(self._items) == len(items), (items, self._items) _base=(Packet,), **kw):
for item, value in zip(self._items, items): """ Register a packet in the packet registry """
item.encode(writer, value) code = len(self)
if doc is None:
def _decode(self, reader): self[code] = None
return tuple([item.decode(reader) for item in self._items]) return # None registered only to skip a code number (for compatibility)
if error and not request:
class PStructItem(PItem): assert not code
""" code = RESPONSE_MASK
A single value encoded with struct kw.update(__doc__=doc, _code=code)
""" packet = type('', _base, kw)
def __init__(self, name): # register the request
PItem.__init__(self, name) self[code] = packet
struct = Struct(self._fmt) if request:
self.pack = struct.pack if ignore_when_closed is None:
self.unpack = struct.unpack # By default, on a closed connection:
self.size = struct.size # - request: ignore
# - answer: keep
def _encode(self, writer, value): # - notification: keep
writer(self.pack(value)) packet._ignore_when_closed = True
else:
def _decode(self, reader): assert ignore_when_closed is False
return self.unpack(reader(self.size))[0] if error:
packet._answer = self[RESPONSE_MASK]
class PStructItemOrNone(PStructItem): else:
# build a class for the answer
def _encode(self, writer, value): code |= RESPONSE_MASK
return writer(self._None if value is None else self.pack(value)) kw['_code'] = code
answer = packet._answer = self[code] = type('', _base, kw)
def _decode(self, reader): return packet, answer
value = reader(self.size)
return None if value == self._None else self.unpack(value)[0]
class POption(PStruct):
def _encode(self, writer, value):
if value is None:
writer('\0')
else: else:
writer('\1') assert ignore_when_closed is None
PStruct._encode(self, writer, value) return packet
def _decode(self, reader):
if '\0\1'.index(reader(1)):
return PStruct._decode(self, reader)
class PList(PStructItem): class Packets(dict):
""" """
A list of homogeneous items Packet registry that checks packet code uniqueness and provides an index
""" """
_fmt = '!L' __metaclass__ = PacketRegistryFactory()
notify = __metaclass__.register
request = partial(notify, request=True)
def __init__(self, name, item): Error = notify("""
PStructItem.__init__(self, name) Error is a special type of message, because this can be sent against
self._item = item any other message, even if such a message does not expect a reply
usually.
def _encode(self, writer, items): :nodes: * -> *
writer(self.pack(len(items))) """, error=True)
item = self._item
for value in items:
item.encode(writer, value)
def _decode(self, reader): RequestIdentification, AcceptIdentification = request("""
length = self.unpack(reader(self.size))[0] Request a node identification. This must be the first packet for any
item = self._item connection.
return [item.decode(reader) for _ in xrange(length)]
class PDict(PStructItem): :nodes: * -> *
""" """, poll_thread=True)
A dictionary with custom key and value formats
"""
_fmt = '!L'
def __init__(self, name, key, value):
PStructItem.__init__(self, name)
self._key = key
self._value = value
def _encode(self, writer, item):
assert isinstance(item , dict), (type(item), item)
writer(self.pack(len(item)))
key, value = self._key, self._value
for k, v in item.iteritems():
key.encode(writer, k)
value.encode(writer, v)
def _decode(self, reader):
length = self.unpack(reader(self.size))[0]
key, value = self._key, self._value
new_dict = {}
for _ in xrange(length):
k = key.decode(reader)
v = value.decode(reader)
new_dict[k] = v
return new_dict
class PEnum(PStructItem):
"""
Encapsulate an enumeration value
"""
_fmt = 'b'
def __init__(self, name, enum): Ping, Pong = request("""
PStructItem.__init__(self, name) Empty request used as network barrier.
self._enum = enum
def _encode(self, writer, item): :nodes: * -> *
if item is None: """)
item = -1
writer(self.pack(item))
def _decode(self, reader): CloseClient = notify("""
code = self.unpack(reader(self.size))[0] Tell peer that it can close the connection if it has finished with us.
if code == -1:
return None
try:
return self._enum[code]
except KeyError:
enum = self._enum.__class__.__name__
raise ValueError, 'Invalid code for %s enum: %r' % (enum, code)
class PString(PStructItem): :nodes: * -> *
""" """)
A variable-length string
"""
_fmt = '!L'
def _encode(self, writer, value): AskPrimary, AnswerPrimary = request("""
writer(self.pack(len(value))) Ask node identier of the current primary master.
writer(value)
def _decode(self, reader): :nodes: ctl -> A
length = self.unpack(reader(self.size))[0] """)
return reader(length)
class PAddress(PString): NotPrimaryMaster = notify("""
""" Notify peer that I'm not the primary master. Attach any extra
An host address (IPv4/IPv6) information to help the peer joining the cluster.
"""
def __init__(self, name): :nodes: SM -> *
PString.__init__(self, name) """)
self._port = Struct('!H')
def _encode(self, writer, address): NotifyNodeInformation = notify("""
if address: Notify information about one or more nodes.
host, port = address
PString._encode(self, writer, host)
writer(self._port.pack(port))
else:
PString._encode(self, writer, '')
def _decode(self, reader): :nodes: M -> *
host = PString._decode(self, reader) """)
if host:
p = self._port
return host, p.unpack(reader(p.size))[0]
class PBoolean(PStructItem): AskRecovery, AnswerRecovery = request("""
""" Ask storage nodes data needed by master to recover.
A boolean value, encoded as a single byte Reused by `neoctl print ids`.
"""
_fmt = '!?'
class PNumber(PStructItem): :nodes: M -> S; ctl -> A -> M
""" """)
A integer number (4-bytes length)
"""
_fmt = '!L'
class PIndex(PStructItem): AskLastIDs, AnswerLastIDs = request("""
""" Ask the last OID/TID so that a master can initialize its
A big integer to defined indexes in a huge list. TransactionManager. Reused by `neoctl print ids`.
"""
_fmt = '!Q'
class PPTID(PStructItemOrNone): :nodes: M -> S; ctl -> A -> M
""" """)
A None value means an invalid PTID
"""
_fmt = '!Q'
_None = Struct(_fmt).pack(0)
class PChecksum(PItem): AskPartitionTable, AnswerPartitionTable = request("""
""" Ask storage node the remaining data needed by master to recover.
A hash (SHA1)
"""
def _encode(self, writer, checksum):
assert len(checksum) == 20, (len(checksum), checksum)
writer(checksum)
def _decode(self, reader): :nodes: M -> S
return reader(20) """)
class PSignedNull(PStructItemOrNone): SendPartitionTable = notify("""
_fmt = '!l' Send the full partition table to admin/client/storage nodes on
_None = Struct(_fmt).pack(0) connection.
class PUUID(PSignedNull): :nodes: M -> A, C, S
""" """)
An UUID (node identifier, 4-bytes signed integer)
"""
class PTID(PItem): NotifyPartitionChanges = notify("""
""" Notify about changes in the partition table.
A transaction identifier
"""
def _encode(self, writer, tid):
if tid is None:
tid = INVALID_TID
assert len(tid) == 8, (len(tid), tid)
writer(tid)
def _decode(self, reader):
tid = reader(8)
if tid == INVALID_TID:
tid = None
return tid
# same definition, for now
POID = PTID
class PFloat(PStructItemOrNone):
"""
A float number (8-bytes length)
"""
_fmt = '!d'
_None = '\xff' * 8
# common definitions
PFEmpty = PStruct('no_content')
PFNodeType = PEnum('type', NodeTypes)
PFNodeState = PEnum('state', NodeStates)
PFCellState = PEnum('state', CellStates)
PFNodeList = PList('node_list',
PStruct('node',
PFNodeType,
PAddress('address'),
PUUID('uuid'),
PFNodeState,
PFloat('id_timestamp'),
),
)
PFCellList = PList('cell_list',
PStruct('cell',
PUUID('uuid'),
PFCellState,
),
)
PFRowList = PList('row_list',
PStruct('row',
PNumber('offset'),
PFCellList,
),
)
PFHistoryList = PList('history_list',
PStruct('history_entry',
PTID('serial'),
PNumber('size'),
),
)
PFUUIDList = PList('uuid_list',
PUUID('uuid'),
)
PFTidList = PList('tid_list',
PTID('tid'),
)
PFOidList = PList('oid_list',
POID('oid'),
)
# packets definition
class Error(Packet):
"""
Error is a special type of message, because this can be sent against
any other message, even if such a message does not expect a reply
usually.
:nodes: * -> * :nodes: M -> *
""" """)
_fmt = PStruct('error',
PNumber('code'),
PString('message'),
)
class Ping(Packet): StartOperation = notify("""
""" Tell a storage node to start operation. Before this message,
Empty request used as network barrier. it must only communicate with the primary master.
:nodes: * -> * :nodes: M -> S
""" """)
_answer = PFEmpty
class CloseClient(Packet): StopOperation = notify("""
""" Notify that the cluster is not operational anymore.
Tell peer that it can close the connection if it has finished with us. Any operation between nodes must be aborted.
:nodes: * -> * :nodes: M -> S, C
""" """)
class RequestIdentification(Packet): AskUnfinishedTransactions, AnswerUnfinishedTransactions = request("""
""" Ask unfinished transactions, which will be replicated
Request a node identification. This must be the first packet for any when they're finished.
connection.
:nodes: * -> * :nodes: S -> M
""" """)
poll_thread = True
_fmt = PStruct('request_identification',
PFNodeType,
PUUID('uuid'),
PAddress('address'),
PString('name'),
PList('devpath', PString('devid')),
PFloat('id_timestamp'),
)
_answer = PStruct('accept_identification',
PFNodeType,
PUUID('my_uuid'),
PNumber('num_partitions'),
PNumber('num_replicas'),
PUUID('your_uuid'),
)
class PrimaryMaster(Packet):
"""
Ask node identier of the current primary master.
:nodes: ctl -> A AskLockedTransactions, AnswerLockedTransactions = request("""
""" Ask locked transactions to replay committed transactions
_answer = PStruct('answer_primary', that haven't been unlocked.
PUUID('primary_uuid'),
)
class NotPrimaryMaster(Packet): :nodes: M -> S
""" """)
Notify peer that I'm not the primary master. Attach any extra information
to help the peer joining the cluster.
:nodes: SM -> * AskFinalTID, AnswerFinalTID = request("""
""" Return final tid if ttid has been committed, to recover from certain
_fmt = PStruct('not_primary_master', failures during tpc_finish.
PSignedNull('primary'),
PList('known_master_list',
PAddress('address'),
),
)
class Recovery(Packet):
"""
Ask storage nodes data needed by master to recover.
Reused by `neoctl print ids`.
:nodes: M -> S; ctl -> A -> M :nodes: M -> S; C -> M, S
""" """)
_answer = PStruct('answer_recovery',
PPTID('ptid'),
PTID('backup_tid'),
PTID('truncate_tid'),
)
class LastIDs(Packet): ValidateTransaction = notify("""
""" Do replay a committed transaction that was not unlocked.
Ask the last OID/TID so that a master can initialize its TransactionManager.
Reused by `neoctl print ids`.
:nodes: M -> S; ctl -> A -> M :nodes: M -> S
""" """)
_answer = PStruct('answer_last_ids',
POID('last_oid'),
PTID('last_tid'),
)
class PartitionTable(Packet): AskBeginTransaction, AnswerBeginTransaction = request("""
""" Ask to begin a new transaction. This maps to `tpc_begin`.
Ask storage node the remaining data needed by master to recover.
This is also how the clients get the full partition table on connection.
:nodes: M -> S; C -> M :nodes: C -> M
""" """)
_answer = PStruct('answer_partition_table',
PPTID('ptid'),
PFRowList,
)
class NotifyPartitionTable(Packet): FailedVote = request("""
""" Report storage nodes for which vote failed.
Send the full partition table to admin/storage nodes on connection. True is returned if it's still possible to finish the transaction.
:nodes: M -> A, S :nodes: C -> M
""" """, error=True)
_fmt = PStruct('send_partition_table',
PPTID('ptid'),
PFRowList,
)
class PartitionChanges(Packet): AskFinishTransaction, AnswerTransactionFinished = request("""
""" Finish a transaction. Return the TID of the committed transaction.
Notify about changes in the partition table. This maps to `tpc_finish`.
:nodes: M -> * :nodes: C -> M
""" """, ignore_when_closed=False, poll_thread=True)
_fmt = PStruct('notify_partition_changes',
PPTID('ptid'),
PList('cell_list',
PStruct('cell',
PNumber('offset'),
PUUID('uuid'),
PFCellState,
),
),
)
class StartOperation(Packet):
"""
Tell a storage node to start operation. Before this message, it must only
communicate with the primary master.
:nodes: M -> S AskLockInformation, AnswerInformationLocked = request("""
""" Commit a transaction. The new data is read-locked.
_fmt = PStruct('start_operation',
# XXX: Is this boolean needed ? Maybe this
# can be deduced from cluster state.
PBoolean('backup'),
)
class StopOperation(Packet): :nodes: M -> S
""" """, ignore_when_closed=False)
Notify that the cluster is not operational anymore. Any operation between
nodes must be aborted.
:nodes: M -> S, C InvalidateObjects = notify("""
""" Notify about a new transaction modifying objects,
invalidating client caches.
class UnfinishedTransactions(Packet): :nodes: M -> C
""" """)
Ask unfinished transactions, which will be replicated when they're finished.
:nodes: S -> M NotifyUnlockInformation = notify("""
""" Notify about a successfully committed transaction. The new data can be
_fmt = PStruct('ask_unfinished_transactions', unlocked.
PList('row_list',
PNumber('offset'),
),
)
_answer = PStruct('answer_unfinished_transactions',
PTID('max_tid'),
PList('tid_list',
PTID('unfinished_tid'),
),
)
class LockedTransactions(Packet):
"""
Ask locked transactions to replay committed transactions that haven't been
unlocked.
:nodes: M -> S :nodes: M -> S
""" """)
_answer = PStruct('answer_locked_transactions',
PDict('tid_dict',
PTID('ttid'),
PTID('tid'),
),
)
class FinalTID(Packet):
"""
Return final tid if ttid has been committed, to recover from certain
failures during tpc_finish.
:nodes: M -> S; C -> M, S AskNewOIDs, AnswerNewOIDs = request("""
""" Ask new OIDs to create objects.
_fmt = PStruct('final_tid',
PTID('ttid'),
)
_answer = PStruct('final_tid', :nodes: C -> M
PTID('tid'), """)
)
class ValidateTransaction(Packet): NotifyDeadlock = notify("""
""" Ask master to generate a new TTID that will be used by the client to
Do replay a committed transaction that was not unlocked. solve a deadlock by rebasing the transaction on top of concurrent
changes.
:nodes: M -> S :nodes: S -> M -> C
""" """)
_fmt = PStruct('validate_transaction',
PTID('ttid'),
PTID('tid'),
)
class BeginTransaction(Packet): AskRebaseTransaction, AnswerRebaseTransaction = request("""
""" Rebase a transaction to solve a deadlock.
Ask to begin a new transaction. This maps to `tpc_begin`.
:nodes: C -> M :nodes: C -> S
""" """)
_fmt = PStruct('ask_begin_transaction',
PTID('tid'),
)
_answer = PStruct('answer_begin_transaction', AskRebaseObject, AnswerRebaseObject = request("""
PTID('tid'), Rebase an object change to solve a deadlock.
)
class FailedVote(Packet): :nodes: C -> S
"""
Report storage nodes for which vote failed.
True is returned if it's still possible to finish the transaction.
:nodes: C -> M XXX: It is a request packet to simplify the implementation. For more
""" efficiency, this should be turned into a notification, and the
_fmt = PStruct('failed_vote', RebaseTransaction should answered once all objects are rebased
PTID('tid'), (so that the client can still wait on something).
PFUUIDList, """, data_path=(1, 0, 2, 0))
)
_answer = Error AskStoreObject, AnswerStoreObject = request("""
Ask to create/modify an object. This maps to `store`.
class FinishTransaction(Packet): As for IStorage, 'serial' is ZERO_TID for new objects.
"""
Finish a transaction. Return the TID of the committed transaction.
This maps to `tpc_finish`.
:nodes: C -> M :nodes: C -> S
""" """, data_path=(0, 2))
poll_thread = True
_fmt = PStruct('ask_finish_transaction',
PTID('tid'),
PFOidList,
PList('checked_list',
POID('oid'),
),
)
_answer = PStruct('answer_information_locked',
PTID('ttid'),
PTID('tid'),
)
class NotifyTransactionFinished(Packet):
"""
Notify that a transaction blocking a replication is now finished.
:nodes: M -> S AbortTransaction = notify("""
""" Abort a transaction. This maps to `tpc_abort`.
_fmt = PStruct('notify_transaction_finished',
PTID('ttid'),
PTID('max_tid'),
)
class LockInformation(Packet): :nodes: C -> S; C -> M -> S
""" """)
Commit a transaction. The new data is read-locked.
:nodes: M -> S AskStoreTransaction, AnswerStoreTransaction = request("""
""" Ask to store a transaction. Implies vote.
_fmt = PStruct('ask_lock_informations',
PTID('ttid'),
PTID('tid'),
)
_answer = PStruct('answer_information_locked', :nodes: C -> S
PTID('ttid'), """)
)
class InvalidateObjects(Packet): AskVoteTransaction, AnswerVoteTransaction = request("""
""" Ask to vote a transaction.
Notify about a new transaction modifying objects,
invalidating client caches.
:nodes: M -> C :nodes: C -> S
""" """)
_fmt = PStruct('ask_finish_transaction',
PTID('tid'),
PFOidList,
)
class UnlockInformation(Packet): AskObject, AnswerObject = request("""
""" Ask a stored object by its OID, optionally at/before a specific tid.
Notify about a successfully committed transaction. The new data can be This maps to `load/loadBefore/loadSerial`.
unlocked.
:nodes: M -> S :nodes: C -> S
""" """, data_path=(1, 3))
_fmt = PStruct('notify_unlock_information',
PTID('ttid'),
)
class GenerateOIDs(Packet): AskTIDs, AnswerTIDs = request("""
""" Ask for TIDs between a range of offsets. The order of TIDs is
Ask new OIDs to create objects. descending, and the range is [first, last). This maps to `undoLog`.
:nodes: C -> M :nodes: C -> S
""" """)
_fmt = PStruct('ask_new_oids',
PNumber('num_oids'),
)
_answer = PStruct('answer_new_oids', AskTransactionInformation, AnswerTransactionInformation = request("""
PFOidList, Ask for transaction metadata.
)
class Deadlock(Packet): :nodes: C -> S
""" """)
Ask master to generate a new TTID that will be used by the client to solve
a deadlock by rebasing the transaction on top of concurrent changes.
:nodes: S -> M -> C AskObjectHistory, AnswerObjectHistory = request("""
""" Ask history information for a given object. The order of serials is
_fmt = PStruct('notify_deadlock', descending, and the range is [first, last]. This maps to `history`.
PTID('ttid'),
PTID('locking_tid'),
)
class RebaseTransaction(Packet):
"""
Rebase a transaction to solve a deadlock.
:nodes: C -> S
"""
_fmt = PStruct('ask_rebase_transaction',
PTID('ttid'),
PTID('locking_tid'),
)
_answer = PStruct('answer_rebase_transaction', :nodes: C -> S
PFOidList, """)
)
class RebaseObject(Packet): AskPartitionList, AnswerPartitionList = request("""
""" Ask information about partitions.
Rebase an object change to solve a deadlock.
:nodes: C -> S :nodes: ctl -> A
""")
XXX: It is a request packet to simplify the implementation. For more AskNodeList, AnswerNodeList = request("""
efficiency, this should be turned into a notification, and the Ask information about nodes.
RebaseTransaction should answered once all objects are rebased
(so that the client can still wait on something).
"""
_fmt = PStruct('ask_rebase_object',
PTID('ttid'),
PTID('oid'),
)
_answer = PStruct('answer_rebase_object',
POption('conflict',
PTID('serial'),
PTID('conflict_serial'),
POption('data',
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
),
)
)
class StoreObject(Packet):
"""
Ask to create/modify an object. This maps to `store`.
As for IStorage, 'serial' is ZERO_TID for new objects. :nodes: ctl -> A
""")
:nodes: C -> S SetNodeState = request("""
""" Change the state of a node.
_fmt = PStruct('ask_store_object',
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
PTID('tid'),
)
_answer = PStruct('answer_store_object',
PTID('conflict'),
)
class AbortTransaction(Packet):
"""
Abort a transaction. This maps to `tpc_abort`.
:nodes: C -> S; C -> M -> S
"""
_fmt = PStruct('abort_transaction',
PTID('tid'),
PFUUIDList, # unused for * -> S
)
class StoreTransaction(Packet):
"""
Ask to store a transaction. Implies vote.
:nodes: C -> S
"""
_fmt = PStruct('ask_store_transaction',
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PFOidList,
)
_answer = PFEmpty
class VoteTransaction(Packet):
"""
Ask to vote a transaction.
:nodes: C -> S
"""
_fmt = PStruct('ask_vote_transaction',
PTID('tid'),
)
_answer = PFEmpty
class GetObject(Packet): :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
Ask a stored object by its OID, optionally at/before a specific tid.
This maps to `load/loadBefore/loadSerial`.
:nodes: C -> S AddPendingNodes = request("""
""" Mark given pending nodes as running, for future inclusion when tweaking
_fmt = PStruct('ask_object', the partition table.
POID('oid'),
PTID('at'),
PTID('before'),
)
_answer = PStruct('answer_object',
POID('oid'),
PTID('serial_start'),
PTID('serial_end'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
)
class TIDList(Packet):
"""
Ask for TIDs between a range of offsets. The order of TIDs is descending,
and the range is [first, last). This maps to `undoLog`.
:nodes: C -> S :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
_fmt = PStruct('ask_tids',
PIndex('first'),
PIndex('last'),
PNumber('partition'),
)
_answer = PStruct('answer_tids', TweakPartitionTable, AnswerTweakPartitionTable = request("""
PFTidList, Ask the master to balance the partition table, optionally excluding
) specific nodes in anticipation of removing them.
class TIDListFrom(Packet): :nodes: ctl -> A -> M
""" """)
Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
Used by `iterator`.
:nodes: C -> S SetNumReplicas = request("""
""" Set the number of replicas.
_fmt = PStruct('tid_list_from',
PTID('min_tid'),
PTID('max_tid'),
PNumber('length'),
PNumber('partition'),
)
_answer = PStruct('answer_tids',
PFTidList,
)
class TransactionInformation(Packet):
"""
Ask for transaction metadata.
:nodes: C -> S :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
_fmt = PStruct('ask_transaction_information',
PTID('tid'),
)
_answer = PStruct('answer_transaction_information',
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PBoolean('packed'),
PFOidList,
)
class ObjectHistory(Packet):
"""
Ask history information for a given object. The order of serials is
descending, and the range is [first, last]. This maps to `history`.
:nodes: C -> S SetClusterState = request("""
""" Set the cluster state.
_fmt = PStruct('ask_object_history',
POID('oid'),
PIndex('first'),
PIndex('last'),
)
_answer = PStruct('answer_object_history',
POID('oid'),
PFHistoryList,
)
class PartitionList(Packet):
"""
Ask information about partitions.
:nodes: ctl -> A :nodes: ctl -> A -> M
""" """, error=True, ignore_when_closed=False)
_fmt = PStruct('ask_partition_list',
PNumber('min_offset'),
PNumber('max_offset'),
PUUID('uuid'),
)
_answer = PStruct('answer_partition_list',
PPTID('ptid'),
PFRowList,
)
class NodeList(Packet):
"""
Ask information about nodes.
:nodes: ctl -> A Repair = request("""
""" Ask storage nodes to repair their databases.
_fmt = PStruct('ask_node_list',
PFNodeType,
)
_answer = PStruct('answer_node_list', :nodes: ctl -> A -> M
PFNodeList, """, error=True)
)
class SetNodeState(Packet): NotifyRepair = notify("""
""" Repair is translated to this message, asking a specific storage node to
Change the state of a node. repair its database.
:nodes: ctl -> A -> M :nodes: M -> S
""" """)
_fmt = PStruct('set_node_state',
PUUID('uuid'),
PFNodeState,
)
_answer = Error NotifyClusterInformation = notify("""
Notify about a cluster state change.
class AddPendingNodes(Packet): :nodes: M -> *
""" """)
Mark given pending nodes as running, for future inclusion when tweaking
the partition table.
:nodes: ctl -> A -> M AskClusterState, AnswerClusterState = request("""
""" Ask the state of the cluster
_fmt = PStruct('add_pending_nodes',
PFUUIDList,
)
_answer = Error :nodes: ctl -> A; A -> M
""")
class TweakPartitionTable(Packet): AskObjectUndoSerial, AnswerObjectUndoSerial = request("""
""" Ask storage the serial where object data is when undoing given
Ask the master to balance the partition table, optionally excluding transaction, for a list of OIDs.
specific nodes in anticipation of removing them.
:nodes: ctl -> A -> M Answer a dict mapping oids to 3-tuples:
""" current_serial (TID)
_fmt = PStruct('tweak_partition_table', The latest serial visible to the undoing transaction.
PFUUIDList, undo_serial (TID)
) Where undone data is (tid at which data is before given undo).
is_current (bool)
_answer = Error If current_serial's data is current on storage.
class NotifyNodeInformation(Packet):
"""
Notify information about one or more nodes.
:nodes: M -> *
"""
_fmt = PStruct('notify_node_informations',
PFloat('id_timestamp'),
PFNodeList,
)
class SetClusterState(Packet):
"""
Set the cluster state.
:nodes: ctl -> A -> M
"""
_fmt = PStruct('set_cluster_state',
PEnum('state', ClusterStates),
)
_answer = Error
class Repair(Packet):
"""
Ask storage nodes to repair their databases.
:nodes: ctl -> A -> M
"""
_flags = map(PBoolean, ('dry_run',
# 'prune_orphan' (commented because it's the only option for the moment)
))
_fmt = PStruct('repair',
PFUUIDList,
*_flags)
_answer = Error :nodes: C -> S
""")
class RepairOne(Packet): AskTIDsFrom, AnswerTIDsFrom = request("""
""" Ask for length TIDs starting at min_tid. The order of TIDs is ascending.
Repair is translated to this message, asking a specific storage node to Used by `iterator`.
repair its database.
:nodes: M -> S :nodes: C -> S
""" """)
_fmt = PStruct('repair', *Repair._flags)
class ClusterInformation(Packet): AskPack, AnswerPack = request("""
""" Request a pack at given TID.
Notify about a cluster state change.
:nodes: M -> * :nodes: C -> M -> S
""" """, ignore_when_closed=False)
_fmt = PStruct('notify_cluster_information',
PEnum('state', ClusterStates),
)
class ClusterState(Packet): CheckReplicas = request("""
""" Ask the cluster to search for mismatches between replicas, metadata
Ask the state of the cluster only, and optionally within a specific range. Reference nodes can be
specified.
:nodes: ctl -> A; A -> M :nodes: ctl -> A -> M
""" """, error=True)
_answer = PStruct('answer_cluster_state', CheckPartition = notify("""
PEnum('state', ClusterStates), Ask a storage node to compare a partition with all other nodes.
) Like for CheckReplicas, only metadata are checked, optionally within a
specific range. A reference node can be specified.
class ObjectUndoSerial(Packet): :nodes: M -> S
""" """)
Ask storage the serial where object data is when undoing given transaction,
for a list of OIDs.
object_tid_dict has the following format: AskCheckTIDRange, AnswerCheckTIDRange = request("""
key: oid Ask some stats about a range of transactions.
value: 3-tuple Used to know if there are differences between a replicating node and
current_serial (TID) reference node.
The latest serial visible to the undoing transaction.
undo_serial (TID)
Where undone data is (tid at which data is before given undo).
is_current (bool)
If current_serial's data is current on storage.
:nodes: C -> S :nodes: S -> S
""" """)
_fmt = PStruct('ask_undo_transaction',
PTID('tid'),
PTID('ltid'),
PTID('undone_tid'),
PFOidList,
)
_answer = PStruct('answer_undo_transaction',
PDict('object_tid_dict',
POID('oid'),
PStruct('object_tid_value',
PTID('current_serial'),
PTID('undo_serial'),
PBoolean('is_current'),
),
),
)
class CheckCurrentSerial(Packet):
"""
Check if given serial is current for the given oid, and lock it so that
this state is not altered until transaction ends.
This maps to `checkCurrentSerialInTransaction`.
:nodes: C -> S AskCheckSerialRange, AnswerCheckSerialRange = request("""
""" Ask some stats about a range of object history.
_fmt = PStruct('ask_check_current_serial', Used to know if there are differences between a replicating node and
PTID('tid'), reference node.
POID('oid'),
PTID('serial'),
)
_answer = StoreObject._answer :nodes: S -> S
""")
class Pack(Packet): NotifyPartitionCorrupted = notify("""
""" Notify that mismatches were found while check replicas for a partition.
Request a pack at given TID.
:nodes: C -> M -> S :nodes: S -> M
""" """)
_fmt = PStruct('ask_pack',
PTID('tid'),
)
_answer = PStruct('answer_pack', NotifyReady = notify("""
PBoolean('status'), Notify that we're ready to serve requests.
)
class CheckReplicas(Packet): :nodes: S -> M
""" """)
Ask the cluster to search for mismatches between replicas, metadata only,
and optionally within a specific range. Reference nodes can be specified.
:nodes: ctl -> A -> M AskLastTransaction, AnswerLastTransaction = request("""
""" Ask last committed TID.
_fmt = PStruct('check_replicas',
PDict('partition_dict',
PNumber('partition'),
PUUID('source'),
),
PTID('min_tid'),
PTID('max_tid'),
)
_answer = Error
class CheckPartition(Packet):
"""
Ask a storage node to compare a partition with all other nodes.
Like for CheckReplicas, only metadata are checked, optionally within a
specific range. A reference node can be specified.
:nodes: M -> S :nodes: C -> M; ctl -> A -> M
""" """, poll_thread=True)
_fmt = PStruct('check_partition',
PNumber('partition'),
PStruct('source',
PString('upstream_name'),
PAddress('address'),
),
PTID('min_tid'),
PTID('max_tid'),
)
class CheckTIDRange(Packet):
"""
Ask some stats about a range of transactions.
Used to know if there are differences between a replicating node and
reference node.
:nodes: S -> S AskCheckCurrentSerial, AnswerCheckCurrentSerial = request("""
""" Check if given serial is current for the given oid, and lock it so that
_fmt = PStruct('ask_check_tid_range', this state is not altered until transaction ends.
PNumber('partition'), This maps to `checkCurrentSerialInTransaction`.
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
)
_answer = PStruct('answer_check_tid_range',
PNumber('count'),
PChecksum('checksum'),
PTID('max_tid'),
)
class CheckSerialRange(Packet):
"""
Ask some stats about a range of object history.
Used to know if there are differences between a replicating node and
reference node.
:nodes: S -> S :nodes: C -> S
""" """)
_fmt = PStruct('ask_check_serial_range',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
)
_answer = PStruct('answer_check_serial_range',
PNumber('count'),
PChecksum('tid_checksum'),
PTID('max_tid'),
PChecksum('oid_checksum'),
POID('max_oid'),
)
class PartitionCorrupted(Packet):
"""
Notify that mismatches were found while check replicas for a partition.
:nodes: S -> M NotifyTransactionFinished = notify("""
""" Notify that a transaction blocking a replication is now finished.
_fmt = PStruct('partition_corrupted',
PNumber('partition'),
PList('cell_list',
PUUID('uuid'),
),
)
class LastTransaction(Packet):
"""
Ask last committed TID.
:nodes: C -> M; ctl -> A -> M :nodes: M -> S
""" """)
poll_thread = True
_answer = PStruct('answer_last_transaction', Replicate = notify("""
PTID('tid'), Notify a storage node to replicate partitions up to given 'tid'
) and from given sources.
class NotifyReady(Packet): args: tid, upstream_name, {partition: address}
""" - upstream_name: replicate from an upstream cluster
Notify that we're ready to serve requests. - address: address of the source storage node, or None if there's
no new data up to 'tid' for the given partition
:nodes: S -> M :nodes: M -> S
""" """)
class FetchTransactions(Packet): NotifyReplicationDone = notify("""
""" Notify the master node that a partition has been successfully
Ask a storage node to send all transaction data we don't have, replicated from a storage to another.
and reply with the list of transactions we should not have.
:nodes: S -> S :nodes: S -> M
""" """)
_fmt = PStruct('ask_transaction_list',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
PFTidList, # already known transactions
)
_answer = PStruct('answer_transaction_list',
PTID('pack_tid'),
PTID('next_tid'),
PFTidList, # transactions to delete
)
class AddTransaction(Packet):
"""
Send metadata of a transaction to a node that do not have them.
:nodes: S -> S AskFetchTransactions, AnswerFetchTransactions = request("""
""" Ask a storage node to send all transaction data we don't have,
nodelay = False and reply with the list of transactions we should not have.
_fmt = PStruct('add_transaction',
PTID('tid'),
PString('user'),
PString('description'),
PString('extension'),
PBoolean('packed'),
PTID('ttid'),
PFOidList,
)
class FetchObjects(Packet):
"""
Ask a storage node to send object records we don't have,
and reply with the list of records we should not have.
:nodes: S -> S :nodes: S -> S
""" """)
_fmt = PStruct('ask_object_list',
PNumber('partition'),
PNumber('length'),
PTID('min_tid'),
PTID('max_tid'),
POID('min_oid'),
PDict('object_dict', # already known objects
PTID('serial'),
PFOidList,
),
)
_answer = PStruct('answer_object_list',
PTID('pack_tid'),
PTID('next_tid'),
POID('next_oid'),
PDict('object_dict', # objects to delete
PTID('serial'),
PFOidList,
),
)
class AddObject(Packet):
"""
Send an object record to a node that do not have it.
:nodes: S -> S AskFetchObjects, AnswerFetchObjects = request("""
""" Ask a storage node to send object records we don't have,
nodelay = False and reply with the list of records we should not have.
_fmt = PStruct('add_object',
POID('oid'),
PTID('serial'),
PBoolean('compression'),
PChecksum('checksum'),
PString('data'),
PTID('data_serial'),
)
class Replicate(Packet):
"""
Notify a storage node to replicate partitions up to given 'tid'
and from given sources.
- upstream_name: replicate from an upstream cluster :nodes: S -> S
- address: address of the source storage node, or None if there's no new """)
data up to 'tid' for the given partition
:nodes: M -> S AddTransaction = notify("""
""" Send metadata of a transaction to a node that does not have them.
_fmt = PStruct('replicate',
PTID('tid'),
PString('upstream_name'),
PDict('source_dict',
PNumber('partition'),
PAddress('address'),
)
)
class ReplicationDone(Packet):
"""
Notify the master node that a partition has been successfully replicated
from a storage to another.
:nodes: S -> M :nodes: S -> S
""" """, nodelay=False)
_fmt = PStruct('notify_replication_done',
PNumber('offset'),
PTID('tid'),
)
class Truncate(Packet): AddObject = notify("""
""" Send an object record to a node that does not have it.
Request DB to be truncated. Also used to leave backup mode.
:nodes: ctl -> A -> M; M -> S :nodes: S -> S
""" """, nodelay=False, data_path=(0, 2))
_fmt = PStruct('truncate',
PTID('tid'),
)
_answer = Error Truncate = request("""
Request DB to be truncated. Also used to leave backup mode.
class FlushLog(Packet): :nodes: ctl -> A -> M; M -> S
""" """, error=True)
Request all nodes to flush their logs.
:nodes: ctl -> A -> M -> * FlushLog = notify("""
""" Request all nodes to flush their logs.
:nodes: ctl -> A -> M -> *
""")
_next_code = 0 del notify, request
def register(request, ignore_when_closed=None):
""" Register a packet in the packet registry """
global _next_code
code = _next_code
assert code < RESPONSE_MASK
_next_code = code + 1
if request is Error:
code |= RESPONSE_MASK
# register the request
request._code = code
answer = request._answer
if ignore_when_closed is None:
# By default, on a closed connection:
# - request: ignore
# - answer: keep
# - notification: keep
ignore_when_closed = answer is not None
request._ignore_when_closed = ignore_when_closed
if answer in (Error, None):
return request
# build a class for the answer
answer = type('Answer' + request.__name__, (Packet, ), {})
answer._fmt = request._answer
answer.poll_thread = request.poll_thread
answer._request = request
assert answer._code is None, "Answer of %s is already used" % (request, )
answer._code = code | RESPONSE_MASK
request._answer = answer
return request, answer
class Packets(dict):
"""
Packet registry that checks packet code uniqueness and provides an index
"""
def __metaclass__(name, base, d):
# this builds a "singleton"
cls = type('PacketRegistry', base, d)()
for k, v in d.iteritems():
if isinstance(v, type) and issubclass(v, Packet):
v.handler_method_name = k[0].lower() + k[1:]
cls[v._code] = v
return cls
Error = register(
Error)
RequestIdentification, AcceptIdentification = register(
RequestIdentification, ignore_when_closed=True)
Ping, Pong = register(
Ping)
CloseClient = register(
CloseClient)
AskPrimary, AnswerPrimary = register(
PrimaryMaster)
NotPrimaryMaster = register(
NotPrimaryMaster)
NotifyNodeInformation = register(
NotifyNodeInformation)
AskRecovery, AnswerRecovery = register(
Recovery)
AskLastIDs, AnswerLastIDs = register(
LastIDs)
AskPartitionTable, AnswerPartitionTable = register(
PartitionTable)
SendPartitionTable = register(
NotifyPartitionTable)
NotifyPartitionChanges = register(
PartitionChanges)
StartOperation = register(
StartOperation)
StopOperation = register(
StopOperation)
AskUnfinishedTransactions, AnswerUnfinishedTransactions = register(
UnfinishedTransactions)
AskLockedTransactions, AnswerLockedTransactions = register(
LockedTransactions)
AskFinalTID, AnswerFinalTID = register(
FinalTID)
ValidateTransaction = register(
ValidateTransaction)
AskBeginTransaction, AnswerBeginTransaction = register(
BeginTransaction)
FailedVote = register(
FailedVote)
AskFinishTransaction, AnswerTransactionFinished = register(
FinishTransaction, ignore_when_closed=False)
AskLockInformation, AnswerInformationLocked = register(
LockInformation, ignore_when_closed=False)
InvalidateObjects = register(
InvalidateObjects)
NotifyUnlockInformation = register(
UnlockInformation)
AskNewOIDs, AnswerNewOIDs = register(
GenerateOIDs)
NotifyDeadlock = register(
Deadlock)
AskRebaseTransaction, AnswerRebaseTransaction = register(
RebaseTransaction)
AskRebaseObject, AnswerRebaseObject = register(
RebaseObject)
AskStoreObject, AnswerStoreObject = register(
StoreObject)
AbortTransaction = register(
AbortTransaction)
AskStoreTransaction, AnswerStoreTransaction = register(
StoreTransaction)
AskVoteTransaction, AnswerVoteTransaction = register(
VoteTransaction)
AskObject, AnswerObject = register(
GetObject)
AskTIDs, AnswerTIDs = register(
TIDList)
AskTransactionInformation, AnswerTransactionInformation = register(
TransactionInformation)
AskObjectHistory, AnswerObjectHistory = register(
ObjectHistory)
AskPartitionList, AnswerPartitionList = register(
PartitionList)
AskNodeList, AnswerNodeList = register(
NodeList)
SetNodeState = register(
SetNodeState, ignore_when_closed=False)
AddPendingNodes = register(
AddPendingNodes, ignore_when_closed=False)
TweakPartitionTable = register(
TweakPartitionTable, ignore_when_closed=False)
SetClusterState = register(
SetClusterState, ignore_when_closed=False)
Repair = register(
Repair)
NotifyRepair = register(
RepairOne)
NotifyClusterInformation = register(
ClusterInformation)
AskClusterState, AnswerClusterState = register(
ClusterState)
AskObjectUndoSerial, AnswerObjectUndoSerial = register(
ObjectUndoSerial)
AskTIDsFrom, AnswerTIDsFrom = register(
TIDListFrom)
AskPack, AnswerPack = register(
Pack, ignore_when_closed=False)
CheckReplicas = register(
CheckReplicas)
CheckPartition = register(
CheckPartition)
AskCheckTIDRange, AnswerCheckTIDRange = register(
CheckTIDRange)
AskCheckSerialRange, AnswerCheckSerialRange = register(
CheckSerialRange)
NotifyPartitionCorrupted = register(
PartitionCorrupted)
NotifyReady = register(
NotifyReady)
AskLastTransaction, AnswerLastTransaction = register(
LastTransaction)
AskCheckCurrentSerial, AnswerCheckCurrentSerial = register(
CheckCurrentSerial)
NotifyTransactionFinished = register(
NotifyTransactionFinished)
Replicate = register(
Replicate)
NotifyReplicationDone = register(
ReplicationDone)
AskFetchTransactions, AnswerFetchTransactions = register(
FetchTransactions)
AskFetchObjects, AnswerFetchObjects = register(
FetchObjects)
AddTransaction = register(
AddTransaction)
AddObject = register(
AddObject)
Truncate = register(
Truncate)
FlushLog = register(
FlushLog)
def Errors(): def Errors():
registry_dict = {} registry_dict = {}
handler_method_name_dict = {} handler_method_name_dict = {}
Error = Packets.Error
def register_error(code): def register_error(code):
return lambda self, message='': Error(code, message) return lambda self, message='': Error(code, message)
for error in ErrorCodes: for error in ErrorCodes:
...@@ -1836,19 +851,20 @@ from operator import itemgetter ...@@ -1836,19 +851,20 @@ from operator import itemgetter
def formatNodeList(node_list, prefix='', _sort_key=itemgetter(2)): def formatNodeList(node_list, prefix='', _sort_key=itemgetter(2)):
if node_list: if node_list:
node_list.sort(key=_sort_key)
node_list = [( node_list = [(
uuid_str(uuid), str(node_type), uuid_str(uuid), str(node_type),
('[%s]:%s' if ':' in addr[0] else '%s:%s') ('[%s]:%s' if ':' in addr[0] else '%s:%s')
% addr if addr else '', str(state), % addr if addr else '', str(state),
str(id_timestamp and datetime.utcfromtimestamp(id_timestamp))) str(id_timestamp and datetime.utcfromtimestamp(id_timestamp)))
for node_type, addr, uuid, state, id_timestamp in node_list] for node_type, addr, uuid, state, id_timestamp
in sorted(node_list, key=_sort_key)]
t = ''.join('%%-%us | ' % max(len(x[i]) for x in node_list) t = ''.join('%%-%us | ' % max(len(x[i]) for x in node_list)
for i in xrange(len(node_list[0]) - 1)) for i in xrange(len(node_list[0]) - 1))
return map((prefix + t + '%s').__mod__, node_list) return map((prefix + t + '%s').__mod__, node_list)
return () return ()
NotifyNodeInformation._neolog = staticmethod(lambda timestamp, node_list: Packets.NotifyNodeInformation._neolog = staticmethod(
lambda timestamp, node_list:
((timestamp,), formatNodeList(node_list, ' ! '))) ((timestamp,), formatNodeList(node_list, ' ! ')))
Error._neolog = staticmethod(lambda *args: ((), ("%s (%s)" % args,))) Packets.Error._neolog = staticmethod(lambda *args: ((), ("%s (%s)" % args,)))
...@@ -86,15 +86,9 @@ class PartitionTable(object): ...@@ -86,15 +86,9 @@ class PartitionTable(object):
'a cell became non-readable whereas all cells were readable' 'a cell became non-readable whereas all cells were readable'
def __init__(self, num_partitions, num_replicas): def __init__(self, num_partitions, num_replicas):
self._id = None
self.np = num_partitions self.np = num_partitions
self.nr = num_replicas self.nr = num_replicas
self.num_filled_rows = 0 self.clear()
# Note: don't use [[]] * num_partition construct, as it duplicates
# instance *references*, so the outer list contains really just one
# inner list instance.
self.partition_list = [[] for _ in xrange(num_partitions)]
self.count_dict = {}
def getID(self): def getID(self):
return self._id return self._id
...@@ -113,7 +107,16 @@ class PartitionTable(object): ...@@ -113,7 +107,16 @@ class PartitionTable(object):
# instance *references*, so the outer list contains really just one # instance *references*, so the outer list contains really just one
# inner list instance. # inner list instance.
self.partition_list = [[] for _ in xrange(self.np)] self.partition_list = [[] for _ in xrange(self.np)]
self.count_dict.clear() self.count_dict = {}
def addNodeList(self, node_list):
"""Add nodes"""
added_list = []
for node in node_list:
if node not in self.count_dict:
self.count_dict[node] = 0
added_list.append(node)
return added_list
def getAssignedPartitionList(self, uuid): def getAssignedPartitionList(self, uuid):
""" Return the partition assigned to the specified UUID """ """ Return the partition assigned to the specified UUID """
...@@ -203,31 +206,31 @@ class PartitionTable(object): ...@@ -203,31 +206,31 @@ class PartitionTable(object):
del self.count_dict[node] del self.count_dict[node]
return not count return not count
def load(self, ptid, row_list, nm): def _load(self, ptid, num_replicas, row_list, getByUUID):
self.__init__(len(row_list), num_replicas)
self._id = ptid
for offset, row in enumerate(row_list):
for uuid, state in row:
node = getByUUID(uuid)
self._setCell(offset, node, state)
def load(self, ptid, num_replicas, row_list, nm):
""" """
Load the partition table with the specified PTID, discard all previous Load the partition table with the specified PTID, discard all previous
content. content.
""" """
self.clear() self._load(ptid, num_replicas, row_list, nm.getByUUID)
self._id = ptid
for offset, row in row_list:
if offset >= self.getPartitions():
raise IndexError
for uuid, state in row:
node = nm.getByUUID(uuid)
# the node must be known by the node manager
assert node is not None
self._setCell(offset, node, state)
logging.debug('partition table loaded (ptid=%s)', ptid) logging.debug('partition table loaded (ptid=%s)', ptid)
self.log() self.log()
def update(self, ptid, cell_list, nm): def update(self, ptid, num_replicas, cell_list, nm):
""" """
Update the partition with the cell list supplied. If a node Update the partition with the cell list supplied. If a node
is not known, it is created in the node manager and set as unavailable is not known, it is created in the node manager and set as unavailable
""" """
assert self._id < ptid, (self._id, ptid) assert self._id < ptid, (self._id, ptid)
self._id = ptid self._id = ptid
self.nr = num_replicas
readable_list = [] readable_list = []
for row in self.partition_list: for row in self.partition_list:
if not all(cell.isReadable() for cell in row): if not all(cell.isReadable() for cell in row):
...@@ -310,14 +313,11 @@ class PartitionTable(object): ...@@ -310,14 +313,11 @@ class PartitionTable(object):
return True return True
def getRow(self, offset): def getRow(self, offset):
row = self.partition_list[offset] return [(cell.getUUID(), cell.getState())
if row is None: for cell in self.partition_list[offset]]
return []
return [(cell.getUUID(), cell.getState()) for cell in row]
def getRowList(self): def getRowList(self):
getRow = self.getRow return map(self.getRow, xrange(self.np))
return [(x, getRow(x)) for x in xrange(self.np)]
class MTPartitionTable(PartitionTable): class MTPartitionTable(PartitionTable):
""" Thread-safe aware version of the partition table, override only methods """ Thread-safe aware version of the partition table, override only methods
......
...@@ -166,65 +166,6 @@ def parseMasterList(masters): ...@@ -166,65 +166,6 @@ def parseMasterList(masters):
return map(parseNodeAddress, masters.split()) return map(parseNodeAddress, masters.split())
class ReadBuffer(object):
"""
Implementation of a lazy buffer. Main purpose if to reduce useless
copies of data by storing chunks and join them only when the requested
size is available.
TODO: For better performance, use:
- socket.recv_into (64kiB blocks)
- struct.unpack_from
- and a circular buffer of dynamic size (initial size:
twice the length passed to socket.recv_into ?)
"""
def __init__(self):
self.size = 0
self.content = deque()
def append(self, data):
""" Append some data and compute the new buffer size """
self.size += len(data)
self.content.append(data)
def __len__(self):
""" Return the current buffer size """
return self.size
def read(self, size):
""" Read and consume size bytes """
if self.size < size:
return None
self.size -= size
chunk_list = []
pop_chunk = self.content.popleft
append_data = chunk_list.append
to_read = size
# select required chunks
while to_read > 0:
chunk_data = pop_chunk()
to_read -= len(chunk_data)
append_data(chunk_data)
if to_read < 0:
# too many bytes consumed, cut the last chunk
last_chunk = chunk_list[-1]
keep, let = last_chunk[:to_read], last_chunk[to_read:]
self.content.appendleft(let)
chunk_list[-1] = keep
# join all chunks (one copy)
data = ''.join(chunk_list)
assert len(data) == size
return data
def clear(self):
""" Erase all buffer content """
self.size = 0
self.content.clear()
dummy_read_buffer = ReadBuffer()
dummy_read_buffer.append = lambda _: None
class cached_property(object): class cached_property(object):
""" """
A property that is only computed once per instance and then replaces itself A property that is only computed once per instance and then replaces itself
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import sys import sys
from collections import defaultdict from collections import defaultdict
from functools import partial
from time import time from time import time
from neo.lib import logging, util from neo.lib import logging, util
...@@ -76,13 +77,11 @@ class Application(BaseApplication): ...@@ -76,13 +77,11 @@ class Application(BaseApplication):
@classmethod @classmethod
def _buildOptionParser(cls): def _buildOptionParser(cls):
_ = cls.option_parser parser = cls.option_parser
_.description = "NEO Master node" parser.description = "NEO Master node"
cls.addCommonServerOptions('master', '127.0.0.1:10000', '') cls.addCommonServerOptions('master', '127.0.0.1:10000', '')
_ = _.group('master') _ = parser.group('master')
_.int('r', 'replicas', default=0, help="replicas number")
_.int('p', 'partitions', default=100, help="partitions number")
_.int('A', 'autostart', _.int('A', 'autostart',
help="minimum number of pending storage nodes to automatically" help="minimum number of pending storage nodes to automatically"
" start new cluster (to avoid unwanted recreation of the" " start new cluster (to avoid unwanted recreation of the"
...@@ -91,8 +90,12 @@ class Application(BaseApplication): ...@@ -91,8 +90,12 @@ class Application(BaseApplication):
help='the name of cluster to backup') help='the name of cluster to backup')
_('M', 'upstream-masters', parse=util.parseMasterList, _('M', 'upstream-masters', parse=util.parseMasterList,
help='list of master nodes in the cluster to backup') help='list of master nodes in the cluster to backup')
_.int('u', 'uuid', _.int('i', 'nid',
help="specify an UUID to use for this process (testing purpose)") help="specify an NID to use for this process (testing purpose)")
_ = parser.group('database creation')
_.int('r', 'replicas', default=0, help="replicas number")
_.int('p', 'partitions', default=100, help="partitions number")
def __init__(self, config): def __init__(self, config):
super(Application, self).__init__( super(Application, self).__init__(
...@@ -108,7 +111,7 @@ class Application(BaseApplication): ...@@ -108,7 +111,7 @@ class Application(BaseApplication):
for master_address in config['masters']: for master_address in config['masters']:
self.nm.createMaster(address=master_address) self.nm.createMaster(address=master_address)
self._node = self.nm.createMaster(address=self.server, self._node = self.nm.createMaster(address=self.server,
uuid=config.get('uuid')) uuid=config.get('nid'))
logging.node(self.name, self.uuid) logging.node(self.name, self.uuid)
logging.debug('IP address is %s, port is %d', *self.server) logging.debug('IP address is %s, port is %d', *self.server)
...@@ -117,14 +120,14 @@ class Application(BaseApplication): ...@@ -117,14 +120,14 @@ class Application(BaseApplication):
replicas = config['replicas'] replicas = config['replicas']
partitions = config['partitions'] partitions = config['partitions']
if replicas < 0: if replicas < 0:
raise RuntimeError, 'replicas must be a positive integer' sys.exit('replicas must be a positive integer')
if partitions <= 0: if partitions <= 0:
raise RuntimeError, 'partitions must be more than zero' sys.exit('partitions must be more than zero')
self.pt = PartitionTable(partitions, replicas)
logging.info('Configuration:') logging.info('Configuration:')
logging.info('Partitions: %d', partitions) logging.info('Partitions: %d', partitions)
logging.info('Replicas : %d', replicas) logging.info('Replicas : %d', replicas)
logging.info('Name : %s', self.name) logging.info('Name : %s', self.name)
self.newPartitionTable = partial(PartitionTable, partitions, replicas)
self.listening_conn = None self.listening_conn = None
self.cluster_state = None self.cluster_state = None
...@@ -196,7 +199,7 @@ class Application(BaseApplication): ...@@ -196,7 +199,7 @@ class Application(BaseApplication):
node_dict[NodeTypes.MASTER].append(node_info) node_dict[NodeTypes.MASTER].append(node_info)
return node_dict return node_dict
def broadcastNodesInformation(self, node_list, exclude=None): def broadcastNodesInformation(self, node_list):
""" """
Broadcast changes for a set a nodes Broadcast changes for a set a nodes
Send only one packet per connection to reduce bandwidth Send only one packet per connection to reduce bandwidth
...@@ -209,20 +212,26 @@ class Application(BaseApplication): ...@@ -209,20 +212,26 @@ class Application(BaseApplication):
# We don't skip pending storage nodes because we don't send them # We don't skip pending storage nodes because we don't send them
# the full list of nodes when they're added, and it's also quite # the full list of nodes when they're added, and it's also quite
# useful to notify them about new masters. # useful to notify them about new masters.
if node_list and node is not exclude: if node_list:
node.send(Packets.NotifyNodeInformation(now, node_list)) node.send(Packets.NotifyNodeInformation(now, node_list))
def broadcastPartitionChanges(self, cell_list): def broadcastPartitionChanges(self, cell_list, num_replicas=None):
"""Broadcast a Notify Partition Changes packet.""" """Broadcast a Notify Partition Changes packet."""
if cell_list: pt = self.pt
ptid = self.pt.setNextID() if num_replicas is not None:
self.pt.logUpdated() pt.setReplicas(num_replicas)
packet = Packets.NotifyPartitionChanges(ptid, cell_list) elif cell_list:
for node in self.nm.getIdentifiedList(): num_replicas = pt.getReplicas()
# As for broadcastNodesInformation, we don't send the full PT else:
# when pending storage nodes are added, so keep them notified. return
if not node.isMaster(): packet = Packets.NotifyPartitionChanges(
node.send(packet) pt.setNextID(), num_replicas, cell_list)
pt.logUpdated()
for node in self.nm.getIdentifiedList():
# As for broadcastNodesInformation, we don't send the full PT
# when pending storage nodes are added, so keep them notified.
if not node.isMaster():
node.send(packet)
def provideService(self): def provideService(self):
""" """
...@@ -437,16 +446,7 @@ class Application(BaseApplication): ...@@ -437,16 +446,7 @@ class Application(BaseApplication):
conn.send(notification_packet) conn.send(notification_packet)
elif conn.isServer(): elif conn.isServer():
continue continue
if node.isClient(): if node.isMaster():
if state == ClusterStates.RUNNING:
handler = self.client_service_handler
elif state == ClusterStates.BACKINGUP:
handler = self.client_ro_service_handler
else:
if state != ClusterStates.STOPPING:
conn.abort()
continue
elif node.isMaster():
if state == ClusterStates.RECOVERING: if state == ClusterStates.RECOVERING:
handler = self.election_handler handler = self.election_handler
else: else:
...@@ -454,10 +454,16 @@ class Application(BaseApplication): ...@@ -454,10 +454,16 @@ class Application(BaseApplication):
elif node.isStorage() and storage_handler: elif node.isStorage() and storage_handler:
handler = storage_handler handler = storage_handler
else: else:
# There's a single handler type for admins.
# Client can't change handler without being first disconnected.
assert state in (
ClusterStates.STOPPING,
ClusterStates.STOPPING_BACKUP,
) or not node.isClient(), (state, node)
continue # keep handler continue # keep handler
if type(handler) is not type(conn.getLastHandler()): if type(handler) is not type(conn.getLastHandler()):
conn.setHandler(handler) conn.setHandler(handler)
handler.connectionCompleted(conn, new=False) handler.handlerSwitched(conn, new=False)
self.cluster_state = state self.cluster_state = state
def getNewUUID(self, uuid, address, node_type): def getNewUUID(self, uuid, address, node_type):
...@@ -578,7 +584,9 @@ class Application(BaseApplication): ...@@ -578,7 +584,9 @@ class Application(BaseApplication):
self.tm.executeQueuedEvents() self.tm.executeQueuedEvents()
def startStorage(self, node): def startStorage(self, node):
node.send(Packets.StartOperation(self.backup_tid)) # XXX: Is this boolean 'backup' field needed ?
# Maybe this can be deduced from cluster state.
node.send(Packets.StartOperation(bool(self.backup_tid)))
uuid = node.getUUID() uuid = node.getUUID()
assert uuid not in self.storage_starting_set assert uuid not in self.storage_starting_set
if uuid not in self.storage_ready_dict: if uuid not in self.storage_ready_dict:
......
...@@ -65,6 +65,7 @@ There is no conflict of node id between the 2 clusters: ...@@ -65,6 +65,7 @@ There is no conflict of node id between the 2 clusters:
class BackupApplication(object): class BackupApplication(object):
pt = None pt = None
server = None # like in BaseApplication
uuid = None uuid = None
def __init__(self, app, name, master_addresses): def __init__(self, app, name, master_addresses):
...@@ -111,17 +112,12 @@ class BackupApplication(object): ...@@ -111,17 +112,12 @@ class BackupApplication(object):
else: else:
break break
poll(1) poll(1)
node, conn, num_partitions, num_replicas = \ node, conn = bootstrap.getPrimaryConnection()
bootstrap.getPrimaryConnection()
try: try:
app.changeClusterState(ClusterStates.BACKINGUP) app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node del bootstrap, node
if num_partitions != pt.getPartitions():
raise RuntimeError("inconsistent number of partitions")
self.ignore_invalidations = True self.ignore_invalidations = True
self.pt = PartitionTable(num_partitions, num_replicas)
conn.setHandler(BackupHandler(self)) conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskPartitionTable())
conn.ask(Packets.AskLastTransaction()) conn.ask(Packets.AskLastTransaction())
# debug variable to log how big 'tid_list' can be. # debug variable to log how big 'tid_list' can be.
self.debug_tid_count = 0 self.debug_tid_count = 0
......
...@@ -23,10 +23,6 @@ from neo.lib.protocol import Packets ...@@ -23,10 +23,6 @@ from neo.lib.protocol import Packets
class MasterHandler(EventHandler): class MasterHandler(EventHandler):
"""This class implements a generic part of the event handlers.""" """This class implements a generic part of the event handlers."""
def connectionCompleted(self, conn, new=None):
if new is None:
super(MasterHandler, self).connectionCompleted(conn)
def connectionLost(self, conn, new_state=None): def connectionLost(self, conn, new_state=None):
if self.app.listening_conn: # if running if self.app.listening_conn: # if running
self._connectionLost(conn) self._connectionLost(conn)
...@@ -59,17 +55,20 @@ class MasterHandler(EventHandler): ...@@ -59,17 +55,20 @@ class MasterHandler(EventHandler):
+ app.getNodeInformationDict(node_list)[node.getType()]) + app.getNodeInformationDict(node_list)[node.getType()])
conn.send(Packets.NotifyNodeInformation(monotonic_time(), node_list)) conn.send(Packets.NotifyNodeInformation(monotonic_time(), node_list))
def askPartitionTable(self, conn): def handlerSwitched(self, conn, new):
pt = self.app.pt pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList())) # Except storages during recovery and secondary masters, all nodes
# receives the full partition table as soon as they're identified.
# It is also sent in 2 other cases:
# - to admins during recovery, whenever a newer PT is loaded;
# - to storage when switching from recovery to verification.
# After that, non-master nodes only receive incremental updates.
conn.send(Packets.SendPartitionTable(
pt.getID(), pt.getReplicas(), pt.getRowList()))
class BaseServiceHandler(MasterHandler): class BaseServiceHandler(MasterHandler):
"""This class deals with events for a service phase.""" """Common handler class for storage nodes."""
def connectionCompleted(self, conn, new):
pt = self.app.pt
conn.send(Packets.SendPartitionTable(pt.getID(), pt.getRowList()))
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
......
...@@ -15,14 +15,16 @@ ...@@ -15,14 +15,16 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import random import random
from functools import wraps
from . import MasterHandler from . import MasterHandler
from ..app import monotonic_time, StateChangedException from ..app import monotonic_time, StateChangedException
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import StoppedOperation from neo.lib.exception import StoppedOperation
from neo.lib.handler import AnswerDenied
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
from neo.lib.protocol import ClusterStates, Errors, \ from neo.lib.protocol import ClusterStates, Errors, \
NodeStates, NodeTypes, Packets, ProtocolError, uuid_str NodeStates, NodeTypes, Packets, uuid_str
from neo.lib.util import dump from neo.lib.util import dump
CLUSTER_STATE_WORKFLOW = { CLUSTER_STATE_WORKFLOW = {
...@@ -38,9 +40,25 @@ NODE_STATE_WORKFLOW = { ...@@ -38,9 +40,25 @@ NODE_STATE_WORKFLOW = {
NodeTypes.STORAGE: (NodeStates.DOWN, NodeStates.UNKNOWN), NodeTypes.STORAGE: (NodeStates.DOWN, NodeStates.UNKNOWN),
} }
def check_state(*states):
def decorator(wrapped):
def wrapper(self, *args):
state = self.app.getClusterState()
if state not in states:
raise AnswerDenied('%s RPC can not be used in %s state'
% (wrapped.__name__, state))
wrapped(self, *args)
return wraps(wrapped)(wrapper)
return decorator
class AdministrationHandler(MasterHandler): class AdministrationHandler(MasterHandler):
"""This class deals with messages from the admin node only""" """This class deals with messages from the admin node only"""
def handlerSwitched(self, conn, new):
assert new
super(AdministrationHandler, self).handlerSwitched(conn, new)
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
node = self.app.nm.getByUUID(conn.getUUID()) node = self.app.nm.getByUUID(conn.getUUID())
if node is not None: if node is not None:
...@@ -58,30 +76,28 @@ class AdministrationHandler(MasterHandler): ...@@ -58,30 +76,28 @@ class AdministrationHandler(MasterHandler):
# check request # check request
try: try:
if app.cluster_state not in CLUSTER_STATE_WORKFLOW[state]: if app.cluster_state not in CLUSTER_STATE_WORKFLOW[state]:
raise ProtocolError('Can not switch to this state') raise AnswerDenied('Can not switch to this state')
except KeyError: except KeyError:
if state != ClusterStates.STOPPING: if state != ClusterStates.STOPPING:
raise ProtocolError('Invalid state requested') raise AnswerDenied('Invalid state requested')
# change state # change state
if state == ClusterStates.VERIFYING: if state == ClusterStates.VERIFYING:
storage_list = app.nm.getStorageList(only_identified=True) storage_list = app.nm.getStorageList(only_identified=True)
if not storage_list: if not storage_list:
raise ProtocolError('Cannot exit recovery without any ' raise AnswerDenied(
'storage node') 'Cannot exit recovery without any storage node')
for node in storage_list: for node in storage_list:
assert node.isPending(), node assert node.isPending(), node
if node.getConnection().isPending(): if node.getConnection().isPending():
# XXX: It's wrong to use ProtocolError here. We must reply raise AnswerDenied(
# less aggressively because the admin has no way to 'Cannot exit recovery now: node %r is entering cluster'
# know that there's still pending activity. % node,)
raise ProtocolError('Cannot exit recovery now: node %r is '
'entering cluster' % (node, ))
app._startup_allowed = True app._startup_allowed = True
state = app.cluster_state state = app.cluster_state
elif state == ClusterStates.STARTING_BACKUP: elif state == ClusterStates.STARTING_BACKUP:
if app.tm.hasPending() or app.nm.getClientList(True): if app.tm.hasPending() or app.nm.getClientList(True):
raise ProtocolError("Can not switch to %s state with pending" raise AnswerDenied("Can not switch to %s state with pending"
" transactions or connected clients" % state) " transactions or connected clients" % state)
conn.answer(Errors.Ack('Cluster state changed')) conn.answer(Errors.Ack('Cluster state changed'))
...@@ -93,21 +109,24 @@ class AdministrationHandler(MasterHandler): ...@@ -93,21 +109,24 @@ class AdministrationHandler(MasterHandler):
app = self.app app = self.app
node = app.nm.getByUUID(uuid) node = app.nm.getByUUID(uuid)
if node is None: if node is None:
raise ProtocolError('unknown node') raise AnswerDenied('unknown node')
if state not in NODE_STATE_WORKFLOW.get(node.getType(), ()): if state not in NODE_STATE_WORKFLOW.get(node.getType(), ()):
raise ProtocolError('can not switch node to this state') raise AnswerDenied('can not switch node to %s state' % state)
if uuid == app.uuid: if uuid == app.uuid:
raise ProtocolError('can not kill primary master node') raise AnswerDenied('can not kill primary master node')
state_changed = state != node.getState() state_changed = state != node.getState()
message = ('state changed' if state_changed else message = ('state changed' if state_changed else
'node already in %s state' % state) 'node already in %s state' % state)
if node.isStorage(): if node.isStorage():
keep = state == NodeStates.DOWN keep = state == NodeStates.DOWN
if node.isRunning() and not keep:
raise AnswerDenied(
"a running node must be stopped before removal")
try: try:
cell_list = app.pt.dropNodeList([node], keep) cell_list = app.pt.dropNodeList([node], keep)
except PartitionTableException, e: except PartitionTableException, e:
raise ProtocolError(str(e)) raise AnswerDenied(str(e))
node.setState(state) node.setState(state)
if node.isConnected(): if node.isConnected():
# notify itself so it can shutdown # notify itself so it can shutdown
...@@ -134,16 +153,17 @@ class AdministrationHandler(MasterHandler): ...@@ -134,16 +153,17 @@ class AdministrationHandler(MasterHandler):
monotonic_time(), [node.asTuple()])) monotonic_time(), [node.asTuple()]))
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
# XXX: Would it be safe to allow more states ?
__change_pt_rpc = check_state(
ClusterStates.RUNNING,
ClusterStates.STARTING_BACKUP,
ClusterStates.BACKINGUP)
@__change_pt_rpc
def addPendingNodes(self, conn, uuid_list): def addPendingNodes(self, conn, uuid_list):
uuids = ', '.join(map(uuid_str, uuid_list)) uuids = ', '.join(map(uuid_str, uuid_list))
logging.debug('Add nodes %s', uuids) logging.debug('Add nodes %s', uuids)
app = self.app app = self.app
state = app.getClusterState()
# XXX: Would it be safe to allow more states ?
if state not in (ClusterStates.RUNNING,
ClusterStates.STARTING_BACKUP,
ClusterStates.BACKINGUP):
raise ProtocolError('Can not add nodes in %s state' % state)
# take all pending nodes # take all pending nodes
node_list = list(app.pt.addNodeList(node node_list = list(app.pt.addNodeList(node
for node in app.nm.getStorageList() for node in app.nm.getStorageList()
...@@ -165,31 +185,50 @@ class AdministrationHandler(MasterHandler): ...@@ -165,31 +185,50 @@ class AdministrationHandler(MasterHandler):
for uuid in uuid_list: for uuid in uuid_list:
node = getByUUID(uuid) node = getByUUID(uuid)
if node is None or not (node.isStorage() and node.isIdentified()): if node is None or not (node.isStorage() and node.isIdentified()):
raise ProtocolError("invalid storage node %s" % uuid_str(uuid)) raise AnswerDenied("invalid storage node %s" % uuid_str(uuid))
node_list.append(node) node_list.append(node)
repair = Packets.NotifyRepair(*args) repair = Packets.NotifyRepair(*args)
for node in node_list: for node in node_list:
node.send(repair) node.send(repair)
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
def tweakPartitionTable(self, conn, uuid_list): @__change_pt_rpc
app = self.app def setNumReplicas(self, conn, num_replicas):
state = app.getClusterState() self.app.broadcastPartitionChanges((), num_replicas)
# XXX: Would it be safe to allow more states ?
if state not in (ClusterStates.RUNNING,
ClusterStates.STARTING_BACKUP,
ClusterStates.BACKINGUP):
raise ProtocolError('Can not tweak partition table in %s state'
% state)
app.broadcastPartitionChanges(app.pt.tweak([node
for node in app.nm.getStorageList()
if node.getUUID() in uuid_list or not node.isRunning()]))
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
def truncate(self, conn, tid): @__change_pt_rpc
def tweakPartitionTable(self, conn, dry_run, uuid_list):
app = self.app app = self.app
if app.cluster_state != ClusterStates.RUNNING: drop_list = []
raise ProtocolError('Can not truncate in this state') for node in app.nm.getStorageList():
if node.getUUID() in uuid_list or node.isPending():
drop_list.append(node)
elif not node.isRunning():
drop_list.append(node)
raise AnswerDenied(
'tweak: down nodes must be listed explicitly')
if dry_run:
pt = object.__new__(app.pt.__class__)
new_nodes = pt.load(app.pt.getID(), app.pt.getReplicas(),
app.pt.getRowList(), app.nm)
assert not new_nodes
pt.addNodeList(node
for node, count in app.pt.count_dict.iteritems()
if not count)
else:
pt = app.pt
try:
changed_list = pt.tweak(drop_list)
except PartitionTableException, e:
raise AnswerDenied(str(e))
if not dry_run:
app.broadcastPartitionChanges(changed_list)
conn.answer(Packets.AnswerTweakPartitionTable(
bool(changed_list), pt.getRowList()))
@check_state(ClusterStates.RUNNING)
def truncate(self, conn, tid):
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
raise StoppedOperation(tid) raise StoppedOperation(tid)
...@@ -237,3 +276,5 @@ class AdministrationHandler(MasterHandler): ...@@ -237,3 +276,5 @@ class AdministrationHandler(MasterHandler):
node.send(Packets.CheckPartition( node.send(Packets.CheckPartition(
offset, source, min_tid, max_tid)) offset, source, min_tid, max_tid))
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
del __change_pt_rpc
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import ZERO_TID from neo.lib.protocol import ZERO_TID
from neo.lib.pt import PartitionTable
class BackupHandler(EventHandler): class BackupHandler(EventHandler):
"""Handler dedicated to upstream master during BACKINGUP state""" """Handler dedicated to upstream master during BACKINGUP state"""
...@@ -25,12 +26,15 @@ class BackupHandler(EventHandler): ...@@ -25,12 +26,15 @@ class BackupHandler(EventHandler):
if self.app.app.listening_conn: # if running if self.app.app.listening_conn: # if running
raise PrimaryFailure('connection lost') raise PrimaryFailure('connection lost')
def answerPartitionTable(self, conn, ptid, row_list): def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
self.app.pt.load(ptid, row_list, self.app.nm) app = self.app
pt = app.pt = object.__new__(PartitionTable)
pt.load(ptid, num_replicas, row_list, self.app.nm)
if pt.getPartitions() != app.app.pt.getPartitions():
raise RuntimeError("inconsistent number of partitions")
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
if self.app.pt.filled(): self.app.pt.update(ptid, num_replicas, cell_list, self.app.nm)
self.app.pt.update(ptid, cell_list, self.app.nm)
def answerLastTransaction(self, conn, tid): def answerLastTransaction(self, conn, tid):
app = self.app app = self.app
......
...@@ -22,6 +22,10 @@ from . import MasterHandler ...@@ -22,6 +22,10 @@ from . import MasterHandler
class ClientServiceHandler(MasterHandler): class ClientServiceHandler(MasterHandler):
""" Handler dedicated to client during service state """ """ Handler dedicated to client during service state """
def handlerSwitched(self, conn, new):
assert new
super(ClientServiceHandler, self).handlerSwitched(conn, new)
def _connectionLost(self, conn): def _connectionLost(self, conn):
# cancel its transactions and forgot the node # cancel its transactions and forgot the node
app = self.app app = self.app
......
...@@ -17,14 +17,14 @@ ...@@ -17,14 +17,14 @@
from neo.lib import logging from neo.lib import logging
from neo.lib.exception import PrimaryElected from neo.lib.exception import PrimaryElected
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, \
NotReadyError, Packets, ProtocolError, uuid_str NodeTypes, NotReadyError, Packets, ProtocolError, uuid_str
from ..app import monotonic_time from ..app import monotonic_time
class IdentificationHandler(EventHandler): class IdentificationHandler(EventHandler):
def requestIdentification(self, conn, node_type, uuid, def requestIdentification(self, conn, node_type, uuid,
address, name, devpath, id_timestamp): address, name, id_timestamp, devpath, new_nid):
app = self.app app = self.app
self.checkClusterName(name) self.checkClusterName(name)
if address == app.server: if address == app.server:
...@@ -77,6 +77,16 @@ class IdentificationHandler(EventHandler): ...@@ -77,6 +77,16 @@ class IdentificationHandler(EventHandler):
manager = app manager = app
state, handler = manager.identifyStorageNode( state, handler = manager.identifyStorageNode(
uuid is not None and node is not None) uuid is not None and node is not None)
if not address:
if app.cluster_state == ClusterStates.RECOVERING:
raise NotReadyError
if uuid or not new_nid:
raise ProtocolError
state = NodeStates.DOWN
# We'll let the storage node close the connection. If we
# aborted it at the end of the method, BootstrapManager
# (which is used by storage nodes) could see the closure
# and try to reconnect to a master.
human_readable_node_type = ' storage (%s) ' % (state, ) human_readable_node_type = ' storage (%s) ' % (state, )
elif node_type == NodeTypes.MASTER: elif node_type == NodeTypes.MASTER:
if app.election: if app.election:
...@@ -105,24 +115,27 @@ class IdentificationHandler(EventHandler): ...@@ -105,24 +115,27 @@ class IdentificationHandler(EventHandler):
node.devpath = tuple(devpath) node.devpath = tuple(devpath)
node.id_timestamp = monotonic_time() node.id_timestamp = monotonic_time()
node.setState(state) node.setState(state)
app.broadcastNodesInformation([node])
if new_nid:
changed_list = []
for offset in new_nid:
changed_list.append((offset, uuid, CellStates.OUT_OF_DATE))
app.pt._setCell(offset, node, CellStates.OUT_OF_DATE)
app.broadcastPartitionChanges(changed_list)
conn.setHandler(handler) conn.setHandler(handler)
node.setConnection(conn, not node.isIdentified()) node.setConnection(conn, not node.isIdentified())
app.broadcastNodesInformation([node], node)
conn.answer(Packets.AcceptIdentification( conn.answer(Packets.AcceptIdentification(
NodeTypes.MASTER, NodeTypes.MASTER,
app.uuid, app.uuid,
app.pt.getPartitions(),
app.pt.getReplicas(),
uuid)) uuid))
handler._notifyNodeInformation(conn) handler._notifyNodeInformation(conn)
handler.connectionCompleted(conn, True) handler.handlerSwitched(conn, True)
class SecondaryIdentificationHandler(EventHandler): class SecondaryIdentificationHandler(EventHandler):
def requestIdentification(self, conn, node_type, uuid, def requestIdentification(self, conn, node_type, uuid,
address, name, devpath, id_timestamp): address, name, id_timestamp, devpath, new_nid):
app = self.app app = self.app
self.checkClusterName(name) self.checkClusterName(name)
if address == app.server: if address == app.server:
......
...@@ -23,6 +23,9 @@ from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets ...@@ -23,6 +23,9 @@ from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
class SecondaryHandler(MasterHandler): class SecondaryHandler(MasterHandler):
"""Handler used by primary to handle secondary masters""" """Handler used by primary to handle secondary masters"""
def handlerSwitched(self, conn, new):
pass
def _connectionLost(self, conn): def _connectionLost(self, conn):
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
...@@ -30,21 +33,20 @@ class SecondaryHandler(MasterHandler): ...@@ -30,21 +33,20 @@ class SecondaryHandler(MasterHandler):
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
class ElectionHandler(MasterHandler): class ElectionHandler(SecondaryHandler):
"""Handler used by primary to handle secondary masters during election""" """Handler used by primary to handle secondary masters during election"""
def connectionCompleted(self, conn, new=None): def connectionCompleted(self, conn):
if new is None: super(ElectionHandler, self).connectionCompleted(conn)
super(ElectionHandler, self).connectionCompleted(conn) app = self.app
app = self.app conn.ask(Packets.RequestIdentification(NodeTypes.MASTER,
conn.ask(Packets.RequestIdentification(NodeTypes.MASTER, app.uuid, app.server, app.name, app.election, (), ()))
app.uuid, app.server, app.name, (), app.election))
def connectionFailed(self, conn): def connectionFailed(self, conn):
super(ElectionHandler, self).connectionFailed(conn) super(ElectionHandler, self).connectionFailed(conn)
self.connectionLost(conn) self.connectionLost(conn)
def _acceptIdentification(self, node, *args): def _acceptIdentification(self, node):
raise PrimaryElected(node) raise PrimaryElected(node)
def _connectionLost(self, *args): def _connectionLost(self, *args):
...@@ -66,7 +68,7 @@ class ElectionHandler(MasterHandler): ...@@ -66,7 +68,7 @@ class ElectionHandler(MasterHandler):
class PrimaryHandler(ElectionHandler): class PrimaryHandler(ElectionHandler):
"""Handler used by secondaries to handle primary master""" """Handler used by secondaries to handle primary master"""
def _acceptIdentification(self, node, num_partitions, num_replicas): def _acceptIdentification(self, node):
assert self.app.primary_master is node, (self.app.primary_master, node) assert self.app.primary_master is node, (self.app.primary_master, node)
def _connectionLost(self, conn): def _connectionLost(self, conn):
......
...@@ -26,10 +26,10 @@ from . import BaseServiceHandler ...@@ -26,10 +26,10 @@ from . import BaseServiceHandler
class StorageServiceHandler(BaseServiceHandler): class StorageServiceHandler(BaseServiceHandler):
""" Handler dedicated to storages during service state """ """ Handler dedicated to storages during service state """
def connectionCompleted(self, conn, new): def handlerSwitched(self, conn, new):
app = self.app app = self.app
if new: if new:
super(StorageServiceHandler, self).connectionCompleted(conn, new) super(StorageServiceHandler, self).handlerSwitched(conn, new)
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
if node.isRunning(): # node may be PENDING if node.isRunning(): # node may be PENDING
app.startStorage(node) app.startStorage(node)
......
...@@ -56,6 +56,10 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -56,6 +56,10 @@ class PartitionTable(neo.lib.pt.PartitionTable):
self._id += 1 self._id += 1
return self._id return self._id
def setReplicas(self, num_replicas):
assert num_replicas >= 0, num_replicas
self.nr = num_replicas
def make(self, node_list): def make(self, node_list):
"""Make a new partition table from scratch.""" """Make a new partition table from scratch."""
assert self._id is None and node_list, (self._id, node_list) assert self._id is None and node_list, (self._id, node_list)
...@@ -108,26 +112,19 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -108,26 +112,19 @@ class PartitionTable(neo.lib.pt.PartitionTable):
self.num_filled_rows = len(filter(None, self.partition_list)) self.num_filled_rows = len(filter(None, self.partition_list))
return change_list return change_list
def load(self, ptid, row_list, nm): def load(self, ptid, num_replicas, row_list, nm):
""" """
Load a partition table from a storage node during the recovery. Load a partition table from a storage node during the recovery.
Return the new storage nodes registered Return the new storage nodes registered
""" """
# check offsets
for offset, _row in row_list:
if offset >= self.getPartitions():
raise IndexError, offset
# store the partition table
self.clear()
self._id = ptid
new_nodes = [] new_nodes = []
for offset, row in row_list: def getByUUID(nid):
for uuid, state in row: node = nm.getByUUID(nid)
node = nm.getByUUID(uuid) if node is None:
if node is None: node = nm.createStorage(uuid=nid)
node = nm.createStorage(uuid=uuid) new_nodes.append(node.asTuple())
new_nodes.append(node.asTuple()) return node
self._setCell(offset, node, state) self._load(ptid, num_replicas, row_list, getByUUID)
return new_nodes return new_nodes
def setUpToDate(self, node, offset): def setUpToDate(self, node, offset):
...@@ -166,15 +163,6 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -166,15 +163,6 @@ class PartitionTable(neo.lib.pt.PartitionTable):
return cell_list return cell_list
def addNodeList(self, node_list):
"""Add nodes"""
added_list = []
for node in node_list:
if node not in self.count_dict:
self.count_dict[node] = 0
added_list.append(node)
return added_list
def tweak(self, drop_list=()): def tweak(self, drop_list=()):
"""Optimize partition table """Optimize partition table
...@@ -183,7 +171,8 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -183,7 +171,8 @@ class PartitionTable(neo.lib.pt.PartitionTable):
few readable cells, some cells are instead marked as FEEDING. This is few readable cells, some cells are instead marked as FEEDING. This is
a preliminary step to drop these nodes, otherwise the partition table a preliminary step to drop these nodes, otherwise the partition table
could become non-operational. could become non-operational.
- Other nodes must have the same number of cells, off by 1. In fact, the code touching these cells is disabled (see NOTE below).
- Other nodes must have the same number of non-feeding cells, off by 1.
- When a transaction creates new objects (oids are roughly allocated - When a transaction creates new objects (oids are roughly allocated
sequentially), we expect better performance by maximizing the number sequentially), we expect better performance by maximizing the number
of involved nodes (i.e. parallelizing writes). of involved nodes (i.e. parallelizing writes).
...@@ -232,6 +221,8 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -232,6 +221,8 @@ class PartitionTable(neo.lib.pt.PartitionTable):
# Collect some data in a usable form for the rest of the method. # Collect some data in a usable form for the rest of the method.
node_list = {node: {} for node in self.count_dict node_list = {node: {} for node in self.count_dict
if node not in drop_list} if node not in drop_list}
if not node_list:
raise neo.lib.pt.PartitionTableException("Can't remove all nodes.")
drop_list = defaultdict(list) drop_list = defaultdict(list)
for offset, row in enumerate(self.partition_list): for offset, row in enumerate(self.partition_list):
for cell in row: for cell in row:
...@@ -420,6 +411,22 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -420,6 +411,22 @@ class PartitionTable(neo.lib.pt.PartitionTable):
outdated_list[offset] -= 1 outdated_list[offset] -= 1
for offset, cell in cell_dict.iteritems(): for offset, cell in cell_dict.iteritems():
discard_list[offset].append(cell) discard_list[offset].append(cell)
# NOTE: The following line disables the next 2 lines, which actually
# causes cells in drop_list to be discarded, now or later;
# drop_list could be renamed into ignore_list.
# 1. Deleting data partition per partition is a lot of work, so
# why ask nodes in drop_list to do that when the goal is
# simply to trash the whole underlying database?
# 2. By excluding nodes from a tweak, it becomes possible to have
# parts of the partition table that are tweaked differently.
# This may require to temporarily change the number of
# replicas for the part being tweaked. In the future, this
# number may be specified in the 'tweak' command, to avoid
# race conditions with setUpToDate().
# Overall, a common use case is when importing a ZODB to NEO,
# to keep the initial importing node up until the database is
# split and replicated to the final nodes.
drop_list = {}
for offset, drop_list in drop_list.iteritems(): for offset, drop_list in drop_list.iteritems():
discard_list[offset] += drop_list discard_list[offset] += drop_list
# We have sorted cells to discard in order to first deallocate nodes # We have sorted cells to discard in order to first deallocate nodes
......
...@@ -28,7 +28,7 @@ class RecoveryManager(MasterHandler): ...@@ -28,7 +28,7 @@ class RecoveryManager(MasterHandler):
def __init__(self, app): def __init__(self, app):
# The target node's uuid to request next. # The target node's uuid to request next.
self.target_ptid = None self.target_ptid = 0
self.ask_pt = [] self.ask_pt = []
self.backup_tid_dict = {} self.backup_tid_dict = {}
self.truncate_dict = {} self.truncate_dict = {}
...@@ -52,9 +52,8 @@ class RecoveryManager(MasterHandler): ...@@ -52,9 +52,8 @@ class RecoveryManager(MasterHandler):
""" """
logging.info('begin the recovery of the status') logging.info('begin the recovery of the status')
app = self.app app = self.app
pt = app.pt pt = app.pt = app.newPartitionTable()
app.changeClusterState(ClusterStates.RECOVERING) app.changeClusterState(ClusterStates.RECOVERING)
pt.clear()
self.try_secondary = True self.try_secondary = True
...@@ -113,7 +112,7 @@ class RecoveryManager(MasterHandler): ...@@ -113,7 +112,7 @@ class RecoveryManager(MasterHandler):
for node in node_list: for node in node_list:
conn = node.getConnection() conn = node.getConnection()
conn.send(truncate) conn.send(truncate)
self.connectionCompleted(conn, False) self.handlerSwitched(conn, False)
continue continue
node_list = pt.getConnectedNodeList() node_list = pt.getConnectedNodeList()
break break
...@@ -140,12 +139,12 @@ class RecoveryManager(MasterHandler): ...@@ -140,12 +139,12 @@ class RecoveryManager(MasterHandler):
logging.info('creating a new partition table') logging.info('creating a new partition table')
pt.make(node_list) pt.make(node_list)
self._notifyAdmins(Packets.SendPartitionTable( self._notifyAdmins(Packets.SendPartitionTable(
pt.getID(), pt.getRowList())) pt.getID(), pt.getReplicas(), pt.getRowList()))
else: else:
cell_list = pt.outdate() cell_list = pt.outdate()
if cell_list: if cell_list:
self._notifyAdmins(Packets.NotifyPartitionChanges( self._notifyAdmins(Packets.NotifyPartitionChanges(
pt.setNextID(), cell_list)) pt.setNextID(), pt.getReplicas(), cell_list))
if app.backup_tid: if app.backup_tid:
pt.setBackupTidDict(self.backup_tid_dict) pt.setBackupTidDict(self.backup_tid_dict)
app.backup_tid = pt.getBackupTid() app.backup_tid = pt.getBackupTid()
...@@ -175,16 +174,16 @@ class RecoveryManager(MasterHandler): ...@@ -175,16 +174,16 @@ class RecoveryManager(MasterHandler):
if node is None or node.getState() == new_state: if node is None or node.getState() == new_state:
return return
node.setState(new_state) node.setState(new_state)
# broadcast to all so that admin nodes gets informed
self.app.broadcastNodesInformation([node]) self.app.broadcastNodesInformation([node])
def connectionCompleted(self, conn, new): def handlerSwitched(self, conn, new):
# ask the last IDs to perform the recovery # ask the last IDs to perform the recovery
conn.ask(Packets.AskRecovery()) conn.ask(Packets.AskRecovery())
def answerRecovery(self, conn, ptid, backup_tid, truncate_tid): def answerRecovery(self, conn, ptid, backup_tid, truncate_tid):
uuid = conn.getUUID() uuid = conn.getUUID()
if self.target_ptid <= ptid: # ptid is None if the node has an empty partition table.
if ptid and self.target_ptid <= ptid:
# Maybe a newer partition table. # Maybe a newer partition table.
if self.target_ptid == ptid and self.ask_pt: if self.target_ptid == ptid and self.ask_pt:
# Another node is already asked. # Another node is already asked.
...@@ -197,17 +196,14 @@ class RecoveryManager(MasterHandler): ...@@ -197,17 +196,14 @@ class RecoveryManager(MasterHandler):
self.backup_tid_dict[uuid] = backup_tid self.backup_tid_dict[uuid] = backup_tid
self.truncate_dict[uuid] = truncate_tid self.truncate_dict[uuid] = truncate_tid
def answerPartitionTable(self, conn, ptid, row_list): def answerPartitionTable(self, conn, ptid, num_replicas, row_list):
# If this is not from a target node, ignore it. # If this is not from a target node, ignore it.
if ptid == self.target_ptid: if ptid == self.target_ptid:
app = self.app app = self.app
try: new_nodes = app.pt.load(ptid, num_replicas, row_list, app.nm)
new_nodes = app.pt.load(ptid, row_list, app.nm)
except IndexError:
raise ProtocolError('Invalid offset')
self._notifyAdmins( self._notifyAdmins(
Packets.NotifyNodeInformation(monotonic_time(), new_nodes), Packets.NotifyNodeInformation(monotonic_time(), new_nodes),
Packets.SendPartitionTable(ptid, row_list)) Packets.SendPartitionTable(ptid, num_replicas, row_list))
self.ask_pt = () self.ask_pt = ()
uuid = conn.getUUID() uuid = conn.getUUID()
app.backup_tid = self.backup_tid_dict[uuid] app.backup_tid = self.backup_tid_dict[uuid]
......
...@@ -16,9 +16,11 @@ ...@@ -16,9 +16,11 @@
import sys import sys
from .neoctl import NeoCTL, NotReadyException from .neoctl import NeoCTL, NotReadyException
from neo.lib.node import NodeManager
from neo.lib.pt import PartitionTable
from neo.lib.util import p64, u64, tidFromTime, timeStringFromTID from neo.lib.util import p64, u64, tidFromTime, timeStringFromTID
from neo.lib.protocol import uuid_str, formatNodeList, \ from neo.lib.protocol import uuid_str, formatNodeList, \
ClusterStates, NodeTypes, UUID_NAMESPACES, ZERO_TID ClusterStates, NodeStates, NodeTypes, UUID_NAMESPACES, ZERO_TID
action_dict = { action_dict = {
'print': { 'print': {
...@@ -30,6 +32,7 @@ action_dict = { ...@@ -30,6 +32,7 @@ action_dict = {
}, },
'set': { 'set': {
'cluster': 'setClusterState', 'cluster': 'setClusterState',
'replicas': 'setNumReplicas',
}, },
'check': 'checkReplicas', 'check': 'checkReplicas',
'start': 'startCluster', 'start': 'startCluster',
...@@ -46,6 +49,11 @@ uuid_int = (lambda ns: lambda uuid: ...@@ -46,6 +49,11 @@ uuid_int = (lambda ns: lambda uuid:
(ns[uuid[0]] << 24) + int(uuid[1:]) (ns[uuid[0]] << 24) + int(uuid[1:])
)({str(k)[0]: v for k, v in UUID_NAMESPACES.iteritems()}) )({str(k)[0]: v for k, v in UUID_NAMESPACES.iteritems()})
class dummy_app:
id_timestamp = uuid = 0
class TerminalNeoCTL(object): class TerminalNeoCTL(object):
def __init__(self, *args, **kw): def __init__(self, *args, **kw):
self.neoctl = NeoCTL(*args, **kw) self.neoctl = NeoCTL(*args, **kw)
...@@ -67,6 +75,15 @@ class TerminalNeoCTL(object): ...@@ -67,6 +75,15 @@ class TerminalNeoCTL(object):
asNode = staticmethod(uuid_int) asNode = staticmethod(uuid_int)
def formatPartitionTable(self, row_list):
nm = NodeManager()
nm.update(dummy_app, 1,
self.neoctl.getNodeList(node_type=NodeTypes.STORAGE))
pt = object.__new__(PartitionTable)
pt._load(None, None, row_list, nm.getByUUID)
pt.addNodeList(nm.getByStateList(NodeStates.RUNNING))
return '\n'.join(line[4:] for line in pt._format())
def formatRowList(self, row_list): def formatRowList(self, row_list):
return '\n'.join('%03d |%s' % (offset, return '\n'.join('%03d |%s' % (offset,
''.join(' %s - %s |' % (uuid_str(uuid), state) ''.join(' %s - %s |' % (uuid_str(uuid), state)
...@@ -105,10 +122,12 @@ class TerminalNeoCTL(object): ...@@ -105,10 +122,12 @@ class TerminalNeoCTL(object):
max_offset = int(max_offset) max_offset = int(max_offset)
if node is not None: if node is not None:
node = self.asNode(node) node = self.asNode(node)
ptid, row_list = self.neoctl.getPartitionRowList( ptid, num_replicas, row_list = self.neoctl.getPartitionRowList(
min_offset=min_offset, max_offset=max_offset, node=node) min_offset=min_offset, max_offset=max_offset, node=node)
# TODO: return ptid return '# ptid: %s, replicas: %s\n%s' % (ptid, num_replicas,
return self.formatRowList(row_list) self.formatRowList(enumerate(row_list, min_offset))
if min_offset or max_offset else
self.formatPartitionTable(row_list))
def getNodeList(self, params): def getNodeList(self, params):
""" """
...@@ -140,6 +159,18 @@ class TerminalNeoCTL(object): ...@@ -140,6 +159,18 @@ class TerminalNeoCTL(object):
assert len(params) == 1 assert len(params) == 1
return self.neoctl.setClusterState(self.asClusterState(params[0])) return self.neoctl.setClusterState(self.asClusterState(params[0]))
def setNumReplicas(self, params):
"""
Set number of replicas.
Parameters: nr
nr: positive number (0 means no redundancy)
"""
assert len(params) == 1
nr = int(params[0])
if nr < 0:
sys.exit('invalid number of replicas')
return self.neoctl.setNumReplicas(nr)
def startCluster(self, params): def startCluster(self, params):
""" """
Starts cluster operation after a startup. Starts cluster operation after a startup.
...@@ -167,10 +198,18 @@ class TerminalNeoCTL(object): ...@@ -167,10 +198,18 @@ class TerminalNeoCTL(object):
def tweakPartitionTable(self, params): def tweakPartitionTable(self, params):
""" """
Optimize partition table. Optimize partition table.
No partition will be assigned to specified storage nodes. No change is done to the specified/down storage nodes and they don't
Parameters: [node [...]] count as replicas. The purpose of listing nodes is usually to drop
them once the data is replicated to other nodes.
Parameters: [-n] [node [...]]
-n: dry run
""" """
return self.neoctl.tweakPartitionTable(map(self.asNode, params)) dry_run = params[0] == '-n'
changed, row_list = self.neoctl.tweakPartitionTable(
map(self.asNode, params[dry_run:]), dry_run)
if changed:
return self.formatPartitionTable(row_list)
return 'No change done.'
def killNode(self, params): def killNode(self, params):
""" """
...@@ -203,7 +242,7 @@ class TerminalNeoCTL(object): ...@@ -203,7 +242,7 @@ class TerminalNeoCTL(object):
node: if "all", ask all connected storage nodes to repair, node: if "all", ask all connected storage nodes to repair,
otherwise, only the given list of storage nodes. otherwise, only the given list of storage nodes.
""" """
dry_run = "01".index(params.pop(0)) dry_run = bool("01".index(params.pop(0)))
return self.neoctl.repair(self._getStorageList(params), dry_run) return self.neoctl.repair(self._getStorageList(params), dry_run)
def truncate(self, params): def truncate(self, params):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import sys
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import ErrorCodes, Packets from neo.lib.protocol import ErrorCodes, Packets
...@@ -44,8 +45,8 @@ class CommandEventHandler(EventHandler): ...@@ -44,8 +45,8 @@ class CommandEventHandler(EventHandler):
def ack(self, conn, msg): def ack(self, conn, msg):
self.__respond((Packets.Error, ErrorCodes.ACK, msg)) self.__respond((Packets.Error, ErrorCodes.ACK, msg))
def protocolError(self, conn, msg): def denied(self, conn, msg):
self.__respond((Packets.Error, ErrorCodes.PROTOCOL_ERROR, msg)) sys.exit(msg)
def notReady(self, conn, msg): def notReady(self, conn, msg):
self.__respond((Packets.Error, ErrorCodes.NOT_READY, msg)) self.__respond((Packets.Error, ErrorCodes.NOT_READY, msg))
...@@ -62,3 +63,4 @@ class CommandEventHandler(EventHandler): ...@@ -62,3 +63,4 @@ class CommandEventHandler(EventHandler):
answerLastIDs = __answer(Packets.AnswerLastIDs) answerLastIDs = __answer(Packets.AnswerLastIDs)
answerLastTransaction = __answer(Packets.AnswerLastTransaction) answerLastTransaction = __answer(Packets.AnswerLastTransaction)
answerRecovery = __answer(Packets.AnswerRecovery) answerRecovery = __answer(Packets.AnswerRecovery)
answerTweakPartitionTable = __answer(Packets.AnswerTweakPartitionTable)
...@@ -91,8 +91,14 @@ class NeoCTL(BaseApplication): ...@@ -91,8 +91,14 @@ class NeoCTL(BaseApplication):
raise RuntimeError(response) raise RuntimeError(response)
return response[2] return response[2]
def tweakPartitionTable(self, uuid_list=()): def tweakPartitionTable(self, uuid_list=(), dry_run=False):
response = self.__ask(Packets.TweakPartitionTable(uuid_list)) response = self.__ask(Packets.TweakPartitionTable(dry_run, uuid_list))
if response[0] != Packets.AnswerTweakPartitionTable:
raise RuntimeError(response)
return response[1:]
def setNumReplicas(self, nr):
response = self.__ask(Packets.SetNumReplicas(nr))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK: if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
raise RuntimeError(response) raise RuntimeError(response)
return response[2] return response[2]
...@@ -163,7 +169,7 @@ class NeoCTL(BaseApplication): ...@@ -163,7 +169,7 @@ class NeoCTL(BaseApplication):
response = self.__ask(packet) response = self.__ask(packet)
if response[0] != Packets.AnswerPartitionList: if response[0] != Packets.AnswerPartitionList:
raise RuntimeError(response) raise RuntimeError(response)
return response[1:3] # ptid, row_list return response[1:]
def startCluster(self): def startCluster(self):
""" """
......
...@@ -157,27 +157,49 @@ class Log(object): ...@@ -157,27 +157,49 @@ class Log(object):
for x in 'uuid_str', 'Packets', 'PacketMalformedError': for x in 'uuid_str', 'Packets', 'PacketMalformedError':
setattr(self, x, g[x]) setattr(self, x, g[x])
x = {} x = {}
try:
Unpacker = g['Unpacker']
except KeyError:
unpackb = None
else:
from msgpack import ExtraData, UnpackException
def unpackb(data):
u = Unpacker()
u.feed(data)
data = u.unpack()
if u.read_bytes(1):
raise ExtraData
return data
self.PacketMalformedError = UnpackException
self.unpackb = unpackb
if self._decode > 1: if self._decode > 1:
PStruct = g['PStruct'] try:
PBoolean = g['PBoolean'] PStruct = g['PStruct']
def hasData(item): except KeyError:
items = item._items for p in self.Packets.itervalues():
for i, item in enumerate(items): data_path = getattr(p, 'data_path', (None,))
if isinstance(item, PStruct): if p._code >> 15 == data_path[0]:
j = hasData(item) x[p._code] = data_path[1:]
if j: else:
return (i,) + j PBoolean = g['PBoolean']
elif (isinstance(item, PBoolean) def hasData(item):
and item._name == 'compression' items = item._items
and i + 2 < len(items) for i, item in enumerate(items):
and items[i+2]._name == 'data'): if isinstance(item, PStruct):
return i, j = hasData(item)
for p in self.Packets.itervalues(): if j:
if p._fmt is not None: return (i,) + j
path = hasData(p._fmt) elif (isinstance(item, PBoolean)
if path: and item._name == 'compression'
assert not hasattr(p, '_neolog'), p and i + 2 < len(items)
x[p._code] = path and items[i+2]._name == 'data'):
return i,
for p in self.Packets.itervalues():
if p._fmt is not None:
path = hasData(p._fmt)
if path:
assert not hasattr(p, '_neolog'), p
x[p._code] = path
self._getDataPath = x.get self._getDataPath = x.get
try: try:
...@@ -215,11 +237,13 @@ class Log(object): ...@@ -215,11 +237,13 @@ class Log(object):
if body is not None: if body is not None:
log = getattr(p, '_neolog', None) log = getattr(p, '_neolog', None)
if log or self._decode: if log or self._decode:
p = p()
p._id = msg_id
p._body = body
try: try:
args = p.decode() if self.unpackb:
args = self.unpackb(body)
else:
p = p()
p._body = body
args = p.decode()
except self.PacketMalformedError: except self.PacketMalformedError:
msg.append("Can't decode packet") msg.append("Can't decode packet")
else: else:
......
...@@ -51,13 +51,11 @@ UNIT_TEST_MODULES = [ ...@@ -51,13 +51,11 @@ UNIT_TEST_MODULES = [
'neo.tests.master.testClientHandler', 'neo.tests.master.testClientHandler',
'neo.tests.master.testMasterApp', 'neo.tests.master.testMasterApp',
'neo.tests.master.testMasterPT', 'neo.tests.master.testMasterPT',
'neo.tests.master.testRecovery',
'neo.tests.master.testStorageHandler', 'neo.tests.master.testStorageHandler',
'neo.tests.master.testTransactions', 'neo.tests.master.testTransactions',
# storage application # storage application
'neo.tests.storage.testClientHandler', 'neo.tests.storage.testClientHandler',
'neo.tests.storage.testMasterHandler', 'neo.tests.storage.testMasterHandler',
'neo.tests.storage.testStorageApp',
'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), 'neo.tests.storage.testStorage' + os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
'neo.tests.storage.testTransactions', 'neo.tests.storage.testTransactions',
# client application # client application
...@@ -298,6 +296,13 @@ class TestRunner(BenchmarkRunner): ...@@ -298,6 +296,13 @@ class TestRunner(BenchmarkRunner):
x('-S', '--stop-on-success', action='store_true', default=None, x('-S', '--stop-on-success', action='store_true', default=None,
help='Opposite of --stop-on-error: stop as soon as a test' help='Opposite of --stop-on-error: stop as soon as a test'
' passes. Details about errors are not printed at exit.') ' passes. Details about errors are not printed at exit.')
x = parser.add_mutually_exclusive_group().add_argument
x('-p', '--dump-protocol', const=True,
dest='protocol', action='store_const',
help='Dump schema of protocol instead of checking it.')
x('-P', '--no-check-protocol', const=False,
dest='protocol', action='store_const',
help='Do not check schema of protocol.')
_('-r', '--readable-tid', action='store_true', _('-r', '--readable-tid', action='store_true',
help='Change master behaviour to generate readable TIDs for easier' help='Change master behaviour to generate readable TIDs for easier'
' debugging (rather than from current time).') ' debugging (rather than from current time).')
...@@ -347,6 +352,7 @@ Environment Variables: ...@@ -347,6 +352,7 @@ Environment Variables:
coverage = args.coverage, coverage = args.coverage,
cov_unit = args.cov_unit, cov_unit = args.cov_unit,
only = args.only, only = args.only,
protocol = args.protocol,
stop_on_success = args.stop_on_success, stop_on_success = args.stop_on_success,
readable_tid = args.readable_tid, readable_tid = args.readable_tid,
) )
...@@ -374,19 +380,26 @@ Environment Variables: ...@@ -374,19 +380,26 @@ Environment Variables:
self.__coverage.save() self.__coverage.save()
del self.__coverage del self.__coverage
orig(self, success) orig(self, success)
try: if config.protocol is False:
for _ in xrange(config.loop): from contextlib import nested
if config.unit: protocol_checker = nested()
runner.run('Unit tests', UNIT_TEST_MODULES, only) else:
if config.functional: from neo.tests.protocol_checker import protocolChecker
runner.run('Functional tests', FUNC_TEST_MODULES, only) protocol_checker = protocolChecker(config.protocol)
if config.zodb: with protocol_checker:
runner.run('ZODB tests', ZODB_TEST_MODULES, only) try:
except KeyboardInterrupt: for _ in xrange(config.loop):
config['mail_to'] = None if config.unit:
traceback.print_exc() runner.run('Unit tests', UNIT_TEST_MODULES, only)
except StopOnSuccess: if config.functional:
pass runner.run('Functional tests', FUNC_TEST_MODULES, only)
if config.zodb:
runner.run('ZODB tests', ZODB_TEST_MODULES, only)
except KeyboardInterrupt:
config['mail_to'] = None
traceback.print_exc()
except StopOnSuccess:
pass
if config.coverage: if config.coverage:
coverage.stop() coverage.stop()
if coverage.neotestrunner: if coverage.neotestrunner:
......
...@@ -63,11 +63,16 @@ class Application(BaseApplication): ...@@ -63,11 +63,16 @@ class Application(BaseApplication):
help="do not delete data of discarded cells, which is useful for" help="do not delete data of discarded cells, which is useful for"
" big databases because the current implementation is" " big databases because the current implementation is"
" inefficient (this option should disappear in the future)") " inefficient (this option should disappear in the future)")
_.bool('new-nid',
help="request a new NID from a cluster that is already"
" operational, update the database with the new NID and exit,"
" which makes easier to quickly set up a replica by copying"
" the database of another node while it was stopped")
_ = parser.group('database creation') _ = parser.group('database creation')
_.int('u', 'uuid', _.int('i', 'nid',
help="specify an UUID to use for this process. Previously" help="specify an NID to use for this process. Previously"
" assigned UUID takes precedence (i.e. you should" " assigned NID takes precedence (i.e. you should"
" always use reset with this switch)") " always use reset with this switch)")
_('e', 'engine', help="database engine (MySQL only)") _('e', 'engine', help="database engine (MySQL only)")
_.bool('dedup', _.bool('dedup',
...@@ -118,10 +123,16 @@ class Application(BaseApplication): ...@@ -118,10 +123,16 @@ class Application(BaseApplication):
self.loadConfiguration() self.loadConfiguration()
self.devpath = self.dm.getTopologyPath() self.devpath = self.dm.getTopologyPath()
# force node uuid from command line argument, for testing purpose only if config.get('new_nid'):
if 'uuid' in config: self.new_nid = [x[0] for x in self.dm.iterAssignedCells()]
self.uuid = config['uuid'] if not self.new_nid:
logging.node(self.name, self.uuid) sys.exit('database is empty')
self.uuid = None
else:
self.new_nid = ()
if 'nid' in config: # for testing purpose only
self.uuid = config['nid']
logging.node(self.name, self.uuid)
registerLiveDebugger(on_log=self.log) registerLiveDebugger(on_log=self.log)
...@@ -158,36 +169,27 @@ class Application(BaseApplication): ...@@ -158,36 +169,27 @@ class Application(BaseApplication):
# load configuration # load configuration
self.uuid = dm.getUUID() self.uuid = dm.getUUID()
logging.node(self.name, self.uuid) logging.node(self.name, self.uuid)
num_partitions = dm.getNumPartitions()
num_replicas = dm.getNumReplicas()
ptid = dm.getPTID()
# check partition table configuration
if num_partitions is not None and num_replicas is not None:
if num_partitions <= 0:
raise RuntimeError, 'partitions must be more than zero'
# create a partition table
self.pt = PartitionTable(num_partitions, num_replicas)
logging.info('Configuration loaded:') logging.info('Configuration loaded:')
logging.info('PTID : %s', dump(ptid)) logging.info('PTID : %s', dump(dm.getPTID()))
logging.info('Name : %s', self.name) logging.info('Name : %s', self.name)
logging.info('Partitions: %s', num_partitions)
logging.info('Replicas : %s', num_replicas)
def loadPartitionTable(self): def loadPartitionTable(self):
"""Load a partition table from the database.""" """Load a partition table from the database."""
self.pt.clear()
ptid = self.dm.getPTID() ptid = self.dm.getPTID()
if ptid is None: if ptid is None:
self.pt = PartitionTable(0, 0)
return return
cell_list = [] row_list = []
for offset, uuid, state in self.dm.getPartitionTable(): for offset, uuid, state in self.dm.getPartitionTable():
while len(row_list) <= offset:
row_list.append([])
# register unknown nodes # register unknown nodes
if self.nm.getByUUID(uuid) is None: if self.nm.getByUUID(uuid) is None:
self.nm.createStorage(uuid=uuid) self.nm.createStorage(uuid=uuid)
cell_list.append((offset, uuid, CellStates[state])) row_list[offset].append((uuid, CellStates[state]))
self.pt.update(ptid, cell_list, self.nm) self.pt = object.__new__(PartitionTable)
self.pt.load(ptid, self.dm.getNumReplicas(), row_list, self.nm)
def run(self): def run(self):
try: try:
...@@ -247,29 +249,16 @@ class Application(BaseApplication): ...@@ -247,29 +249,16 @@ class Application(BaseApplication):
Note that I do not accept any connection from non-master nodes Note that I do not accept any connection from non-master nodes
at this stage.""" at this stage."""
pt = self.pt
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, NodeTypes.STORAGE, self.server, bootstrap = BootstrapManager(self, NodeTypes.STORAGE,
self.devpath) None if self.new_nid else self.server,
self.master_node, self.master_conn, num_partitions, num_replicas = \ self.devpath, self.new_nid)
bootstrap.getPrimaryConnection() self.master_node, self.master_conn = bootstrap.getPrimaryConnection()
self.dm.setUUID(self.uuid) self.dm.setUUID(self.uuid)
# Reload a partition table from the database. This is necessary # Reload a partition table from the database,
# when a previous primary master died while sending a partition # in case that we're in RECOVERING phase.
# table, because the table might be incomplete. self.loadPartitionTable()
if pt is not None:
self.loadPartitionTable()
if num_partitions != pt.getPartitions():
raise RuntimeError('the number of partitions is inconsistent')
if pt is None or pt.getReplicas() != num_replicas:
# changing number of replicas is not an issue
self.dm.setNumPartitions(num_partitions)
self.dm.setNumReplicas(num_replicas)
self.pt = PartitionTable(num_partitions, num_replicas)
self.loadPartitionTable()
def initialize(self): def initialize(self):
logging.debug('initializing...') logging.debug('initializing...')
......
...@@ -51,7 +51,7 @@ class Checker(object): ...@@ -51,7 +51,7 @@ class Checker(object):
else: else:
conn = ClientConnection(app, StorageOperationHandler(app), node) conn = ClientConnection(app, StorageOperationHandler(app), node)
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
uuid, app.server, name, (), app.id_timestamp)) uuid, app.server, name, app.id_timestamp, (), ()))
self.conn_dict[conn] = node.isIdentified() self.conn_dict[conn] = node.isIdentified()
conn_set = set(self.conn_dict) conn_set = set(self.conn_dict)
conn_set.discard(None) conn_set.discard(None)
......
...@@ -33,7 +33,7 @@ from ZODB.FileStorage import FileStorage ...@@ -33,7 +33,7 @@ from ZODB.FileStorage import FileStorage
from ..app import option_defaults from ..app import option_defaults
from . import buildDatabaseManager, DatabaseFailure from . import buildDatabaseManager, DatabaseFailure
from .manager import DatabaseManager from .manager import DatabaseManager, Fallback
from neo.lib import compress, logging, patch, util from neo.lib import compress, logging, patch, util
from neo.lib.interfaces import implements from neo.lib.interfaces import implements
from neo.lib.protocol import BackendNotImplemented, MAX_TID from neo.lib.protocol import BackendNotImplemented, MAX_TID
...@@ -216,7 +216,7 @@ class ZODB(object): ...@@ -216,7 +216,7 @@ class ZODB(object):
self._connect = _connect self._connect = _connect
config = section.config config = section.config
if 'read_only' in config.getSectionAttributes(): if 'read_only' in config.getSectionAttributes():
has_next_oid = config.read_only = hasattr(self, 'next_oid') has_next_oid = config.read_only = 'next_oid' in self.__dict__
if not has_next_oid: if not has_next_oid:
import gc import gc
# This will reopen read-only as soon as we know the last oid. # This will reopen read-only as soon as we know the last oid.
...@@ -378,8 +378,8 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -378,8 +378,8 @@ class ImporterDatabaseManager(DatabaseManager):
conf = self._conf conf = self._conf
db = self.db = buildDatabaseManager(conf['adapter'], db = self.db = buildDatabaseManager(conf['adapter'],
(conf['database'], conf.get('engine'), conf['wait'])) (conf['database'], conf.get('engine'), conf['wait']))
for x in """getConfiguration _setConfiguration setNumPartitions for x in """getConfiguration _setConfiguration _getMaxPartition
query erase getPartitionTable _iterAssignedCells query erase getPartitionTable iterAssignedCells
updateCellTID getUnfinishedTIDDict dropUnfinishedData updateCellTID getUnfinishedTIDDict dropUnfinishedData
abortTransaction storeTransaction lockTransaction abortTransaction storeTransaction lockTransaction
loadData storeData getOrphanList _pruneData deferCommit loadData storeData getOrphanList _pruneData deferCommit
...@@ -396,9 +396,16 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -396,9 +396,16 @@ class ImporterDatabaseManager(DatabaseManager):
self._writeback.committed() self._writeback.committed()
self.commit = db.commit = commit self.commit = db.commit = commit
def _updateReadable(self): def _updateReadable(*_):
raise AssertionError raise AssertionError
def setUUID(self, nid):
old_nid = self.getUUID()
if old_nid:
assert old_nid == nid, (old_nid, nid)
else:
self.setConfiguration('nid', str(nid))
def changePartitionTable(self, *args, **kw): def changePartitionTable(self, *args, **kw):
self.db.changePartitionTable(*args, **kw) self.db.changePartitionTable(*args, **kw)
if self._writeback: if self._writeback:
...@@ -413,7 +420,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -413,7 +420,7 @@ class ImporterDatabaseManager(DatabaseManager):
if self._writeback: if self._writeback:
self._writeback.close() self._writeback.close()
self.db.close() self.db.close()
if isinstance(self.zodb, list): # _setup called if isinstance(self.zodb, tuple): # _setup called
for zodb in self.zodb: for zodb in self.zodb:
zodb.close() zodb.close()
...@@ -436,9 +443,13 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -436,9 +443,13 @@ class ImporterDatabaseManager(DatabaseManager):
self.zodb_ltid = max(x.ltid for x in self.zodb) self.zodb_ltid = max(x.ltid for x in self.zodb)
zodb = self.zodb[-1] zodb = self.zodb[-1]
self.zodb_loid = zodb.shift_oid + zodb.next_oid - 1 self.zodb_loid = zodb.shift_oid + zodb.next_oid - 1
self.zodb_tid = self.db.getLastTID(self.zodb_ltid) or 0 self.zodb_tid = self._getMaxPartition() is not None and \
if callable(self._import): self.db.getLastTID(self.zodb_ltid) or 0
self._import = self._import() if callable(self._import): # XXX: why ?
if self.zodb_tid == self.zodb_ltid:
self._finished()
else:
self._import = self._import()
def doOperation(self, app): def doOperation(self, app):
if self._import: if self._import:
...@@ -498,12 +509,19 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -498,12 +509,19 @@ class ImporterDatabaseManager(DatabaseManager):
if process: if process:
process.join() process.join()
self.commit() self.commit()
self._finished()
def _finished(self):
logging.warning("All data are imported. You should change" logging.warning("All data are imported. You should change"
" your configuration to use the native backend and restart.") " your configuration to use the native backend and restart.")
self._import = None self._import = None
for x in """getObject getReplicationTIDList getReplicationObjectList for x in """getObject getReplicationTIDList getReplicationObjectList
_fetchObject _getDataTID getLastObjectTID
""".split(): """.split():
setattr(self, x, getattr(self.db, x)) setattr(self, x, getattr(self.db, x))
for zodb in self.zodb:
zodb.close()
self.zodb = None
def _iter_zodb(self, zodb_list): def _iter_zodb(self, zodb_list):
util.setproctitle('neostorage: import') util.setproctitle('neostorage: import')
...@@ -556,6 +574,19 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -556,6 +574,19 @@ class ImporterDatabaseManager(DatabaseManager):
return (max(tid, util.p64(self.zodb_ltid)), return (max(tid, util.p64(self.zodb_ltid)),
max(oid, util.p64(self.zodb_loid))) max(oid, util.p64(self.zodb_loid)))
def _getObject(self, oid, tid=None, before_tid=None):
p64 = util.p64
r = self.getObject(p64(oid),
None if tid is None else p64(tid),
None if before_tid is None else p64(before_tid))
if r:
serial, next_serial, compression, checksum, data, data_serial = r
u64 = util.u64
return (u64(serial),
next_serial and u64(next_serial),
compression, checksum, data,
data_serial and u64(data_serial))
def getObject(self, oid, tid=None, before_tid=None): def getObject(self, oid, tid=None, before_tid=None):
u64 = util.u64 u64 = util.u64
u_oid = u64(oid) u_oid = u64(oid)
...@@ -623,7 +654,11 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -623,7 +654,11 @@ class ImporterDatabaseManager(DatabaseManager):
def _deleteRange(self, partition, min_tid=None, max_tid=None): def _deleteRange(self, partition, min_tid=None, max_tid=None):
# Even if everything is imported, we can't truncate below # Even if everything is imported, we can't truncate below
# because it would import again if we restart with this backend. # because it would import again if we restart with this backend.
if min_tid < self.zodb_ltid: # This is also incompatible with writeback, because ZODB has
# no API to truncate.
if min_tid < self.zodb_ltid or self._writeback:
# XXX: That's late to fail. The master should ask storage nodes
# whether truncation is possible before going further.
raise NotImplementedError raise NotImplementedError
self.db._deleteRange(partition, min_tid, max_tid) self.db._deleteRange(partition, min_tid, max_tid)
...@@ -667,6 +702,12 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -667,6 +702,12 @@ class ImporterDatabaseManager(DatabaseManager):
length, partition) length, partition)
return r return r
def _fetchObject(*_):
raise AssertionError
getLastObjectTID = Fallback.getLastObjectTID.__func__
_getDataTID = Fallback._getDataTID.__func__
def getObjectHistory(self, *args, **kw): def getObjectHistory(self, *args, **kw):
raise BackendNotImplemented(self.getObjectHistory) raise BackendNotImplemented(self.getObjectHistory)
...@@ -678,6 +719,7 @@ class WriteBack(object): ...@@ -678,6 +719,7 @@ class WriteBack(object):
_changed = False _changed = False
_process = None _process = None
chunk_size = 100
def __init__(self, db, storage): def __init__(self, db, storage):
self._db = db self._db = db
...@@ -705,7 +747,7 @@ class WriteBack(object): ...@@ -705,7 +747,7 @@ class WriteBack(object):
self._event = Event() self._event = Event()
self._idle = Event() self._idle = Event()
self._stop = Event() self._stop = Event()
self._np = self._db.getNumPartitions() self._np = 1 + self._db._getMaxPartition()
self._db = cPickle.dumps(self._db, 2) self._db = cPickle.dumps(self._db, 2)
self._process = Process(target=self._run) self._process = Process(target=self._run)
self._process.daemon = True self._process.daemon = True
...@@ -737,7 +779,6 @@ class WriteBack(object): ...@@ -737,7 +779,6 @@ class WriteBack(object):
def iterator(self): def iterator(self):
db = self._db db = self._db
np = self._np np = self._np
chunk_size = max(2, 1000 // np)
offset_list = xrange(np) offset_list = xrange(np)
while 1: while 1:
with db: with db:
...@@ -748,23 +789,26 @@ class WriteBack(object): ...@@ -748,23 +789,26 @@ class WriteBack(object):
if np == len(db._readable_set): if np == len(db._readable_set):
while 1: while 1:
tid_list = [] tid_list = []
loop = False max_tid = MAX_TID
for offset in offset_list: for offset in offset_list:
x = db.getReplicationTIDList( x = db.getReplicationTIDList(
self.min_tid, MAX_TID, chunk_size, offset) self.min_tid, max_tid, self.chunk_size, offset)
tid_list += x tid_list += x
if len(x) == chunk_size: if len(x) == self.chunk_size:
loop = True max_tid = x[-1]
if tid_list: if not tid_list:
tid_list.sort() break
for tid in tid_list: tid_list.sort()
if self._stop.is_set(): for tid in tid_list:
return if self._stop.is_set():
yield TransactionRecord(db, tid) return
yield TransactionRecord(db, tid)
if tid == max_tid:
break
else:
self.min_tid = util.add64(tid, 1) self.min_tid = util.add64(tid, 1)
if loop: break
continue self.min_tid = util.add64(tid, 1)
break
if not self._event.is_set(): if not self._event.is_set():
self._idle.set() self._idle.set()
self._event.wait() self._event.wait()
...@@ -785,7 +829,10 @@ class TransactionRecord(BaseStorage.TransactionRecord): ...@@ -785,7 +829,10 @@ class TransactionRecord(BaseStorage.TransactionRecord):
def __iter__(self): def __iter__(self):
tid = self.tid tid = self.tid
for oid in self._oid_list: for oid in self._oid_list:
_, compression, _, data, data_tid = self._db.fetchObject(oid, tid) r = self._db.fetchObject(oid, tid)
if r is None: # checkCurrentSerialInTransaction
continue
_, compression, _, data, data_tid = r
if data is not None: if data is not None:
data = compress.decompress_list[compression](data) data = compress.decompress_list[compression](data)
yield BaseStorage.DataRecord(oid, tid, data, data_tid) yield BaseStorage.DataRecord(oid, tid, data, data_tid)
...@@ -26,22 +26,9 @@ from . import DatabaseFailure ...@@ -26,22 +26,9 @@ from . import DatabaseFailure
READABLE = CellStates.UP_TO_DATE, CellStates.FEEDING READABLE = CellStates.UP_TO_DATE, CellStates.FEEDING
def lazymethod(func):
def getter(self):
cls = self.__class__
name = func.__name__
assert name not in cls.__dict__
setattr(cls, name, func(self))
return getattr(self, name)
return property(getter, doc=func.__doc__)
def fallback(func): def fallback(func):
def warn(self): setattr(Fallback, func.__name__, func)
logging.info("Fallback to generic/slow implementation of %s." return abstract(func)
" It should be overridden by backend storage (%s).",
func.__name__, self.__class__.__name__)
return func
return lazymethod(wraps(func)(warn))
def splitOIDField(tid, oids): def splitOIDField(tid, oids):
if len(oids) % 8: if len(oids) % 8:
...@@ -52,6 +39,9 @@ def splitOIDField(tid, oids): ...@@ -52,6 +39,9 @@ def splitOIDField(tid, oids):
class CreationUndone(Exception): class CreationUndone(Exception):
pass pass
class Fallback(object):
pass
class DatabaseManager(object): class DatabaseManager(object):
"""This class only describes an interface for database managers.""" """This class only describes an interface for database managers."""
...@@ -102,25 +92,24 @@ class DatabaseManager(object): ...@@ -102,25 +92,24 @@ class DatabaseManager(object):
finally: finally:
db.close() db.close()
_cached_attr_list = (
'_readable_set', '_getPartition', '_getReadablePartition')
def __getattr__(self, attr): def __getattr__(self, attr):
if attr in ('_readable_set', '_getPartition', '_getReadablePartition'): if attr in self._cached_attr_list:
self._updateReadable() self._updateReadable()
return self.__getattribute__(attr) return self.__getattribute__(attr)
def _partitionTableChanged(self):
try:
del (self._readable_set,
self._getPartition,
self._getReadablePartition)
except AttributeError:
pass
def __enter__(self): def __enter__(self):
assert not self.LOCK, "not a secondary connection" assert not self.LOCK, "not a secondary connection"
# XXX: All config caching should be done in this class, # XXX: All config caching should be done in this class,
# rather than in backend classes. # rather than in backend classes.
self._config.clear() self._config.clear()
self._partitionTableChanged() try:
for attr in self._cached_attr_list:
delattr(self, attr)
except AttributeError:
pass
def __exit__(self, t, v, tb): def __exit__(self, t, v, tb):
if v is None: if v is None:
...@@ -180,6 +169,10 @@ class DatabaseManager(object): ...@@ -180,6 +169,10 @@ class DatabaseManager(object):
def erase(self): def erase(self):
"""""" """"""
def restore(self, dump): # for tests
self.erase()
self._restore(dump)
def _setup(self, dedup=False): def _setup(self, dedup=False):
"""To be overridden by the backend to set up a database """To be overridden by the backend to set up a database
...@@ -271,6 +264,18 @@ class DatabaseManager(object): ...@@ -271,6 +264,18 @@ class DatabaseManager(object):
def _setConfiguration(self, key, value): def _setConfiguration(self, key, value):
"""""" """"""
def _changePartitionTable(self, cell_list, reset=False):
"""Change a part of a partition table. The list of cells is
a tuple of tuples, each of which consists of an offset (row ID),
the NID of a storage node, and a cell state. If reset is True,
existing data is first thrown away.
"""
def _getPartitionTable(self):
"""Return a whole partition table as a sequence of rows. Each row
is again a tuple of an offset (row ID), the NID of a storage
node, and a cell state."""
def getUUID(self): def getUUID(self):
""" """
Load a NID from a database. Load a NID from a database.
...@@ -279,26 +284,19 @@ class DatabaseManager(object): ...@@ -279,26 +284,19 @@ class DatabaseManager(object):
if nid is not None: if nid is not None:
return int(nid) return int(nid)
@requires(_changePartitionTable, _getPartitionTable)
def setUUID(self, nid): def setUUID(self, nid):
""" """
Store a NID into a database. Store a NID into a database.
""" """
self.setConfiguration('nid', str(nid)) old_nid = self.getUUID()
if nid != old_nid:
def getNumPartitions(self): if old_nid:
""" self._changePartitionTable((offset, x, tid)
Load the number of partitions from a database. for offset, x, tid in self._getPartitionTable()
""" if x == old_nid
n = self.getConfiguration('partitions') for x, tid in ((x, None), (nid, tid)))
if n is not None: self.setConfiguration('nid', str(nid))
return int(n)
def setNumPartitions(self, num_partitions):
"""
Store the number of partitions into a database.
"""
self.setConfiguration('partitions', num_partitions)
self._partitionTableChanged()
def getNumReplicas(self): def getNumReplicas(self):
""" """
...@@ -308,12 +306,6 @@ class DatabaseManager(object): ...@@ -308,12 +306,6 @@ class DatabaseManager(object):
if n is not None: if n is not None:
return int(n) return int(n)
def setNumReplicas(self, num_replicas):
"""
Store the number of replicas into a database.
"""
self.setConfiguration('replicas', num_replicas)
def getName(self): def getName(self):
""" """
Load a name from a database. Load a name from a database.
...@@ -374,8 +366,9 @@ class DatabaseManager(object): ...@@ -374,8 +366,9 @@ class DatabaseManager(object):
tids are in unpacked format. tids are in unpacked format.
""" """
if self.getNumPartitions(): x = self._readable_set
return max(map(self._getLastTID, self._readable_set)) if x:
return max(self._getLastTID(x, max_tid) for x in x)
def _getLastIDs(self, partition): def _getLastIDs(self, partition):
"""Return max(tid) & max(oid) for objects of given partition """Return max(tid) & max(oid) for objects of given partition
...@@ -395,7 +388,7 @@ class DatabaseManager(object): ...@@ -395,7 +388,7 @@ class DatabaseManager(object):
x = self._readable_set x = self._readable_set
if x: if x:
tid, oid = zip(*map(self._getLastIDs, x)) tid, oid = zip(*map(self._getLastIDs, x))
tid = max(self.getLastTID(None), max(tid)) tid = max(self.getLastTID(), max(tid))
oid = max(oid) oid = max(oid)
return (None if tid is None else util.p64(tid), return (None if tid is None else util.p64(tid),
None if oid is None else util.p64(oid)) None if oid is None else util.p64(oid))
...@@ -490,6 +483,7 @@ class DatabaseManager(object): ...@@ -490,6 +483,7 @@ class DatabaseManager(object):
None if data_serial is None else util.p64(data_serial)) None if data_serial is None else util.p64(data_serial))
@fallback @fallback
@requires(_getObject)
def _fetchObject(self, oid, tid): def _fetchObject(self, oid, tid):
"""Specialized version of _getObject, for replication""" """Specialized version of _getObject, for replication"""
r = self._getObject(oid, tid) r = self._getObject(oid, tid)
...@@ -511,13 +505,8 @@ class DatabaseManager(object): ...@@ -511,13 +505,8 @@ class DatabaseManager(object):
return (util.p64(serial), compression, checksum, data, return (util.p64(serial), compression, checksum, data,
None if data_serial is None else util.p64(data_serial)) None if data_serial is None else util.p64(data_serial))
def _getPartitionTable(self):
"""Return a whole partition table as a sequence of rows. Each row
is again a tuple of an offset (row ID), the NID of a storage
node, and a cell state."""
@requires(_getPartitionTable) @requires(_getPartitionTable)
def _iterAssignedCells(self): def iterAssignedCells(self):
my_nid = self.getUUID() my_nid = self.getUUID()
return ((offset, tid) for offset, nid, tid in self._getPartitionTable() return ((offset, tid) for offset, nid, tid in self._getPartitionTable()
if my_nid == nid) if my_nid == nid)
...@@ -537,24 +526,19 @@ class DatabaseManager(object): ...@@ -537,24 +526,19 @@ class DatabaseManager(object):
finally: finally:
readable_set.remove(offset) readable_set.remove(offset)
def _changePartitionTable(self, cell_list, reset=False): def _getDataLastId(self, partition):
"""Change a part of a partition table. The list of cells is """
a tuple of tuples, each of which consists of an offset (row ID),
the NID of a storage node, and a cell state. If reset is True,
existing data is first thrown away.
""" """
def _getDataLastId(self, partition): def _getMaxPartition(self):
""" """
""" """
@requires(_getDataLastId) @requires(_getDataLastId, _getMaxPartition)
def _updateReadable(self): def _updateReadable(self, reset=True):
try: if reset:
readable_set = self.__dict__['_readable_set']
except KeyError:
readable_set = self._readable_set = set() readable_set = self._readable_set = set()
np = self.getNumPartitions() np = 1 + self._getMaxPartition()
def _getPartition(x, np=np): def _getPartition(x, np=np):
return x % np return x % np
def _getReadablePartition(x, np=np, r=readable_set): def _getReadablePartition(x, np=np, r=readable_set):
...@@ -569,14 +553,15 @@ class DatabaseManager(object): ...@@ -569,14 +553,15 @@ class DatabaseManager(object):
i = self._getDataLastId(p) i = self._getDataLastId(p)
d.append(p << 48 if i is None else i + 1) d.append(p << 48 if i is None else i + 1)
else: else:
readable_set = self._readable_set
readable_set.clear() readable_set.clear()
readable_set.update(x[0] for x in self._iterAssignedCells() readable_set.update(x[0] for x in self.iterAssignedCells()
if -x[1] in READABLE) if -x[1] in READABLE)
@requires(_changePartitionTable, _getLastIDs, _getLastTID) @requires(_changePartitionTable, _getLastIDs, _getLastTID)
def changePartitionTable(self, ptid, cell_list, reset=False): def changePartitionTable(self, ptid, num_replicas, cell_list, reset=False):
my_nid = self.getUUID() my_nid = self.getUUID()
pt = dict(self._iterAssignedCells()) pt = dict(self.iterAssignedCells())
# In backup mode, the last transactions of a readable cell may be # In backup mode, the last transactions of a readable cell may be
# incomplete. # incomplete.
backup_tid = self.getBackupTID() backup_tid = self.getBackupTID()
...@@ -595,13 +580,14 @@ class DatabaseManager(object): ...@@ -595,13 +580,14 @@ class DatabaseManager(object):
outofdate_tid(offset))) outofdate_tid(offset)))
for offset, nid, state in cell_list] for offset, nid, state in cell_list]
self._changePartitionTable(cell_list, reset) self._changePartitionTable(cell_list, reset)
self._updateReadable() self._updateReadable(reset)
assert isinstance(ptid, (int, long)), ptid assert isinstance(ptid, (int, long)), ptid
self._setConfiguration('ptid', str(ptid)) self._setConfiguration('ptid', str(ptid))
self._setConfiguration('replicas', str(num_replicas))
@requires(_changePartitionTable) @requires(_changePartitionTable)
def updateCellTID(self, partition, tid): def updateCellTID(self, partition, tid):
t, = (t for p, t in self._iterAssignedCells() if p == partition) t, = (t for p, t in self.iterAssignedCells() if p == partition)
if t < 0: if t < 0:
return return
tid = util.u64(tid) tid = util.u64(tid)
...@@ -623,7 +609,7 @@ class DatabaseManager(object): ...@@ -623,7 +609,7 @@ class DatabaseManager(object):
next_tid = util.u64(backup_tid) next_tid = util.u64(backup_tid)
if next_tid: if next_tid:
next_tid += 1 next_tid += 1
for offset, tid in self._iterAssignedCells(): for offset, tid in self.iterAssignedCells():
if tid >= 0: # OUT_OF_DATE if tid >= 0: # OUT_OF_DATE
yield offset, p64(tid and tid + 1) yield offset, p64(tid and tid + 1)
elif -tid in READABLE: elif -tid in READABLE:
...@@ -743,6 +729,7 @@ class DatabaseManager(object): ...@@ -743,6 +729,7 @@ class DatabaseManager(object):
return self._pruneData(data_id_list) return self._pruneData(data_id_list)
@fallback @fallback
@requires(_getObject)
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
""" """
Return a 2-tuple: Return a 2-tuple:
...@@ -809,6 +796,9 @@ class DatabaseManager(object): ...@@ -809,6 +796,9 @@ class DatabaseManager(object):
oid, current_tid) oid, current_tid)
return current_tid, current_tid return current_tid, current_tid
return current_tid, tid return current_tid, tid
found_undone_tid, undone_data_tid = getDataTID(tid=undone_tid)
if found_undone_tid is None:
return
if transaction_object: if transaction_object:
try: try:
current_tid = current_data_tid = u64(transaction_object[2]) current_tid = current_data_tid = u64(transaction_object[2])
...@@ -818,8 +808,6 @@ class DatabaseManager(object): ...@@ -818,8 +808,6 @@ class DatabaseManager(object):
current_tid, current_data_tid = getDataTID(before_tid=ltid) current_tid, current_data_tid = getDataTID(before_tid=ltid)
if current_tid is None: if current_tid is None:
return None, None, False return None, None, False
found_undone_tid, undone_data_tid = getDataTID(tid=undone_tid)
assert found_undone_tid is not None, (oid, undone_tid)
is_current = undone_data_tid in (current_data_tid, tid) is_current = undone_data_tid in (current_data_tid, tid)
# Load object data as it was before given transaction. # Load object data as it was before given transaction.
# It can be None, in which case it means we are undoing object # It can be None, in which case it means we are undoing object
...@@ -865,7 +853,7 @@ class DatabaseManager(object): ...@@ -865,7 +853,7 @@ class DatabaseManager(object):
assert tid, tid assert tid, tid
cell_list = [] cell_list = []
my_nid = self.getUUID() my_nid = self.getUUID()
for partition, state in self._iterAssignedCells(): for partition, state in self.iterAssignedCells():
if state > tid: if state > tid:
cell_list.append((partition, my_nid, tid)) cell_list.append((partition, my_nid, tid))
self._deleteRange(partition, tid) self._deleteRange(partition, tid)
......
...@@ -117,9 +117,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -117,9 +117,11 @@ class MySQLDatabaseManager(DatabaseManager):
return super(MySQLDatabaseManager, self).__getattr__(attr) return super(MySQLDatabaseManager, self).__getattr__(attr)
def _tryConnect(self): def _tryConnect(self):
kwd = {'db' : self.db, 'user' : self.user} kwd = {'db' : self.db}
if self.passwd is not None: if self.user:
kwd['passwd'] = self.passwd kwd['user'] = self.user
if self.passwd is not None:
kwd['passwd'] = self.passwd
if self.socket: if self.socket:
kwd['unix_socket'] = os.path.expanduser(self.socket) kwd['unix_socket'] = os.path.expanduser(self.socket)
logging.info('connecting to MySQL on the database %s with user %s', logging.info('connecting to MySQL on the database %s with user %s',
...@@ -198,6 +200,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -198,6 +200,7 @@ class MySQLDatabaseManager(DatabaseManager):
self._connect() self._connect()
def _commit(self): def _commit(self):
# XXX: Should we translate OperationalError into MysqlError ?
self.conn.commit() self.conn.commit()
self._active = 0 self._active = 0
...@@ -270,6 +273,12 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -270,6 +273,12 @@ class MySQLDatabaseManager(DatabaseManager):
" ELSE 1-state" " ELSE 1-state"
" END as tid") " END as tid")
# Let's wait for a more important change to clean up,
# so that users can still downgrade.
if 0:
def _migrate4(self, schema_dict):
self._setConfiguration('partitions', None)
def _setup(self, dedup=False): def _setup(self, dedup=False):
self._config.clear() self._config.clear()
q = self.query q = self.query
...@@ -295,6 +304,12 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -295,6 +304,12 @@ class MySQLDatabaseManager(DatabaseManager):
p += """ PARTITION BY LIST (`partition`) ( p += """ PARTITION BY LIST (`partition`) (
PARTITION dummy VALUES IN (NULL))""" PARTITION dummy VALUES IN (NULL))"""
if engine == "RocksDB":
cf = lambda name, rev=False: " COMMENT '%scf_neo_%s'" % (
'rev:' if rev else '', name)
else:
cf = lambda *_: ''
# The table "trans" stores information on committed transactions. # The table "trans" stores information on committed transactions.
schema_dict['trans'] = """CREATE TABLE %s ( schema_dict['trans'] = """CREATE TABLE %s (
`partition` SMALLINT UNSIGNED NOT NULL, `partition` SMALLINT UNSIGNED NOT NULL,
...@@ -305,8 +320,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -305,8 +320,8 @@ class MySQLDatabaseManager(DatabaseManager):
description BLOB NOT NULL, description BLOB NOT NULL,
ext BLOB NOT NULL, ext BLOB NOT NULL,
ttid BIGINT UNSIGNED NOT NULL, ttid BIGINT UNSIGNED NOT NULL,
PRIMARY KEY (`partition`, tid) PRIMARY KEY (`partition`, tid){}
) ENGINE=""" + p ) ENGINE={}""".format(cf('append_meta'), p)
# The table "obj" stores committed object metadata. # The table "obj" stores committed object metadata.
schema_dict['obj'] = """CREATE TABLE %s ( schema_dict['obj'] = """CREATE TABLE %s (
...@@ -315,10 +330,11 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -315,10 +330,11 @@ class MySQLDatabaseManager(DatabaseManager):
tid BIGINT UNSIGNED NOT NULL, tid BIGINT UNSIGNED NOT NULL,
data_id BIGINT UNSIGNED NULL, data_id BIGINT UNSIGNED NULL,
value_tid BIGINT UNSIGNED NULL, value_tid BIGINT UNSIGNED NULL,
PRIMARY KEY (`partition`, oid, tid), PRIMARY KEY (`partition`, oid, tid){},
KEY tid (`partition`, tid, oid), KEY tid (`partition`, tid, oid){},
KEY (data_id) KEY (data_id){}
) ENGINE=""" + p ) ENGINE={}""".format(cf('obj_pk', True),
cf('append_meta'), cf('append_meta'), p)
if engine == "TokuDB": if engine == "TokuDB":
engine += " compression='tokudb_uncompressed'" engine += " compression='tokudb_uncompressed'"
...@@ -326,18 +342,21 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -326,18 +342,21 @@ class MySQLDatabaseManager(DatabaseManager):
# The table "data" stores object data. # The table "data" stores object data.
# We'd like to have partial index on 'hash' column (e.g. hash(4)) # We'd like to have partial index on 'hash' column (e.g. hash(4))
# but 'UNIQUE' constraint would not work as expected. # but 'UNIQUE' constraint would not work as expected.
schema_dict['data'] = """CREATE TABLE %%s ( schema_dict['data'] = """CREATE TABLE %s (
id BIGINT UNSIGNED NOT NULL PRIMARY KEY, id BIGINT UNSIGNED NOT NULL,
hash BINARY(20) NOT NULL, hash BINARY(20) NOT NULL,
compression TINYINT UNSIGNED NULL, compression TINYINT UNSIGNED NULL,
value MEDIUMBLOB NOT NULL%s value MEDIUMBLOB NOT NULL,
) ENGINE=%s""" % (""", PRIMARY KEY (id){}{}
UNIQUE (hash, compression)""" if dedup else "", engine) ) ENGINE={}""".format(cf('append'), """,
UNIQUE (hash, compression)""" + cf('no_comp') if dedup else "",
engine)
schema_dict['bigdata'] = """CREATE TABLE %s ( schema_dict['bigdata'] = """CREATE TABLE %s (
id INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, id INT UNSIGNED NOT NULL AUTO_INCREMENT,
value MEDIUMBLOB NOT NULL value MEDIUMBLOB NOT NULL,
) ENGINE=""" + engine PRIMARY KEY (id){}
) ENGINE={}""".format(cf('append'), p)
# The table "ttrans" stores information on uncommitted transactions. # The table "ttrans" stores information on uncommitted transactions.
schema_dict['ttrans'] = """CREATE TABLE %s ( schema_dict['ttrans'] = """CREATE TABLE %s (
...@@ -348,8 +367,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -348,8 +367,9 @@ class MySQLDatabaseManager(DatabaseManager):
user BLOB NOT NULL, user BLOB NOT NULL,
description BLOB NOT NULL, description BLOB NOT NULL,
ext BLOB NOT NULL, ext BLOB NOT NULL,
ttid BIGINT UNSIGNED NOT NULL ttid BIGINT UNSIGNED NOT NULL,
) ENGINE=""" + engine PRIMARY KEY (ttid){}
) ENGINE={}""".format(cf('no_comp'), p)
# The table "tobj" stores uncommitted object metadata. # The table "tobj" stores uncommitted object metadata.
schema_dict['tobj'] = """CREATE TABLE %s ( schema_dict['tobj'] = """CREATE TABLE %s (
...@@ -358,8 +378,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -358,8 +378,8 @@ class MySQLDatabaseManager(DatabaseManager):
tid BIGINT UNSIGNED NOT NULL, tid BIGINT UNSIGNED NOT NULL,
data_id BIGINT UNSIGNED NULL, data_id BIGINT UNSIGNED NULL,
value_tid BIGINT UNSIGNED NULL, value_tid BIGINT UNSIGNED NULL,
PRIMARY KEY (tid, oid) PRIMARY KEY (tid, oid){}
) ENGINE=""" + engine ) ENGINE={}""".format(cf('no_comp'), p)
if self.nonempty('config') is None: if self.nonempty('config') is None:
q(schema_dict.pop('config') % 'config') q(schema_dict.pop('config') % 'config')
...@@ -407,6 +427,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -407,6 +427,9 @@ class MySQLDatabaseManager(DatabaseManager):
q("ALTER TABLE config MODIFY value VARBINARY(%s) NULL" % len(value)) q("ALTER TABLE config MODIFY value VARBINARY(%s) NULL" % len(value))
q(sql) q(sql)
def _getMaxPartition(self):
return self.query("SELECT MAX(`partition`) FROM pt")[0][0]
def _getPartitionTable(self): def _getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
...@@ -965,7 +988,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -965,7 +988,7 @@ class MySQLDatabaseManager(DatabaseManager):
cmd += self._cmdline() cmd += self._cmdline()
return subprocess.check_output(cmd) return subprocess.check_output(cmd)
def restore(self, sql): def _restore(self, sql):
import subprocess import subprocess
cmd = ['mysql'] cmd = ['mysql']
cmd += self._cmdline() cmd += self._cmdline()
......
...@@ -79,6 +79,7 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -79,6 +79,7 @@ class SQLiteDatabaseManager(DatabaseManager):
def _connect(self): def _connect(self):
logging.info('connecting to SQLite database %r', self.db) logging.info('connecting to SQLite database %r', self.db)
self.conn = sqlite3.connect(self.db, check_same_thread=False) self.conn = sqlite3.connect(self.db, check_same_thread=False)
self.conn.text_factory = str
self.lock(self.db) self.lock(self.db)
if self.UNSAFE: if self.UNSAFE:
q = self.query q = self.query
...@@ -144,6 +145,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -144,6 +145,12 @@ class SQLiteDatabaseManager(DatabaseManager):
" WHEN 2 THEN -2" # FEEDING " WHEN 2 THEN -2" # FEEDING
" ELSE 1-state END") " ELSE 1-state END")
# Let's wait for a more important change to clean up,
# so that users can still downgrade.
if 0:
def _migrate4(self, schema_dict, index_dict):
self._setConfiguration('partitions', None)
def _setup(self, dedup=False): def _setup(self, dedup=False):
# BBB: SQLite has transactional DDL but before Python 3.6, # BBB: SQLite has transactional DDL but before Python 3.6,
# the binding automatically commits between such statements. # the binding automatically commits between such statements.
...@@ -265,6 +272,9 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -265,6 +272,9 @@ class SQLiteDatabaseManager(DatabaseManager):
else: else:
q("REPLACE INTO config VALUES (?,?)", (key, str(value))) q("REPLACE INTO config VALUES (?,?)", (key, str(value)))
def _getMaxPartition(self):
return self.query("SELECT MAX(`partition`) FROM pt").next()[0]
def _getPartitionTable(self): def _getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
...@@ -451,8 +461,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -451,8 +461,12 @@ class SQLiteDatabaseManager(DatabaseManager):
return r return r
def loadData(self, data_id): def loadData(self, data_id):
return self.query("SELECT compression, hash, value" compression, checksum, data = self.query(
" FROM data WHERE id=?", (data_id,)).fetchone() "SELECT compression, hash, value FROM data WHERE id=?",
(data_id,)).fetchone()
if checksum:
return compression, str(checksum), str(data)
return compression, checksum, data
def _getDataTID(self, oid, tid=None, before_tid=None): def _getDataTID(self, oid, tid=None, before_tid=None):
partition = self._getReadablePartition(oid) partition = self._getReadablePartition(oid)
...@@ -712,5 +726,5 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -712,5 +726,5 @@ class SQLiteDatabaseManager(DatabaseManager):
main[-1:-1] = data main[-1:-1] = data
return '\n'.join(main) + '\n' return '\n'.join(main) + '\n'
def restore(self, sql): def _restore(self, sql):
self.conn.executescript(sql) self.conn.executescript(sql)
...@@ -65,14 +65,14 @@ class BaseMasterHandler(BaseHandler): ...@@ -65,14 +65,14 @@ class BaseMasterHandler(BaseHandler):
# See comment in ClientOperationHandler.connectionClosed # See comment in ClientOperationHandler.connectionClosed
self.app.tm.abortFor(uuid, even_if_voted=True) self.app.tm.abortFor(uuid, even_if_voted=True)
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
"""This is very similar to Send Partition Table, except that """This is very similar to Send Partition Table, except that
the information is only about changes from the previous.""" the information is only about changes from the previous."""
app = self.app app = self.app
if ptid != 1 + app.pt.getID(): if ptid != 1 + app.pt.getID():
raise ProtocolError('wrong partition table id') raise ProtocolError('wrong partition table id')
app.pt.update(ptid, cell_list, app.nm) app.pt.update(ptid, num_replicas, cell_list, app.nm)
app.dm.changePartitionTable(ptid, cell_list) app.dm.changePartitionTable(ptid, num_replicas, cell_list)
if app.operational: if app.operational:
app.replicator.notifyPartitionChanges(cell_list) app.replicator.notifyPartitionChanges(cell_list)
app.dm.commit() app.dm.commit()
......
...@@ -53,7 +53,7 @@ class ClientOperationHandler(BaseHandler): ...@@ -53,7 +53,7 @@ class ClientOperationHandler(BaseHandler):
p = Errors.TidNotFound('%s does not exist' % dump(tid)) p = Errors.TidNotFound('%s does not exist' % dump(tid))
else: else:
p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3], p = Packets.AnswerTransactionInformation(tid, t[1], t[2], t[3],
t[4], t[0]) bool(t[4]), t[0])
conn.answer(p) conn.answer(p)
def getEventQueue(self): def getEventQueue(self):
...@@ -183,12 +183,13 @@ class ClientOperationHandler(BaseHandler): ...@@ -183,12 +183,13 @@ class ClientOperationHandler(BaseHandler):
getObjectFromTransaction = app.tm.getObjectFromTransaction getObjectFromTransaction = app.tm.getObjectFromTransaction
object_tid_dict = {} object_tid_dict = {}
for oid in oid_list: for oid in oid_list:
current_serial, undo_serial, is_current = findUndoTID(oid, ttid, r = findUndoTID(oid, ttid,
ltid, undone_tid, getObjectFromTransaction(ttid, oid)) ltid, undone_tid, getObjectFromTransaction(ttid, oid))
if current_serial is None: if r:
p = Errors.OidNotFound(dump(oid)) if not r[0]:
break p = Errors.OidNotFound(dump(oid))
object_tid_dict[oid] = (current_serial, undo_serial, is_current) break
object_tid_dict[oid] = r
else: else:
p = Packets.AnswerObjectUndoSerial(object_tid_dict) p = Packets.AnswerObjectUndoSerial(object_tid_dict)
conn.answer(p) conn.answer(p)
......
...@@ -32,7 +32,7 @@ class IdentificationHandler(EventHandler): ...@@ -32,7 +32,7 @@ class IdentificationHandler(EventHandler):
return self.app.nm return self.app.nm
def requestIdentification(self, conn, node_type, uuid, address, name, def requestIdentification(self, conn, node_type, uuid, address, name,
devpath, id_timestamp): id_timestamp, devpath, new_nid):
self.checkClusterName(name) self.checkClusterName(name)
app = self.app app = self.app
# reject any incoming connections if not ready # reject any incoming connections if not ready
...@@ -65,6 +65,5 @@ class IdentificationHandler(EventHandler): ...@@ -65,6 +65,5 @@ class IdentificationHandler(EventHandler):
conn.setHandler(handler) conn.setHandler(handler)
node.setConnection(conn, force) node.setConnection(conn, force)
# accept the identification and trigger an event # accept the identification and trigger an event
conn.answer(Packets.AcceptIdentification(NodeTypes.STORAGE, uuid and conn.answer(Packets.AcceptIdentification(
app.uuid, app.pt.getPartitions(), app.pt.getReplicas(), uuid)) NodeTypes.STORAGE, uuid and app.uuid, uuid))
handler.connectionCompleted(conn)
...@@ -20,10 +20,10 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID ...@@ -20,10 +20,10 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID
class InitializationHandler(BaseMasterHandler): class InitializationHandler(BaseMasterHandler):
def sendPartitionTable(self, conn, ptid, row_list): def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
app = self.app app = self.app
pt = app.pt pt = app.pt
pt.load(ptid, row_list, app.nm) pt.load(ptid, num_replicas, row_list, app.nm)
if not pt.filled(): if not pt.filled():
raise ProtocolError('Partial partition table received') raise ProtocolError('Partial partition table received')
# Install the partition table into the database for persistence. # Install the partition table into the database for persistence.
...@@ -44,7 +44,7 @@ class InitializationHandler(BaseMasterHandler): ...@@ -44,7 +44,7 @@ class InitializationHandler(BaseMasterHandler):
logging.debug('drop data for partitions %r', unassigned) logging.debug('drop data for partitions %r', unassigned)
dm.dropPartitions(unassigned) dm.dropPartitions(unassigned)
dm.changePartitionTable(ptid, cell_list, reset=True) dm.changePartitionTable(ptid, num_replicas, cell_list, reset=True)
dm.commit() dm.commit()
def truncate(self, conn, tid): def truncate(self, conn, tid):
...@@ -68,7 +68,8 @@ class InitializationHandler(BaseMasterHandler): ...@@ -68,7 +68,8 @@ class InitializationHandler(BaseMasterHandler):
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList())) conn.answer(Packets.AnswerPartitionTable(
pt.getID(), pt.getReplicas(), pt.getRowList()))
def askLockedTransactions(self, conn): def askLockedTransactions(self, conn):
conn.answer(Packets.AnswerLockedTransactions( conn.answer(Packets.AnswerLockedTransactions(
......
...@@ -212,7 +212,7 @@ class StorageOperationHandler(EventHandler): ...@@ -212,7 +212,7 @@ class StorageOperationHandler(EventHandler):
# Sending such packet does not mark the connection # Sending such packet does not mark the connection
# for writing if there's too little data in the buffer. # for writing if there's too little data in the buffer.
conn.send(Packets.AddTransaction(tid, user, conn.send(Packets.AddTransaction(tid, user,
desc, ext, packed, ttid, oid_list), msg_id) desc, ext, bool(packed), ttid, oid_list), msg_id)
# To avoid delaying several connections simultaneously, # To avoid delaying several connections simultaneously,
# and also prevent the backend from scanning different # and also prevent the backend from scanning different
# parts of the DB at the same time, we ask the # parts of the DB at the same time, we ask the
...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler): ...@@ -248,7 +248,7 @@ class StorageOperationHandler(EventHandler):
for serial, oid in object_list: for serial, oid in object_list:
oid_set = object_dict.get(serial) oid_set = object_dict.get(serial)
if oid_set: if oid_set:
if type(oid_set) is list: if type(oid_set) is tuple:
object_dict[serial] = oid_set = set(oid_set) object_dict[serial] = oid_set = set(oid_set)
if oid in oid_set: if oid in oid_set:
oid_set.remove(oid) oid_set.remove(oid)
......
...@@ -350,7 +350,7 @@ class Replicator(object): ...@@ -350,7 +350,7 @@ class Replicator(object):
try: try:
conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE, conn.ask(Packets.RequestIdentification(NodeTypes.STORAGE,
None if name else app.uuid, app.server, name or app.name, None if name else app.uuid, app.server, name or app.name,
(), app.id_timestamp)) app.id_timestamp, (), ()))
except ConnectionClosed: except ConnectionClosed:
if previous_node is self.current_node: if previous_node is self.current_node:
return return
......
...@@ -98,9 +98,12 @@ class TransactionManager(EventQueue): ...@@ -98,9 +98,12 @@ class TransactionManager(EventQueue):
self._load_lock_dict = {} self._load_lock_dict = {}
self._replicated = {} self._replicated = {}
self._replicating = set() self._replicating = set()
def getPartition(self, oid):
from neo.lib.util import u64 from neo.lib.util import u64
np = app.pt.getPartitions() np = self._app.pt.getPartitions()
self.getPartition = lambda oid: u64(oid) % np self.getPartition = lambda oid: u64(oid) % np
return self.getPartition(oid)
def discarded(self, offset_list): def discarded(self, offset_list):
self._replicating.difference_update(offset_list) self._replicating.difference_update(offset_list)
......
...@@ -21,6 +21,7 @@ import gc ...@@ -21,6 +21,7 @@ import gc
import os import os
import random import random
import socket import socket
import subprocess
import sys import sys
import tempfile import tempfile
import unittest import unittest
...@@ -41,7 +42,7 @@ from .mock import Mock ...@@ -41,7 +42,7 @@ from .mock import Mock
from neo.lib import debug, logging, protocol from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
from neo.lib.util import cached_property from neo.lib.util import cached_property
from time import time from time import time, sleep
from struct import pack, unpack from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess from unittest.case import _ExpectedFailure, _UnexpectedSuccess
try: try:
...@@ -72,6 +73,9 @@ DB_ADMIN = os.getenv('NEO_DB_ADMIN', 'root') ...@@ -72,6 +73,9 @@ DB_ADMIN = os.getenv('NEO_DB_ADMIN', 'root')
DB_PASSWD = os.getenv('NEO_DB_PASSWD', '') DB_PASSWD = os.getenv('NEO_DB_PASSWD', '')
DB_USER = os.getenv('NEO_DB_USER', 'test') DB_USER = os.getenv('NEO_DB_USER', 'test')
DB_SOCKET = os.getenv('NEO_DB_SOCKET', '') DB_SOCKET = os.getenv('NEO_DB_SOCKET', '')
DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db')
DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld')
DB_MYCNF = os.getenv('NEO_DB_MYCNF')
IP_VERSION_FORMAT_DICT = { IP_VERSION_FORMAT_DICT = {
socket.AF_INET: '127.0.0.1', socket.AF_INET: '127.0.0.1',
...@@ -134,8 +138,12 @@ def getTempDirectory(): ...@@ -134,8 +138,12 @@ def getTempDirectory():
print 'Using temp directory %r.' % temp_dir print 'Using temp directory %r.' % temp_dir
return temp_dir return temp_dir
def setupMySQLdb(db_list, user=DB_USER, password='', clear_databases=True): def setupMySQLdb(db_list, clear_databases=True):
if mysql_pool:
return mysql_pool.setup(db_list, clear_databases)
from MySQLdb.constants.ER import BAD_DB_ERROR from MySQLdb.constants.ER import BAD_DB_ERROR
user = DB_USER
password = ''
kw = {'unix_socket': os.path.expanduser(DB_SOCKET)} if DB_SOCKET else {} kw = {'unix_socket': os.path.expanduser(DB_SOCKET)} if DB_SOCKET else {}
conn = MySQLdb.connect(user=DB_ADMIN, passwd=DB_PASSWD, **kw) conn = MySQLdb.connect(user=DB_ADMIN, passwd=DB_PASSWD, **kw)
cursor = conn.cursor() cursor = conn.cursor()
...@@ -154,6 +162,88 @@ def setupMySQLdb(db_list, user=DB_USER, password='', clear_databases=True): ...@@ -154,6 +162,88 @@ def setupMySQLdb(db_list, user=DB_USER, password='', clear_databases=True):
cursor.close() cursor.close()
conn.commit() conn.commit()
conn.close() conn.close()
return '{}:{}@%s{}'.format(user, password, DB_SOCKET).__mod__
class MySQLPool(object):
def __init__(self, pool_dir=None):
self._args = {}
self._mysqld_dict = {}
if not pool_dir:
pool_dir = getTempDirectory()
self._base = pool_dir + os.sep
self._sock_template = os.path.join(pool_dir, '%s', 'mysql.sock')
def __del__(self):
self.kill(*self._mysqld_dict)
def setup(self, db_list, clear_databases):
start_list = set(db_list).difference(self._mysqld_dict)
if start_list:
start_list = sorted(start_list)
x = []
with open(os.devnull, 'wb') as f:
for db in start_list:
base = self._base + db
datadir = os.path.join(base, 'datadir')
sock = self._sock_template % db
tmpdir = os.path.join(base, 'tmp')
args = [DB_INSTALL,
'--defaults-file=' + DB_MYCNF,
'--datadir=' + datadir,
'--socket=' + sock,
'--tmpdir=' + tmpdir,
'--log_error=' + os.path.join(base, 'error.log')]
if os.path.exists(datadir):
try:
os.remove(sock)
except OSError, e:
if e.errno != errno.ENOENT:
raise
else:
os.makedirs(tmpdir)
x.append(subprocess.Popen(args,
stdout=f, stderr=subprocess.STDOUT))
args[0] = DB_MYSQLD
self._args[db] = args
for x in x:
x = x.wait()
if x:
raise subprocess.CalledProcessError(x, DB_INSTALL)
self.start(*start_list)
for db in start_list:
sock = self._sock_template % db
p = self._mysqld_dict[db]
while not os.path.exists(sock):
sleep(1)
x = p.poll()
if x is not None:
raise subprocess.CalledProcessError(x, DB_MYSQLD)
for db in db_list:
db = MySQLdb.connect(unix_socket=self._sock_template % db,
user='root')
if clear_databases:
db.query('DROP DATABASE IF EXISTS neo')
db.query('CREATE DATABASE IF NOT EXISTS neo')
db.close()
return ('root@neo' + self._sock_template).__mod__
def start(self, *db, **kw):
assert set(db).isdisjoint(self._mysqld_dict)
for db in db:
self._mysqld_dict[db] = subprocess.Popen(self._args[db], **kw)
def kill(self, *db):
processes = []
for db in db:
p = self._mysqld_dict.pop(db)
processes.append(p)
p.kill()
for p in processes:
p.wait()
mysql_pool = MySQLPool() if DB_MYCNF else None
def ImporterConfigParser(adapter, zodb, **kw): def ImporterConfigParser(adapter, zodb, **kw):
cfg = SafeConfigParser() cfg = SafeConfigParser()
...@@ -244,13 +334,15 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -244,13 +334,15 @@ class NeoUnitTestBase(NeoTestBase):
""" create empty databases """ """ create empty databases """
adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL') adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL')
if adapter == 'MySQL': if adapter == 'MySQL':
setupMySQLdb([prefix + str(i) for i in xrange(number)]) db_template = setupMySQLdb(
[prefix + str(i) for i in xrange(number)])
self.db_template = lambda i: db_template(prefix + str(i))
elif adapter == 'SQLite': elif adapter == 'SQLite':
temp_dir = getTempDirectory() self.db_template = os.path.join(getTempDirectory(),
prefix + '%s.sqlite').__mod__
for i in xrange(number): for i in xrange(number):
try: try:
os.remove(os.path.join(temp_dir, os.remove(self.db_template(i))
'%s%s.sqlite' % (prefix, i)))
except OSError, e: except OSError, e:
if e.errno != errno.ENOENT: if e.errno != errno.ENOENT:
raise raise
...@@ -274,21 +366,14 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -274,21 +366,14 @@ class NeoUnitTestBase(NeoTestBase):
def getStorageConfiguration(self, cluster='main', master_number=2, def getStorageConfiguration(self, cluster='main', master_number=2,
index=0, prefix=DB_PREFIX, uuid=None): index=0, prefix=DB_PREFIX, uuid=None):
assert master_number >= 1 and master_number <= 10 assert master_number >= 1 and master_number <= 10
assert index >= 0 and index <= 9
masters = [(buildUrlFromString(self.local_ip), masters = [(buildUrlFromString(self.local_ip),
10010 + i) for i in xrange(master_number)] 10010 + i) for i in xrange(master_number)]
adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL') adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL')
if adapter == 'MySQL':
db = '%s@%s%s%s' % (DB_USER, prefix, index, DB_SOCKET)
elif adapter == 'SQLite':
db = os.path.join(getTempDirectory(), 'test_neo%s.sqlite' % index)
else:
assert False, adapter
return { return {
'cluster': cluster, 'cluster': cluster,
'bind': (masters[0], 10020 + index), 'bind': (masters[0], 10020 + index),
'masters': masters, 'masters': masters,
'database': db, 'database': self.db_template(index),
'uuid': uuid, 'uuid': uuid,
'adapter': adapter, 'adapter': adapter,
'wait': 0, 'wait': 0,
......
...@@ -36,7 +36,7 @@ from neo.lib import logging ...@@ -36,7 +36,7 @@ from neo.lib import logging
from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \
UUID_NAMESPACES UUID_NAMESPACES
from neo.lib.util import dump, setproctitle from neo.lib.util import dump, setproctitle
from .. import (ADDRESS_TYPE, DB_SOCKET, DB_USER, IP_VERSION_FORMAT_DICT, SSL, from .. import (ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, SSL,
buildUrlFromString, cluster, getTempDirectory, setupMySQLdb, buildUrlFromString, cluster, getTempDirectory, setupMySQLdb,
ImporterConfigParser, NeoTestBase, Patch) ImporterConfigParser, NeoTestBase, Patch)
from neo.client.Storage import Storage from neo.client.Storage import Storage
...@@ -282,7 +282,7 @@ class NEOProcess(Process): ...@@ -282,7 +282,7 @@ class NEOProcess(Process):
def _args(self): def _args(self):
args = super(NEOProcess, self)._args() args = super(NEOProcess, self)._args()
if self.uuid: if self.uuid:
args[:0] = '--uuid', str(self.uuid) args[:0] = '--nid', str(self.uuid)
return args return args
def run(self): def run(self):
...@@ -306,11 +306,11 @@ class NEOCluster(object): ...@@ -306,11 +306,11 @@ class NEOCluster(object):
SSL = None SSL = None
def __init__(self, db_list, master_count=1, partitions=1, replicas=0, def __init__(self, db_list, master_count=1, partitions=1, replicas=0,
db_user=DB_USER, db_password='', name=None, name=None,
cleanup_on_delete=False, temp_dir=None, clear_databases=True, cleanup_on_delete=False, temp_dir=None, clear_databases=True,
adapter=os.getenv('NEO_TESTS_ADAPTER'), adapter=os.getenv('NEO_TESTS_ADAPTER'),
address_type=ADDRESS_TYPE, bind_ip=None, logger=True, address_type=ADDRESS_TYPE, bind_ip=None, logger=True,
importer=None): importer=None, storage_kw={}):
if not adapter: if not adapter:
adapter = 'MySQL' adapter = 'MySQL'
self.adapter = adapter self.adapter = adapter
...@@ -322,20 +322,28 @@ class NEOCluster(object): ...@@ -322,20 +322,28 @@ class NEOCluster(object):
temp_dir = tempfile.mkdtemp(prefix='neo_') temp_dir = tempfile.mkdtemp(prefix='neo_')
print 'Using temp directory ' + temp_dir print 'Using temp directory ' + temp_dir
if adapter == 'MySQL': if adapter == 'MySQL':
self.db_user = db_user self.db_template = setupMySQLdb(db_list, clear_databases)
self.db_password = db_password
self.db_template = ('%s:%s@%%s%s' % (db_user, db_password,
DB_SOCKET)).__mod__
elif adapter == 'SQLite': elif adapter == 'SQLite':
self.db_template = (lambda t: lambda db: self.db_template = (lambda t: lambda db:
':memory:' if db is None else db if os.sep in db else t % db ':memory:' if db is None else db if os.sep in db else t % db
)(os.path.join(temp_dir, '%s.sqlite')) )(os.path.join(temp_dir, '%s.sqlite'))
if clear_databases:
for db in self.db_list:
if db is None:
continue
db = self.db_template(db)
try:
os.remove(db)
except OSError, e:
if e.errno != errno.ENOENT:
raise
else:
logging.debug('%r deleted', db)
else: else:
assert False, adapter assert False, adapter
self.address_type = address_type self.address_type = address_type
self.local_ip = local_ip = bind_ip or \ self.local_ip = local_ip = bind_ip or \
IP_VERSION_FORMAT_DICT[self.address_type] IP_VERSION_FORMAT_DICT[self.address_type]
self.setupDB(clear_databases)
if importer: if importer:
cfg = ImporterConfigParser(adapter, **importer) cfg = ImporterConfigParser(adapter, **importer)
cfg.set("neo", "database", self.db_template(*db_list)) cfg.set("neo", "database", self.db_template(*db_list))
...@@ -364,7 +372,8 @@ class NEOCluster(object): ...@@ -364,7 +372,8 @@ class NEOCluster(object):
# create storage nodes # create storage nodes
for i, db in enumerate(db_list): for i, db in enumerate(db_list):
self._newProcess(NodeTypes.STORAGE, logger and 'storage_%u' % i, self._newProcess(NodeTypes.STORAGE, logger and 'storage_%u' % i,
0, adapter=adapter, database=self.db_template(db)) 0, adapter=adapter, database=self.db_template(db),
**storage_kw)
# create neoctl # create neoctl
self.neoctl = NeoCTL((self.local_ip, admin_port), ssl=self.SSL) self.neoctl = NeoCTL((self.local_ip, admin_port), ssl=self.SSL)
...@@ -382,23 +391,10 @@ class NEOCluster(object): ...@@ -382,23 +391,10 @@ class NEOCluster(object):
self.process_dict.setdefault(node_type, []).append( self.process_dict.setdefault(node_type, []).append(
NEOProcess(command_dict[node_type], uuid=uuid, **kw)) NEOProcess(command_dict[node_type], uuid=uuid, **kw))
def setupDB(self, clear_databases=True): def resetDB(self):
if self.adapter == 'MySQL': for db in self.db_list:
setupMySQLdb(self.db_list, self.db_user, self.db_password, dm = buildDatabaseManager(self.adapter, (self.db_template(db),))
clear_databases) dm.setup(True)
elif self.adapter == 'SQLite':
if clear_databases:
for db in self.db_list:
if db is None:
continue
db = self.db_template(db)
try:
os.remove(db)
except OSError, e:
if e.errno != errno.ENOENT:
raise
else:
logging.debug('%r deleted', db)
def run(self, except_storages=()): def run(self, except_storages=()):
""" Start cluster processes except some storage nodes """ """ Start cluster processes except some storage nodes """
...@@ -437,7 +433,7 @@ class NEOCluster(object): ...@@ -437,7 +433,7 @@ class NEOCluster(object):
pending_count += 1 pending_count += 1
if pending_count == target[0]: if pending_count == target[0]:
neoctl.startCluster() neoctl.startCluster()
except (NotReadyException, RuntimeError): except (NotReadyException, SystemExit):
pass pass
if not pdb.wait(test, MAX_START_TIME): if not pdb.wait(test, MAX_START_TIME):
raise AssertionError('Timeout when starting cluster') raise AssertionError('Timeout when starting cluster')
...@@ -449,7 +445,7 @@ class NEOCluster(object): ...@@ -449,7 +445,7 @@ class NEOCluster(object):
def start(last_try): def start(last_try):
try: try:
self.neoctl.startCluster() self.neoctl.startCluster()
except (NotReadyException, RuntimeError), e: except (NotReadyException, SystemExit), e:
return False, e return False, e
return True, None return True, None
self.expectCondition(start) self.expectCondition(start)
...@@ -653,10 +649,10 @@ class NEOCluster(object): ...@@ -653,10 +649,10 @@ class NEOCluster(object):
def expectOudatedCells(self, number, *args, **kw): def expectOudatedCells(self, number, *args, **kw):
def callback(last_try): def callback(last_try):
row_list = self.neoctl.getPartitionRowList()[1] row_list = self.neoctl.getPartitionRowList()[2]
number_of_outdated = 0 number_of_outdated = 0
for row in row_list: for row in row_list:
for cell in row[1]: for cell in row:
if cell[1] == CellStates.OUT_OF_DATE: if cell[1] == CellStates.OUT_OF_DATE:
number_of_outdated += 1 number_of_outdated += 1
return number_of_outdated == number, number_of_outdated return number_of_outdated == number, number_of_outdated
...@@ -664,10 +660,10 @@ class NEOCluster(object): ...@@ -664,10 +660,10 @@ class NEOCluster(object):
def expectAssignedCells(self, process, number, *args, **kw): def expectAssignedCells(self, process, number, *args, **kw):
def callback(last_try): def callback(last_try):
row_list = self.neoctl.getPartitionRowList()[1] row_list = self.neoctl.getPartitionRowList()[2]
assigned_cells_number = 0 assigned_cells_number = 0
for row in row_list: for row in row_list:
for cell in row[1]: for cell in row:
if cell[0] == process.getUUID(): if cell[0] == process.getUUID():
assigned_cells_number += 1 assigned_cells_number += 1
return assigned_cells_number == number, assigned_cells_number return assigned_cells_number == number, assigned_cells_number
......
...@@ -62,8 +62,6 @@ class ClientTests(NEOFunctionalTest): ...@@ -62,8 +62,6 @@ class ClientTests(NEOFunctionalTest):
NEOFunctionalTest._tearDown(self, success) NEOFunctionalTest._tearDown(self, success)
def __setup(self): def __setup(self):
# start cluster
self.neo.setupDB()
self.neo.start() self.neo.start()
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
self.db = ZODB.DB(self.neo.getZODBStorage()) self.db = ZODB.DB(self.neo.getZODBStorage())
......
...@@ -71,7 +71,6 @@ class ClusterTests(NEOFunctionalTest): ...@@ -71,7 +71,6 @@ class ClusterTests(NEOFunctionalTest):
def testClusterBreaks(self): def testClusterBreaks(self):
self.neo = NEOCluster(['test_neo1'], self.neo = NEOCluster(['test_neo1'],
master_count=1, temp_dir=self.getTempDirectory()) master_count=1, temp_dir=self.getTempDirectory())
self.neo.setupDB()
self.neo.start() self.neo.start()
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
self.neo.expectOudatedCells(number=0) self.neo.expectOudatedCells(number=0)
...@@ -82,7 +81,6 @@ class ClusterTests(NEOFunctionalTest): ...@@ -82,7 +81,6 @@ class ClusterTests(NEOFunctionalTest):
self.neo = NEOCluster(['test_neo1', 'test_neo2'], self.neo = NEOCluster(['test_neo1', 'test_neo2'],
partitions=2, master_count=1, replicas=0, partitions=2, master_count=1, replicas=0,
temp_dir=self.getTempDirectory()) temp_dir=self.getTempDirectory())
self.neo.setupDB()
self.neo.start() self.neo.start()
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
self.neo.expectOudatedCells(number=0) self.neo.expectOudatedCells(number=0)
...@@ -93,7 +91,6 @@ class ClusterTests(NEOFunctionalTest): ...@@ -93,7 +91,6 @@ class ClusterTests(NEOFunctionalTest):
self.neo = NEOCluster(['test_neo1', 'test_neo2'], self.neo = NEOCluster(['test_neo1', 'test_neo2'],
partitions=2, replicas=1, master_count=1, partitions=2, replicas=1, master_count=1,
temp_dir=self.getTempDirectory()) temp_dir=self.getTempDirectory())
self.neo.setupDB()
self.neo.start() self.neo.start()
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
self.neo.expectOudatedCells(number=0) self.neo.expectOudatedCells(number=0)
......
...@@ -47,7 +47,7 @@ class MasterTests(NEOFunctionalTest): ...@@ -47,7 +47,7 @@ class MasterTests(NEOFunctionalTest):
break break
neoctl.killNode(uuid) neoctl.killNode(uuid)
self.neo.expectDead(master) self.neo.expectDead(master)
self.assertRaises(RuntimeError, neoctl.killNode, primary_uuid) self.assertRaises(SystemExit, neoctl.killNode, primary_uuid)
def testStoppingPrimaryWithTwoSecondaries(self): def testStoppingPrimaryWithTwoSecondaries(self):
# Wait for masters to stabilize # Wait for masters to stabilize
......
...@@ -172,7 +172,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -172,7 +172,7 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectOudatedCells(2) self.neo.expectOudatedCells(2)
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
self.assertRaises(RuntimeError, self.neo.neoctl.killNode, self.assertRaises(SystemExit, self.neo.neoctl.killNode,
started[1].getUUID()) started[1].getUUID())
started[1].stop() started[1].stop()
# Cluster not operational anymore. Only cells of second storage that # Cluster not operational anymore. Only cells of second storage that
...@@ -323,7 +323,7 @@ class StorageTests(NEOFunctionalTest): ...@@ -323,7 +323,7 @@ class StorageTests(NEOFunctionalTest):
self.neo.expectStorageUnknown(started[0]) self.neo.expectStorageUnknown(started[0])
self.neo.expectAssignedCells(started[0], 0) self.neo.expectAssignedCells(started[0], 0)
self.neo.expectAssignedCells(started[1], 10) self.neo.expectAssignedCells(started[1], 10)
self.assertRaises(RuntimeError, self.neo.neoctl.dropNode, self.assertRaises(SystemExit, self.neo.neoctl.dropNode,
started[1].getUUID()) started[1].getUUID())
self.neo.expectClusterRunning() self.neo.expectClusterRunning()
......
...@@ -30,8 +30,6 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -30,8 +30,6 @@ class MasterClientHandlerTests(NeoUnitTestBase):
config = self.getMasterConfiguration(master_number=1, replicas=1) config = self.getMasterConfiguration(master_number=1, replicas=1)
self.app = Application(config) self.app = Application(config)
self.app.em.close() self.app.em.close()
self.app.pt.clear()
self.app.pt.setID(1)
self.app.em = Mock() self.app.em = Mock()
self.app.loid = '\0' * 8 self.app.loid = '\0' * 8
self.app.tm.setLastTID('\0' * 8) self.app.tm.setLastTID('\0' * 8)
...@@ -73,7 +71,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -73,7 +71,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0] ptid = self.checkAskPacket(storage_conn, Packets.AskPack)._args[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn) self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id) self.assertEqual(self.app.packing[1], peer_id)
...@@ -85,7 +83,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -85,7 +83,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0] status = self.checkAnswerPacket(conn, Packets.AnswerPack)._args[0]
self.assertFalse(status) self.assertFalse(status)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -26,7 +26,6 @@ class MasterAppTests(NeoUnitTestBase): ...@@ -26,7 +26,6 @@ class MasterAppTests(NeoUnitTestBase):
# create an application object # create an application object
config = self.getMasterConfiguration() config = self.getMasterConfiguration()
self.app = Application(config) self.app = Application(config)
self.app.pt.clear()
def _tearDown(self, success): def _tearDown(self, success):
self.app.close() self.app.close()
......
...@@ -289,7 +289,9 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -289,7 +289,9 @@ class MasterPartitionTableTests(NeoUnitTestBase):
pt.addNodeList(sn[1:3]) pt.addNodeList(sn[1:3])
self.assertPartitionTable(pt, 'U..|U..|U..|U..|U..|U..|U..') self.assertPartitionTable(pt, 'U..|U..|U..|U..|U..|U..|U..')
self.update(pt, self.tweak(pt, sn[:1])) self.update(pt, self.tweak(pt, sn[:1]))
self.assertPartitionTable(pt, '.U.|..U|.U.|..U|.U.|..U|.U.') # See note in PartitionTable.tweak() about drop_list.
#self.assertPartitionTable(pt,'.U.|..U|.U.|..U|.U.|..U|.U.')
self.assertPartitionTable(pt, 'UU.|U.U|UU.|U.U|UU.|U.U|UU.')
def test_18_tweakBigPT(self): def test_18_tweakBigPT(self):
seed = repr(time.time()) seed = repr(time.time())
......
#
# Copyright (C) 2009-2019 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NodeStates, CellStates
from neo.master.recovery import RecoveryManager
from neo.master.app import Application
class MasterRecoveryTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
# create an application object
config = self.getMasterConfiguration()
self.app = Application(config)
self.app.pt.clear()
self.recovery = RecoveryManager(self.app)
self.app.unconnected_master_node_set = set()
self.app.negotiating_master_node_set = set()
for node in self.app.nm.getMasterList():
self.app.unconnected_master_node_set.add(node.getAddress())
node.setState(NodeStates.RUNNING)
# define some variable to simulate client and storage node
self.storage_port = 10021
self.master_port = 10011
def _tearDown(self, success):
self.app.close()
NeoUnitTestBase._tearDown(self, success)
# Common methods
def identifyToMasterNode(self, node_type=NodeTypes.STORAGE, ip="127.0.0.1",
port=10021):
"""Do first step of identification to MN
"""
address = (ip, port)
uuid = self.getNewUUID(node_type)
self.app.nm.createFromNodeType(node_type, address=address, uuid=uuid,
state=NodeStates.RUNNING)
return uuid
# Tests
def test_10_answerPartitionTable(self):
# XXX: This test does much less that it seems, because all 'for' loops
# iterate over empty lists. Currently, only testRecovery covers
# some paths in NodeManager._createNode: apart from that, we could
# delete it entirely.
recovery = self.recovery
uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port)
# not from target node, ignore
uuid = self.identifyToMasterNode(NodeTypes.STORAGE, port=self.storage_port)
conn = self.getFakeConnection(uuid, self.storage_port)
node = self.app.nm.getByUUID(conn.getUUID())
offset = 1
cell_list = [(offset, uuid, CellStates.UP_TO_DATE)]
cells = self.app.pt.getRow(offset)
for cell, state in cells:
self.assertEqual(state, CellStates.OUT_OF_DATE)
recovery.target_ptid = 2
node.setPending()
recovery.answerPartitionTable(conn, 1, cell_list)
cells = self.app.pt.getRow(offset)
for cell, state in cells:
self.assertEqual(state, CellStates.OUT_OF_DATE)
# from target node, taken into account
conn = self.getFakeConnection(uuid, self.storage_port)
offset = 1
cell_list = [(offset, ((uuid, CellStates.UP_TO_DATE,),),)]
cells = self.app.pt.getRow(offset)
for cell, state in cells:
self.assertEqual(state, CellStates.OUT_OF_DATE)
node.setPending()
recovery.answerPartitionTable(conn, None, cell_list)
cells = self.app.pt.getRow(offset)
for cell, state in cells:
self.assertEqual(state, CellStates.UP_TO_DATE)
# give a bad offset, must send error
self.recovery.target_uuid = uuid
conn = self.getFakeConnection(uuid, self.storage_port)
offset = 1000000
self.assertFalse(self.app.pt.hasOffset(offset))
cell_list = [(offset, ((uuid, NodeStates.UNKNOWN,),),)]
node.setPending()
self.checkProtocolErrorRaised(recovery.answerPartitionTable, conn,
2, cell_list)
if __name__ == '__main__':
unittest.main()
...@@ -18,8 +18,8 @@ import unittest ...@@ -18,8 +18,8 @@ import unittest
from ..mock import Mock from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, Packets from neo.lib.protocol import NodeTypes, Packets
from neo.master.handlers.storage import StorageServiceHandler
from neo.master.app import Application from neo.master.app import Application
from neo.master.handlers.storage import StorageServiceHandler
class MasterStorageHandlerTests(NeoUnitTestBase): class MasterStorageHandlerTests(NeoUnitTestBase):
...@@ -29,7 +29,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -29,7 +29,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
config = self.getMasterConfiguration(master_number=1, replicas=1) config = self.getMasterConfiguration(master_number=1, replicas=1)
self.app = Application(config) self.app = Application(config)
self.app.em.close() self.app.em.close()
self.app.pt.clear()
self.app.em = Mock() self.app.em = Mock()
self.service = StorageServiceHandler(self.app) self.service = StorageServiceHandler(self.app)
...@@ -73,7 +72,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -73,7 +72,7 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.service.answerPack(conn2, False) self.service.answerPack(conn2, False)
packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack) packet = self.checkNotifyPacket(client_conn, Packets.AnswerPack)
# TODO: verify packet peer id # TODO: verify packet peer id
self.assertTrue(packet.decode()[0]) self.assertTrue(packet._args[0])
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
if __name__ == '__main__': if __name__ == '__main__':
......
# generated by running the whole test suite with -p
AbortTransaction(p64,[int])
AcceptIdentification(NodeTypes,?int,?int)
AddObject(p64,p64,int,bin,bin,?p64)
AddPendingNodes([int])
AddTransaction(p64,bin,bin,bin,bool,p64,[p64])
AnswerBeginTransaction(p64)
AnswerCheckCurrentSerial(?p64)
AnswerCheckSerialRange(int,bin,p64,bin,p64)
AnswerCheckTIDRange(int,bin,p64)
AnswerClusterState(?ClusterStates)
AnswerFetchObjects(?,?p64,?p64,{:})
AnswerFetchTransactions(?,?p64,[])
AnswerFinalTID(p64)
AnswerInformationLocked(p64)
AnswerLastIDs(?p64,?p64)
AnswerLastTransaction(p64)
AnswerLockedTransactions({p64:?p64})
AnswerNewOIDs([p64])
AnswerNodeList([(NodeTypes,?(bin,int),?int,NodeStates,?float)])
AnswerObject(p64,p64,?p64,?int,bin,bin,?p64)
AnswerObjectHistory(p64,[(p64,int)])
AnswerObjectUndoSerial({p64:(p64,?p64,bool)})
AnswerPack(bool)
AnswerPartitionList(int,int,[[(int,CellStates)]])
AnswerPartitionTable(int,int,[[(int,CellStates)]])
AnswerPrimary(int)
AnswerRebaseObject(?(p64,p64,?(int,bin,bin)))
AnswerRebaseTransaction([p64])
AnswerRecovery(?int,?p64,?p64)
AnswerStoreObject(?p64)
AnswerStoreTransaction()
AnswerTIDs([p64])
AnswerTIDsFrom([p64])
AnswerTransactionFinished(p64,p64)
AnswerTransactionInformation(p64,bin,bin,bin,bool,[p64])
AnswerTweakPartitionTable(bool,[[(int,CellStates)]])
AnswerUnfinishedTransactions(p64,[p64])
AnswerVoteTransaction()
AskBeginTransaction(?p64)
AskCheckCurrentSerial(p64,p64,p64)
AskCheckSerialRange(int,int,p64,p64,p64)
AskCheckTIDRange(int,int,p64,p64)
AskClusterState()
AskFetchObjects(int,int,p64,p64,p64,{p64:[p64]})
AskFetchTransactions(int,int,p64,p64,[p64])
AskFinalTID(p64)
AskFinishTransaction(p64,[p64],[p64])
AskLastIDs()
AskLastTransaction()
AskLockInformation(p64,p64)
AskLockedTransactions()
AskNewOIDs(int)
AskNodeList(NodeTypes)
AskObject(p64,?p64,?p64)
AskObjectHistory(p64,int,int)
AskObjectUndoSerial(p64,p64,p64,[p64])
AskPack(p64)
AskPartitionList(int,int,?)
AskPartitionTable()
AskPrimary()
AskRebaseObject(p64,p64)
AskRebaseTransaction(p64,p64)
AskRecovery()
AskStoreObject(p64,p64,int,bin,bin,?p64,?p64)
AskStoreTransaction(p64,bin,bin,bin,[p64])
AskTIDs(int,int,int)
AskTIDsFrom(p64,p64,int,int)
AskTransactionInformation(p64)
AskUnfinishedTransactions([int])
AskVoteTransaction(p64)
CheckPartition(int,(bin,?(bin,int)),p64,p64)
CheckReplicas({int:?int},p64,?)
Error(int,bin)
FailedVote(p64,[int])
InvalidateObjects(p64,[p64])
NotPrimaryMaster(?int,[(bin,int)])
NotifyClusterInformation(ClusterStates)
NotifyDeadlock(p64,p64)
NotifyNodeInformation(float,[(NodeTypes,?(bin,int),?int,NodeStates,?float)])
NotifyPartitionChanges(int,int,[(int,int,CellStates)])
NotifyPartitionCorrupted(int,[int])
NotifyReady()
NotifyRepair(bool)
NotifyReplicationDone(int,p64)
NotifyTransactionFinished(p64,p64)
NotifyUnlockInformation(p64)
Ping()
Pong()
Repair([int],bool)
Replicate(p64,bin,{int:?(bin,int)})
RequestIdentification(NodeTypes,?int,?(bin,int),bin,?float,any,[int])
SendPartitionTable(?int,int,[[(int,CellStates)]])
SetClusterState(ClusterStates)
SetNodeState(int,NodeStates)
SetNumReplicas(int)
StartOperation(bool)
StopOperation()
Truncate(p64)
TweakPartitionTable(bool,[int])
ValidateTransaction(p64,p64)
# The use of ast is convoluted, and the result quite verbose,
# but that remains simpler than writing a parser from scratch.
import ast, os
from contextlib import contextmanager
from neo.lib.protocol import Packet, Enum
array = list, set, tuple
item = Enum.Item
class _ast(object):
def __getattr__(self, k):
v = lambda *args: getattr(ast, k)(lineno=0, col_offset=0, *args)
setattr(self, k, v)
return v
_ast = _ast()
class parseArgument(ast.NodeTransformer):
def visit_UnaryOp(self, node):
assert isinstance(node.op, ast.USub)
return _ast.Call(_ast.Name('option', ast.Load()),
[self.visit(node.operand)], [], None, None)
def visit_Name(self, node):
return _ast.Str(node.id.replace('_', '?'))
parseArgument = parseArgument().visit
class Argument(object):
merge = True
type = ''
option = False
@classmethod
def load(cls, arg):
arg = ast.parse(arg.rstrip()
.replace('?(', '-(').replace('?[', '-[').replace('?{', '-{')
.replace('?', '_').replace('[]', '[""]')
.replace('{:', '{"":').replace(':}', ':""}'),
mode="eval")
x = arg.body
name = x.func.id
arg.body = parseArgument(_ast.Tuple(x.args, ast.Load()))
return name, cls._load(eval(compile(arg, '', mode="eval"),
{'option': cls._option}))
@classmethod
def _load(cls, arg):
t = type(arg)
if t is cls:
return arg
x = object.__new__(cls)
if t is tuple:
x.type = map(cls._load, arg)
elif t is list:
x.type = cls._load(*arg),
elif t is dict:
(k, v), = arg.iteritems()
x.type = cls._load(k), cls._load(v)
else:
if arg.startswith('?'):
arg = arg[1:]
x.option = True
x.type = arg
return x
@classmethod
def _option(cls, arg):
arg = cls._load(arg)
arg.option = True
return arg
@classmethod
def _merge(cls, args):
if args:
x, = {cls(x) for x in args}
return x
return object.__new__(cls)
def __init__(self, arg, root=False):
if arg is None:
self.option = True
elif isinstance(arg, tuple) and (root or len(arg) > 1):
self.type = map(self.__class__, arg)
elif isinstance(arg, array):
self.type = self._merge(arg),
elif isinstance(arg, dict):
self.type = self._merge(arg), self._merge(arg.values())
else:
self.type = (('p64' if len(arg) == 8 else
'bin') if isinstance(arg, bytes) else
arg._enum._name if isinstance(arg, item) else
'str' if isinstance(arg, unicode) else
'int' if isinstance(arg, long) else
type(arg).__name__)
def __repr__(self):
x = self.type
if type(x) is tuple:
x = ('[%s]' if len(x) == 1 else '{%s:%s}') % x
elif type(x) is list:
x = '(%s)' % ','.join(map(repr, x))
return '?' + x if self.option else x
def __hash__(self):
return 0
def __eq__(self, other):
x = self.type
y = other.type
if x and y and x != 'any':
# Since we don't know whether an array is fixed-size record of
# heterogeneous values or a collection of homogeneous values,
# we end up with the following complicated heuristic.
t = type(x)
if t is tuple:
if len(x) == 1 and type(y) is list:
z = set(x)
z.update(y)
if len(z) == 1:
x = y = tuple(z)
if self.merge:
self.type = x
elif t is list:
if type(y) is tuple and len(y) == 1:
z = set(y)
z.update(x)
if len(z) == 1:
x = y = tuple(z)
if self.merge:
self.type = x
t = tuple
elif t is str is type(y) and {x, y}.issuperset(('bin', 'p64')):
x = y = 'bin'
if self.merge:
self.type = x
if not (t is type(y) and (t is not tuple or
len(x) == len(y)) and x == y):
if not self.merge:
return False
self.type = 'any'
if self.merge:
if not x:
self.type = y
if not self.option:
self.option = other.option
elif y and not x or other.option and not self.option:
return False
return True
class FrozenArgument(Argument):
merge = False
@contextmanager
def protocolChecker(dump):
x = 'Packet(p64,?[(bin,{int:})],{:?(?,[])},?{?:float})'
assert x == '%s%r' % Argument.load(x)
assert not (FrozenArgument([]) == Argument([0]))
path = os.path.join(os.path.dirname(__file__), 'protocol')
if dump:
import threading
from multiprocessing import Lock
lock = Lock()
schema = {}
pid = os.getpid()
r, w = os.pipe()
def _check(name, arg):
try:
schema[name] == arg
except KeyError:
schema[name] = arg
def check(name, args):
arg = Argument(args, True)
if pid == os.getpid():
_check(name, arg)
else:
with lock:
os.write(w, '%s%r\n' % (name, arg))
def check_thread(r):
for x in os.fdopen(r):
_check(*Argument.load(x))
check_thread = threading.Thread(target=check_thread, args=(r,))
check_thread.daemon = True
check_thread.start()
else:
with open(path) as p:
x = p.readline()
assert x[0] == '#', x
schema = dict(map(FrozenArgument.load, p))
def check(name, args):
arg = Argument(args, True)
if not (None is not schema.get(name) == arg):
raise Exception('invalid packet: %s%r' % (name, arg))
w = None
Packet_encode = Packet.__dict__['encode']
def encode(packet):
check(type(packet).__name__, packet._args)
return Packet_encode(packet)
Packet.encode = encode
try:
yield
finally:
Packet.encode = Packet_encode
if w:
os.close(w)
check_thread.join()
if dump:
with open(path, 'w') as p:
p.write('# generated by running the whole test suite with -p\n')
for x in sorted(schema.iteritems()):
p.write('%s%r\n' % x)
...@@ -56,7 +56,7 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -56,7 +56,7 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
self.app.pt = Mock({'getID': 1}) self.app.pt = Mock({'getID': 1})
count = len(self.app.nm.getList()) count = len(self.app.nm.getList())
self.assertRaises(ProtocolError, self.operation.notifyPartitionChanges, self.assertRaises(ProtocolError, self.operation.notifyPartitionChanges,
conn, 0, ()) conn, 0, 0, ())
self.assertEqual(self.app.pt.getID(), 1) self.assertEqual(self.app.pt.getID(), 1)
self.assertEqual(len(self.app.nm.getList()), count) self.assertEqual(len(self.app.nm.getList()), count)
calls = self.app.replicator.mockGetNamedCalls('removePartition') calls = self.app.replicator.mockGetNamedCalls('removePartition')
...@@ -84,13 +84,13 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -84,13 +84,13 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
ptid = 2 ptid = 2
app.dm = Mock({ }) app.dm = Mock({ })
app.replicator = Mock({}) app.replicator = Mock({})
self.operation.notifyPartitionChanges(conn, ptid, cells) self.operation.notifyPartitionChanges(conn, ptid, 1, cells)
# ptid set # ptid set
self.assertEqual(app.pt.getID(), ptid) self.assertEqual(app.pt.getID(), ptid)
# dm call # dm call
calls = self.app.dm.mockGetNamedCalls('changePartitionTable') calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(ptid, cells) calls[0].checkArgs(ptid, 1, cells)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
#
# Copyright (C) 2009-2019 Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest
from ..mock import Mock
from .. import NeoUnitTestBase
from neo.storage.app import Application
from neo.lib.protocol import CellStates
from neo.lib.pt import PartitionTable
class StorageAppTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.prepareDatabase(number=1)
# create an application object
config = self.getStorageConfiguration(master_number=1)
self.app = Application(config)
def _tearDown(self, success):
self.app.close()
del self.app
super(StorageAppTests, self)._tearDown(success)
def test_01_loadPartitionTable(self):
self.app.dm = Mock({
'getPartitionTable': [],
})
self.assertEqual(self.app.pt, None)
num_partitions = 3
num_replicas = 2
self.app.pt = PartitionTable(num_partitions, num_replicas)
self.assertFalse(self.app.pt.getNodeSet())
self.assertFalse(self.app.pt.filled())
for x in xrange(num_partitions):
self.assertFalse(self.app.pt.hasOffset(x))
# load an empty table
self.app.loadPartitionTable()
self.assertFalse(self.app.pt.getNodeSet())
self.assertFalse(self.app.pt.filled())
for x in xrange(num_partitions):
self.assertFalse(self.app.pt.hasOffset(x))
# add some node, will be remove when loading table
master_uuid = self.getMasterUUID()
master = self.app.nm.createMaster(uuid=master_uuid)
storage_uuid = self.getStorageUUID()
storage = self.app.nm.createStorage(uuid=storage_uuid)
client_uuid = self.getClientUUID()
self.app.pt._setCell(0, master, CellStates.UP_TO_DATE)
self.app.pt._setCell(0, storage, CellStates.UP_TO_DATE)
self.assertEqual(len(self.app.pt.getNodeSet()), 2)
self.assertFalse(self.app.pt.filled())
for x in xrange(num_partitions):
if x == 0:
self.assertTrue(self.app.pt.hasOffset(x))
else:
self.assertFalse(self.app.pt.hasOffset(x))
# load an empty table, everything removed
self.app.loadPartitionTable()
self.assertFalse(self.app.pt.getNodeSet())
self.assertFalse(self.app.pt.filled())
for x in xrange(num_partitions):
self.assertFalse(self.app.pt.hasOffset(x))
# add some node
self.app.pt._setCell(0, master, CellStates.UP_TO_DATE)
self.app.pt._setCell(0, storage, CellStates.UP_TO_DATE)
self.assertEqual(len(self.app.pt.getNodeSet()), 2)
self.assertFalse(self.app.pt.filled())
for x in xrange(num_partitions):
if x == 0:
self.assertTrue(self.app.pt.hasOffset(x))
else:
self.assertFalse(self.app.pt.hasOffset(x))
# fill partition table
self.app.dm = Mock({
'getPartitionTable': [
(0, client_uuid, CellStates.UP_TO_DATE),
(1, client_uuid, CellStates.UP_TO_DATE),
(1, storage_uuid, CellStates.UP_TO_DATE),
(2, storage_uuid, CellStates.UP_TO_DATE),
(2, master_uuid, CellStates.UP_TO_DATE),
],
'getPTID': 1,
})
self.app.pt.clear()
self.app.loadPartitionTable()
self.assertTrue(self.app.pt.filled())
for x in xrange(num_partitions):
self.assertTrue(self.app.pt.hasOffset(x))
# check each row
cell_list = self.app.pt.getCellList(0)
self.assertEqual(len(cell_list), 1)
self.assertEqual(cell_list[0].getUUID(), client_uuid)
cell_list = self.app.pt.getCellList(1)
self.assertEqual(len(cell_list), 2)
self.assertTrue(cell_list[0].getUUID() in (client_uuid, storage_uuid))
self.assertTrue(cell_list[1].getUUID() in (client_uuid, storage_uuid))
cell_list = self.app.pt.getCellList(2)
self.assertEqual(len(cell_list), 2)
self.assertTrue(cell_list[0].getUUID() in (master_uuid, storage_uuid))
self.assertTrue(cell_list[1].getUUID() in (master_uuid, storage_uuid))
if __name__ == '__main__':
unittest.main()
...@@ -48,30 +48,15 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -48,30 +48,15 @@ class StorageDBTests(NeoUnitTestBase):
raise NotImplementedError raise NotImplementedError
def setNumPartitions(self, num_partitions, reset=0): def setNumPartitions(self, num_partitions, reset=0):
try: assert not hasattr(self, '_db')
db = self._db self._db = db = self.getDB(reset)
except AttributeError:
self._db = db = self.getDB(reset)
else:
if reset:
db.setup(reset)
else:
try:
n = db.getNumPartitions()
except KeyError:
n = 0
if num_partitions == n:
return
if num_partitions < n:
db.dropPartitions(n)
db.setNumPartitions(num_partitions)
self.assertEqual(num_partitions, db.getNumPartitions())
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
db.setUUID(uuid) db.setUUID(uuid)
self.assertEqual(uuid, db.getUUID()) self.assertEqual(uuid, db.getUUID())
db.changePartitionTable(1, db.changePartitionTable(1, 0,
[(i, uuid, CellStates.UP_TO_DATE) for i in xrange(num_partitions)], [(i, uuid, CellStates.UP_TO_DATE) for i in xrange(num_partitions)],
reset=True) reset=True)
self.assertEqual(num_partitions, 1 + db._getMaxPartition())
db.commit() db.commit()
def checkConfigEntry(self, get_call, set_call, value): def checkConfigEntry(self, get_call, set_call, value):
...@@ -102,16 +87,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -102,16 +87,6 @@ class StorageDBTests(NeoUnitTestBase):
db = self.getDB() db = self.getDB()
self.checkConfigEntry(db.getName, db.setName, 'TEST_NAME') self.checkConfigEntry(db.getName, db.setName, 'TEST_NAME')
def test_getPartitionTable(self):
db = self.getDB()
db.setNumPartitions(3)
uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID()
cell1 = (0, uuid1, CellStates.OUT_OF_DATE)
cell2 = (1, uuid1, CellStates.UP_TO_DATE)
db.changePartitionTable(1, [cell1, cell2], 1)
result = db.getPartitionTable()
self.assertEqual(set(result), {cell1, cell2})
def getOIDs(self, count): def getOIDs(self, count):
return map(p64, xrange(count)) return map(p64, xrange(count))
...@@ -202,52 +177,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -202,52 +177,6 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1, before_tid=tid2), self.assertEqual(self.db.getObject(oid1, before_tid=tid2),
OBJECT_T1_NEXT) OBJECT_T1_NEXT)
def test_setPartitionTable(self):
db = self.getDB()
db.setNumPartitions(3)
ptid = 1
uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE
cell3 = 1, uuid, CellStates.DISCARDED
# no partition table
self.assertEqual(list(db.getPartitionTable()), [])
# set one
db.changePartitionTable(ptid, [cell1], 1)
result = db.getPartitionTable()
self.assertEqual(list(result), [cell1])
# then another
db.changePartitionTable(ptid, [cell2], 1)
result = db.getPartitionTable()
self.assertEqual(list(result), [cell2])
# drop discarded cells
db.changePartitionTable(ptid, [cell2, cell3], 1)
result = db.getPartitionTable()
self.assertEqual(list(result), [])
def test_changePartitionTable(self):
db = self.getDB()
db.setNumPartitions(3)
ptid = 1
uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE
cell3 = 1, uuid, CellStates.DISCARDED
# no partition table
self.assertEqual(list(db.getPartitionTable()), [])
# set one
db.changePartitionTable(ptid, [cell1])
result = db.getPartitionTable()
self.assertEqual(list(result), [cell1])
# add more entries
db.changePartitionTable(ptid, [cell2])
result = db.getPartitionTable()
self.assertEqual(set(result), {cell1, cell2})
# drop discarded cells
db.changePartitionTable(ptid, [cell2, cell3])
result = db.getPartitionTable()
self.assertEqual(list(result), [cell1])
def test_commitTransaction(self): def test_commitTransaction(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
......
...@@ -22,7 +22,7 @@ from MySQLdb.constants.ER import UNKNOWN_STORAGE_ENGINE ...@@ -22,7 +22,7 @@ from MySQLdb.constants.ER import UNKNOWN_STORAGE_ENGINE
from ..mock import Mock from ..mock import Mock
from neo.lib.protocol import ZERO_OID from neo.lib.protocol import ZERO_OID
from neo.lib.util import p64 from neo.lib.util import p64
from .. import DB_PREFIX, DB_SOCKET, DB_USER, Patch from .. import DB_PREFIX, DB_USER, Patch, setupMySQLdb
from .testStorageDBTests import StorageDBTests from .testStorageDBTests import StorageDBTests
from neo.storage.database import DatabaseFailure from neo.storage.database import DatabaseFailure
from neo.storage.database.mysqldb import MySQLDatabaseManager from neo.storage.database.mysqldb import MySQLDatabaseManager
...@@ -46,8 +46,8 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -46,8 +46,8 @@ class StorageMySQLdbTests(StorageDBTests):
engine = None engine = None
def _test_lockDatabase_open(self): def _test_lockDatabase_open(self):
self.prepareDatabase(number=1, prefix=DB_PREFIX) self.prepareDatabase(1)
database = '%s@%s0%s' % (DB_USER, DB_PREFIX, DB_SOCKET) database = self.db_template(0)
return MySQLDatabaseManager(database, self.engine) return MySQLDatabaseManager(database, self.engine)
def getDB(self, reset=0): def getDB(self, reset=0):
......
...@@ -19,12 +19,9 @@ class Handler(MasterEventHandler): ...@@ -19,12 +19,9 @@ class Handler(MasterEventHandler):
super(Handler, self).answerClusterState(conn, state) super(Handler, self).answerClusterState(conn, state)
self.app.refresh('state') self.app.refresh('state')
def answerPartitionTable(self, *args):
super(Handler, self).answerPartitionTable(*args)
self.app.refresh('pt')
def sendPartitionTable(self, *args): def sendPartitionTable(self, *args):
raise AssertionError super(Handler, self).sendPartitionTable(*args)
self.app.refresh('pt')
def notifyPartitionChanges(self, *args): def notifyPartitionChanges(self, *args):
super(Handler, self).notifyPartitionChanges(*args) super(Handler, self).notifyPartitionChanges(*args)
...@@ -50,6 +47,7 @@ class StressApplication(AdminApplication): ...@@ -50,6 +47,7 @@ class StressApplication(AdminApplication):
cluster_state = server = uuid = None cluster_state = server = uuid = None
listening_conn = True listening_conn = True
fault_probability = 1
restart_ratio = float('inf') # no firewall support restart_ratio = float('inf') # no firewall support
_stress = False _stress = False
...@@ -191,7 +189,7 @@ class StressApplication(AdminApplication): ...@@ -191,7 +189,7 @@ class StressApplication(AdminApplication):
self.loid = loid self.loid = loid
self.ltid = ltid self.ltid = ltid
self.em.setTimeout(int(time.time() + 1), self.askLastIDs) self.em.setTimeout(int(time.time() + 1), self.askLastIDs)
if self._stress: if self._stress and random.random() < self.fault_probability:
node_list = self.nm.getStorageList() node_list = self.nm.getStorageList()
random.shuffle(node_list) random.shuffle(node_list)
fw = [] fw = []
......
...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase): ...@@ -33,9 +33,9 @@ class HandlerTests(NeoUnitTestBase):
def getFakePacket(self): def getFakePacket(self):
p = Mock({ p = Mock({
'decode': (),
'__repr__': 'Fake Packet', '__repr__': 'Fake Packet',
}) })
p._args = ()
p.handler_method_name = 'fake_method' p.handler_method_name = 'fake_method'
return p return p
...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase): ...@@ -53,13 +53,6 @@ class HandlerTests(NeoUnitTestBase):
self.handler.dispatch(conn, packet) self.handler.dispatch(conn, packet)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
self.checkAborted(conn) self.checkAborted(conn)
# raise PacketMalformedError
conn.mockCalledMethods = {}
def fake(c):
raise PacketMalformedError('message')
self.setFakeMethod(fake)
self.handler.dispatch(conn, packet)
self.checkClosed(conn)
# raise NotReadyError # raise NotReadyError
conn.mockCalledMethods = {} conn.mockCalledMethods = {}
def fake(c): def fake(c):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import unittest import unittest
import socket import socket
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.lib.util import ReadBuffer, parseNodeAddress from neo.lib.util import parseNodeAddress
class UtilTests(NeoUnitTestBase): class UtilTests(NeoUnitTestBase):
...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase): ...@@ -40,24 +40,6 @@ class UtilTests(NeoUnitTestBase):
self.assertIn(parseNodeAddress('localhost'), local_address(0)) self.assertIn(parseNodeAddress('localhost'), local_address(0))
self.assertIn(parseNodeAddress('localhost:10'), local_address(10)) self.assertIn(parseNodeAddress('localhost:10'), local_address(10))
def testReadBufferRead(self):
""" Append some chunk then consume the data """
buf = ReadBuffer()
self.assertEqual(len(buf), 0)
buf.append('abc')
self.assertEqual(len(buf), 3)
# no enough data
self.assertEqual(buf.read(4), None)
self.assertEqual(len(buf), 3)
buf.append('def')
# consume a part
self.assertEqual(len(buf), 6)
self.assertEqual(buf.read(4), 'abcd')
self.assertEqual(len(buf), 2)
# consume the rest
self.assertEqual(buf.read(3), None)
self.assertEqual(buf.read(2), 'ef')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -40,7 +40,7 @@ from neo.lib.util import cached_property, parseMasterList, p64 ...@@ -40,7 +40,7 @@ from neo.lib.util import cached_property, parseMasterList, p64
from neo.master.recovery import RecoveryManager from neo.master.recovery import RecoveryManager
from .. import (getTempDirectory, setupMySQLdb, from .. import (getTempDirectory, setupMySQLdb,
ImporterConfigParser, NeoTestBase, Patch, ImporterConfigParser, NeoTestBase, Patch,
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_SOCKET, DB_USER) ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX)
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0 BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]) LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
...@@ -304,7 +304,13 @@ class TestSerialized(Serialized): ...@@ -304,7 +304,13 @@ class TestSerialized(Serialized):
class Node(object): class Node(object):
def getConnectionList(self, *peers): def getConnectionList(self, *peers):
addr = lambda c: c and (c.addr if c.is_server else c.getAddress()) def addr(c):
# Do not identify only by source address because 2 TCP connections
# can have same source host:port to different destinations.
if c:
a = c.addr
b = c.getAddress()
return (b, a) if c.is_server else (ServerNode.resolv(a), b)
addr_set = {addr(c.connector) for peer in peers addr_set = {addr(c.connector) for peer in peers
for c in peer.em.connection_dict.itervalues() for c in peer.em.connection_dict.itervalues()
if isinstance(c, Connection)} if isinstance(c, Connection)}
...@@ -377,7 +383,10 @@ class ServerNode(Node): ...@@ -377,7 +383,10 @@ class ServerNode(Node):
assert not self.is_alive() assert not self.is_alive()
init_args = self._init_args init_args = self._init_args
init_args['reset'] = False init_args['reset'] = False
assert set(kw).issubset(init_args), (kw, init_args) if __debug__:
x = set(kw).difference(init_args)
assert not x or x.issubset(self.option_parser.getOptionDict()), (
kw, init_args)
init_args.update(kw) init_args.update(kw)
self.close() self.close()
self.__init__(**init_args) self.__init__(**init_args)
...@@ -708,7 +717,7 @@ class NEOCluster(object): ...@@ -708,7 +717,7 @@ class NEOCluster(object):
def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None, def __init__(self, master_count=1, partitions=1, replicas=0, upstream=None,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'), adapter=os.getenv('NEO_TESTS_ADAPTER', 'SQLite'),
storage_count=None, db_list=None, clear_databases=True, storage_count=None, db_list=None, clear_databases=True,
db_user=DB_USER, db_password='', compress=True, compress=True,
importer=None, autostart=None, dedup=False, name=None): importer=None, autostart=None, dedup=False, name=None):
self.name = name or 'neo_%s' % self._allocate('name', self.name = name or 'neo_%s' % self._allocate('name',
lambda: random.randint(0, 100)) lambda: random.randint(0, 100))
...@@ -735,21 +744,20 @@ class NEOCluster(object): ...@@ -735,21 +744,20 @@ class NEOCluster(object):
db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index)) db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index))
for _ in xrange(storage_count)] for _ in xrange(storage_count)]
if adapter == 'MySQL': if adapter == 'MySQL':
setupMySQLdb(db_list, db_user, db_password, clear_databases) db = setupMySQLdb(db_list, clear_databases)
db = '%s:%s@%%s%s' % (db_user, db_password, DB_SOCKET)
elif adapter == 'SQLite': elif adapter == 'SQLite':
db = os.path.join(getTempDirectory(), '%s.sqlite') db = os.path.join(getTempDirectory(), '%s.sqlite').__mod__
else: else:
assert False, adapter assert False, adapter
if importer: if importer:
cfg = ImporterConfigParser(adapter, **importer) cfg = ImporterConfigParser(adapter, **importer)
cfg.set("neo", "database", db % tuple(db_list)) cfg.set("neo", "database", db(*db_list))
db = os.path.join(getTempDirectory(), '%s.conf') db = os.path.join(getTempDirectory(), '%s.conf').__mod__
with open(db % tuple(db_list), "w") as f: with open(db(*db_list), "w") as f:
cfg.write(f) cfg.write(f)
kw["adapter"] = "Importer" kw["adapter"] = "Importer"
kw['wait'] = 0 kw['wait'] = 0
self.storage_list = [StorageApplication(database=db % x, **kw) self.storage_list = [StorageApplication(database=db(x), **kw)
for x in db_list] for x in db_list]
self.admin_list = [AdminApplication(**kw)] self.admin_list = [AdminApplication(**kw)]
...@@ -805,7 +813,7 @@ class NEOCluster(object): ...@@ -805,7 +813,7 @@ class NEOCluster(object):
master_list = self.master_list master_list = self.master_list
if storage_list is None: if storage_list is None:
storage_list = self.storage_list storage_list = self.storage_list
def answerPartitionTable(release, orig, *args): def sendPartitionTable(release, orig, *args):
orig(*args) orig(*args)
release() release()
def dispatch(release, orig, handler, *args): def dispatch(release, orig, handler, *args):
...@@ -821,7 +829,7 @@ class NEOCluster(object): ...@@ -821,7 +829,7 @@ class NEOCluster(object):
if state in expected_state: if state in expected_state:
release() release()
with Serialized.until(MasterEventHandler, with Serialized.until(MasterEventHandler,
answerPartitionTable=answerPartitionTable) as tic1, \ sendPartitionTable=sendPartitionTable) as tic1, \
Serialized.until(RecoveryManager, dispatch=dispatch) as tic2, \ Serialized.until(RecoveryManager, dispatch=dispatch) as tic2, \
Serialized.until(MasterEventHandler, Serialized.until(MasterEventHandler,
notifyClusterInformation=notifyClusterInformation) as tic3: notifyClusterInformation=notifyClusterInformation) as tic3:
...@@ -846,9 +854,13 @@ class NEOCluster(object): ...@@ -846,9 +854,13 @@ class NEOCluster(object):
expected_state = (NodeStates.PENDING expected_state = (NodeStates.PENDING
if state == ClusterStates.RECOVERING if state == ClusterStates.RECOVERING
else NodeStates.RUNNING) else NodeStates.RUNNING)
for node in self.storage_list if storage_list is None else storage_list: for node, expected_state in (
storage_list if isinstance(storage_list, dict) else
dict.fromkeys(self.storage_list if storage_list is None else
storage_list, expected_state)
).iteritems():
state = self.getNodeState(node) state = self.getNodeState(node)
assert state == expected_state, (repr(node), state) assert state == expected_state, (repr(node), state, expected_state)
def stop(self, clear_database=False, __print_exc=traceback.print_exc, **kw): def stop(self, clear_database=False, __print_exc=traceback.print_exc, **kw):
if self.started: if self.started:
...@@ -922,7 +934,7 @@ class NEOCluster(object): ...@@ -922,7 +934,7 @@ class NEOCluster(object):
def startCluster(self): def startCluster(self):
try: try:
self.neoctl.startCluster() self.neoctl.startCluster()
except RuntimeError: except SystemExit:
Serialized.tic() Serialized.tic()
if self.neoctl.getClusterState() not in ( if self.neoctl.getClusterState() not in (
ClusterStates.BACKINGUP, ClusterStates.BACKINGUP,
...@@ -1001,18 +1013,18 @@ class NEOCluster(object): ...@@ -1001,18 +1013,18 @@ class NEOCluster(object):
"""Sort storages so that storage_list[i] has partition i for all i""" """Sort storages so that storage_list[i] has partition i for all i"""
pt = [{x.getUUID() for x in x} pt = [{x.getUUID() for x in x}
for x in self.primary_master.pt.partition_list] for x in self.primary_master.pt.partition_list]
n = len(self.storage_list)
r = [] r = []
x = [iter(pt[0])] x = [iter(pt[0])]
try: while 1:
while 1: try:
try: r.append(next(x[-1]))
r.append(next(x[-1])) except StopIteration:
except StopIteration: del r[-1], x[-1]
del r[-1], x[-1] else:
else: if len(r) == n:
x.append(iter(pt[len(r)].difference(r))) break
except IndexError: x.append(iter(pt[len(r)].difference(r)))
assert len(r) == len(self.storage_list)
x = {x.uuid: x for x in self.storage_list} x = {x.uuid: x for x in self.storage_list}
self.storage_list[:] = (x[r] for r in r) self.storage_list[:] = (x[r] for r in r)
return self.storage_list return self.storage_list
......
...@@ -42,6 +42,7 @@ from neo.lib.util import add64, makeChecksum, p64, u64 ...@@ -42,6 +42,7 @@ from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.transactions import Transaction from neo.client.transactions import Transaction
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.master.pt import PartitionTable
from neo.storage.database import DatabaseFailure from neo.storage.database import DatabaseFailure
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.storage.handlers.identification import IdentificationHandler from neo.storage.handlers.identification import IdentificationHandler
...@@ -148,6 +149,7 @@ class Test(NEOThreadedTest): ...@@ -148,6 +149,7 @@ class Test(NEOThreadedTest):
c.root()[0] = ob = PCounterWithResolution() c.root()[0] = ob = PCounterWithResolution()
t.commit() t.commit()
tids = [] tids = []
c.readCurrent(c.root())
for x in inc: for x in inc:
ob.value += x ob.value += x
t.commit() t.commit()
...@@ -424,6 +426,42 @@ class Test(NEOThreadedTest): ...@@ -424,6 +426,42 @@ class Test(NEOThreadedTest):
self.assertEqual([x['tid'] for x in c1.db().history(oid, size=10)], self.assertEqual([x['tid'] for x in c1.db().history(oid, size=10)],
[tid3, tid2, tid1, tid0]) [tid3, tid2, tid1, tid0])
@with_cluster()
def testSlowConflictResolution(self, cluster):
"""
Check that a slow conflict resolution does not always result in a new
conflict because a concurrent client keeps modifying the same object
quickly.
An idea to fix it is to take the lock before the second attempt to
resolve.
"""
t1, c1 = cluster.getTransaction()
c1.root()[''] = ob = PCounterWithResolution()
t1.commit()
l1 = threading.Lock(); l1.acquire()
l2 = threading.Lock(); l2.acquire()
conflicts = []
def _p_resolveConflict(orig, *args):
conflicts.append(get_ident())
l1.release(); l2.acquire()
return orig(*args)
with cluster.newClient(1) as db, Patch(PCounterWithResolution,
_p_resolveConflict=_p_resolveConflict):
t2, c2 = cluster.getTransaction(db)
c2.root()[''].value += 1
for i in xrange(10):
ob.value += 1
t1.commit()
if i:
l2.release()
else:
t = self.newThread(t2.commit)
l1.acquire()
l2.release()
t.join()
with self.expectedFailure(): \
self.assertIn(get_ident(), conflicts)
@with_cluster() @with_cluster()
def testDelayedLoad(self, cluster): def testDelayedLoad(self, cluster):
""" """
...@@ -471,6 +509,7 @@ class Test(NEOThreadedTest): ...@@ -471,6 +509,7 @@ class Test(NEOThreadedTest):
self.assertFalse(conn.isClosed()) self.assertFalse(conn.isClosed())
getCellSortKey = cluster.client.getCellSortKey getCellSortKey = cluster.client.getCellSortKey
self.assertEqual(getCellSortKey(s0, good), 0) self.assertEqual(getCellSortKey(s0, good), 0)
cluster.neoctl.killNode(s0.getUUID())
cluster.neoctl.dropNode(s0.getUUID()) cluster.neoctl.dropNode(s0.getUUID())
self.assertEqual([s1], cluster.client.nm.getStorageList()) self.assertEqual([s1], cluster.client.nm.getStorageList())
self.assertTrue(conn.isClosed()) self.assertTrue(conn.isClosed())
...@@ -776,6 +815,7 @@ class Test(NEOThreadedTest): ...@@ -776,6 +815,7 @@ class Test(NEOThreadedTest):
checkNodeState(NodeStates.RUNNING) checkNodeState(NodeStates.RUNNING)
self.assertEqual([], cluster.getOutdatedCells()) self.assertEqual([], cluster.getOutdatedCells())
# drop one # drop one
cluster.neoctl.killNode(s1.uuid)
cluster.neoctl.dropNode(s1.uuid) cluster.neoctl.dropNode(s1.uuid)
checkNodeState(None) checkNodeState(None)
self.tic() # Let node state update reach remaining storage self.tic() # Let node state update reach remaining storage
...@@ -1123,6 +1163,10 @@ class Test(NEOThreadedTest): ...@@ -1123,6 +1163,10 @@ class Test(NEOThreadedTest):
# Check that the storage hasn't answered to the store, # Check that the storage hasn't answered to the store,
# which means that a lock is still taken for r['x'] by t2. # which means that a lock is still taken for r['x'] by t2.
self.tic() self.tic()
try:
txn = txn.data(c1)
except (AttributeError, KeyError): # BBB: ZODB < 5
pass
txn_context = cluster.client._txn_container.get(txn) txn_context = cluster.client._txn_container.get(txn)
empty = txn_context.queue.empty() empty = txn_context.queue.empty()
ll() ll()
...@@ -1202,7 +1246,7 @@ class Test(NEOThreadedTest): ...@@ -1202,7 +1246,7 @@ class Test(NEOThreadedTest):
# Also check that the master reset the last oid to a correct value. # Also check that the master reset the last oid to a correct value.
t.begin() t.begin()
self.assertEqual(1, u64(c.root()['x']._p_oid)) self.assertEqual(1, u64(c.root()['x']._p_oid))
self.assertFalse(cluster.client.new_oid_list) self.assertFalse(cluster.client.new_oids)
self.assertEqual(2, u64(cluster.client.new_oid())) self.assertEqual(2, u64(cluster.client.new_oid()))
@with_cluster() @with_cluster()
...@@ -1371,7 +1415,7 @@ class Test(NEOThreadedTest): ...@@ -1371,7 +1415,7 @@ class Test(NEOThreadedTest):
del conn._queue[:] # XXX del conn._queue[:] # XXX
conn.close() conn.close()
if 1: if 1:
with Patch(cluster.master.pt, make=make), \ with Patch(PartitionTable, make=make), \
Patch(InitializationHandler, Patch(InitializationHandler,
askPartitionTable=askPartitionTable) as p: askPartitionTable=askPartitionTable) as p:
cluster.start() cluster.start()
...@@ -1562,7 +1606,7 @@ class Test(NEOThreadedTest): ...@@ -1562,7 +1606,7 @@ class Test(NEOThreadedTest):
bad.append(s.getDataLockInfo()) bad.append(s.getDataLockInfo())
s.dm.commit() s.dm.commit()
def check(dry_run, expected): def check(dry_run, expected):
cluster.neoctl.repair(node_list, dry_run) cluster.neoctl.repair(node_list, bool(dry_run))
for e, s in zip(expected, cluster.storage_list): for e, s in zip(expected, cluster.storage_list):
while 1: while 1:
self.tic() self.tic()
...@@ -1902,18 +1946,7 @@ class Test(NEOThreadedTest): ...@@ -1902,18 +1946,7 @@ class Test(NEOThreadedTest):
x.value += 1 x.value += 1
c2.root()['x'].value += 2 c2.root()['x'].value += 2
TransactionalResource(t1, 1, tpc_begin=begin1) TransactionalResource(t1, 1, tpc_begin=begin1)
# BUG: Very rarely, getConnectionList returns more that 1 s1m, = s1.getConnectionList(cluster.master)
# connection ("too many values to unpack"), which is
# a mystery and impossible to reproduce:
# - 1st time: v1.8.1 on a test machine (no SSL)
# - last: current revision on my laptop (SSL),
# at the first iteration of this loop
_sm = list(s1.getConnectionList(cluster.master))
try:
s1m, = _sm
except ValueError:
self.fail((_sm, list(
s1.getConnectionList(cluster.master))))
try: try:
s1.em.removeReader(s1m) s1.em.removeReader(s1m)
with ConnectionFilter() as f, \ with ConnectionFilter() as f, \
...@@ -1979,7 +2012,7 @@ class Test(NEOThreadedTest): ...@@ -1979,7 +2012,7 @@ class Test(NEOThreadedTest):
except threading.ThreadError: except threading.ThreadError:
l[j].acquire() l[j].acquire()
threads[j-1].start() threads[j-1].start()
if x != 'StoreTransaction': if x != 'AskStoreTransaction':
try: try:
l[i].acquire() l[i].acquire()
except IndexError: except IndexError:
...@@ -2056,15 +2089,16 @@ class Test(NEOThreadedTest): ...@@ -2056,15 +2089,16 @@ class Test(NEOThreadedTest):
x = self._testComplexDeadlockAvoidanceWithOneStorage(changes, x = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, 0, 0, 1, 0, 0, 1), (1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, 0, 0, 1, 0, 0, 1),
('tpc_begin', 'tpc_begin', 1, 2, 3, 'tpc_begin', 1, 2, 4, 3, 4, ('tpc_begin', 'tpc_begin', 1, 2, 3, 'tpc_begin', 1, 2, 4, 3, 4,
'StoreTransaction', 'RebaseTransaction', 'RebaseTransaction', 'AskStoreTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AskRebaseTransaction', 'AnswerRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction'), 'AnswerRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction'),
[4, 6, 2, 6]) [4, 6, 2, 6])
try: try:
x[1].remove(1) x[1].remove(1)
except ValueError: except ValueError:
pass pass
self.assertEqual(x, {0: [2, 'StoreTransaction'], 1: ['tpc_abort']}) self.assertEqual(x, {0: [2, 'AskStoreTransaction'], 1: ['tpc_abort']})
def testCascadedDeadlockAvoidanceWithOneStorage2(self): def testCascadedDeadlockAvoidanceWithOneStorage2(self):
def changes(r1, r2, r3): def changes(r1, r2, r3):
...@@ -2087,8 +2121,8 @@ class Test(NEOThreadedTest): ...@@ -2087,8 +2121,8 @@ class Test(NEOThreadedTest):
(0, 1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1, (0, 1, 1, 0, 1, 2, 2, 2, 2, 0, 1, 2, 1,
0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1), 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1),
('tpc_begin', 1, 'tpc_begin', 1, 2, 3, 'tpc_begin', ('tpc_begin', 1, 'tpc_begin', 1, 2, 3, 'tpc_begin',
2, 3, 4, 3, 4, 'StoreTransaction', 'RebaseTransaction', 2, 3, 4, 3, 4, 'AskStoreTransaction', 'AskRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction'), 'AskRebaseTransaction', 'AnswerRebaseTransaction'),
[1, 7, 9, 0]) [1, 7, 9, 0])
x[0].sort(key=str) x[0].sort(key=str)
try: try:
...@@ -2097,8 +2131,8 @@ class Test(NEOThreadedTest): ...@@ -2097,8 +2131,8 @@ class Test(NEOThreadedTest):
pass pass
self.assertEqual(x, { self.assertEqual(x, {
0: [2, 3, 'AnswerRebaseTransaction', 0: [2, 3, 'AnswerRebaseTransaction',
'RebaseTransaction', 'StoreTransaction'], 'AskRebaseTransaction', 'AskStoreTransaction'],
1: ['AnswerRebaseTransaction','RebaseTransaction', 1: ['AnswerRebaseTransaction','AskRebaseTransaction',
'AnswerRebaseTransaction', 'tpc_abort'], 'AnswerRebaseTransaction', 'tpc_abort'],
}) })
...@@ -2131,7 +2165,7 @@ class Test(NEOThreadedTest): ...@@ -2131,7 +2165,7 @@ class Test(NEOThreadedTest):
end = self._testComplexDeadlockAvoidanceWithOneStorage(changes, end = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(0, 1, 1, 0, 1, 1, 0, 0, 2, 2, 2, 2, 1, vote_t2, tic_t1), (0, 1, 1, 0, 1, 1, 0, 0, 2, 2, 2, 2, 1, vote_t2, tic_t1),
('tpc_begin', 1) * 2, [3, 0, 0, 0], None) ('tpc_begin', 1) * 2, [3, 0, 0, 0], None)
self.assertLessEqual(2, end[0].count('RebaseTransaction')) self.assertLessEqual(2, end[0].count('AskRebaseTransaction'))
def testFailedConflictOnBigValueDuringDeadlockAvoidance(self): def testFailedConflictOnBigValueDuringDeadlockAvoidance(self):
def changes(r1, r2, r3): def changes(r1, r2, r3):
...@@ -2147,10 +2181,10 @@ class Test(NEOThreadedTest): ...@@ -2147,10 +2181,10 @@ class Test(NEOThreadedTest):
x = self._testComplexDeadlockAvoidanceWithOneStorage(changes, x = self._testComplexDeadlockAvoidanceWithOneStorage(changes,
(1, 1, 1, 2, 2, 2, 1, 2, 2, 0, 0, 1, 1, 1, 0), (1, 1, 1, 2, 2, 2, 1, 2, 2, 0, 0, 1, 1, 1, 0),
('tpc_begin', 'tpc_begin', 1, 2, 'tpc_begin', 1, 3, 3, 4, ('tpc_begin', 'tpc_begin', 1, 2, 'tpc_begin', 1, 3, 3, 4,
'StoreTransaction', 2, 4, 'RebaseTransaction', 'AskStoreTransaction', 2, 4, 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'tpc_abort'), 'AnswerRebaseTransaction', 'tpc_abort'),
[5, 1, 0, 2], POSException.ConflictError) [5, 1, 0, 2], POSException.ConflictError)
self.assertEqual(x, {0: ['StoreTransaction']}) self.assertEqual(x, {0: ['AskStoreTransaction']})
@with_cluster(replicas=1, partitions=4) @with_cluster(replicas=1, partitions=4)
def testNotifyReplicated(self, cluster): def testNotifyReplicated(self, cluster):
...@@ -2237,7 +2271,7 @@ class Test(NEOThreadedTest): ...@@ -2237,7 +2271,7 @@ class Test(NEOThreadedTest):
def delayConflict(conn, packet): def delayConflict(conn, packet):
app = self.getConnectionApp(conn) app = self.getConnectionApp(conn)
if (isinstance(packet, Packets.AnswerStoreObject) if (isinstance(packet, Packets.AnswerStoreObject)
and packet.decode()[0]): and packet._args[0]):
conn, = cluster.client.getConnectionList(app) conn, = cluster.client.getConnectionList(app)
kw = conn._handlers._pending[0][0][packet._id][1] kw = conn._handlers._pending[0][0][packet._id][1]
return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop() return 1 == u64(kw['oid']) and delay_conflict[app.uuid].pop()
...@@ -2255,8 +2289,9 @@ class Test(NEOThreadedTest): ...@@ -2255,8 +2289,9 @@ class Test(NEOThreadedTest):
self.thread_switcher(threads, self.thread_switcher(threads,
(1, 2, 3, 0, 1, 0, 2, t3_c, 1, 3, 2, t3_resolve, 0, 0, 0, (1, 2, 3, 0, 1, 0, 2, t3_c, 1, 3, 2, t3_resolve, 0, 0, 0,
t1_rebase, 2, t3_b, 3, t4_d, 0, 2, 2), t1_rebase, 2, t3_b, 3, t4_d, 0, 2, 2),
('tpc_begin', 'tpc_begin', 'tpc_begin', 'tpc_begin', 2, 1, 1, ('tpc_begin', 'tpc_begin', 'tpc_begin', 'tpc_begin',
3, 3, 4, 4, 3, 1, 'RebaseTransaction', 'RebaseTransaction', 2, 1, 1, 3, 3, 4, 4, 3, 1,
'AskRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 2 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 2
)) as end: )) as end:
delay = f.delayAskFetchTransactions() delay = f.delayAskFetchTransactions()
...@@ -2268,11 +2303,11 @@ class Test(NEOThreadedTest): ...@@ -2268,11 +2303,11 @@ class Test(NEOThreadedTest):
t4.begin() t4.begin()
self.assertEqual([15, 11, 13, 16], [r[x].value for x in 'abcd']) self.assertEqual([15, 11, 13, 16], [r[x].value for x in 'abcd'])
self.assertEqual([2, 2], map(end.pop(2).count, self.assertEqual([2, 2], map(end.pop(2).count,
['RebaseTransaction', 'AnswerRebaseTransaction'])) ['AskRebaseTransaction', 'AnswerRebaseTransaction']))
self.assertEqual(end, { self.assertEqual(end, {
0: [1, 'StoreTransaction'], 0: [1, 'AskStoreTransaction'],
1: ['StoreTransaction'], 1: ['AskStoreTransaction'],
3: [4, 'StoreTransaction'], 3: [4, 'AskStoreTransaction'],
}) })
self.assertFalse(s1.dm.getOrphanList()) self.assertFalse(s1.dm.getOrphanList())
...@@ -2308,7 +2343,8 @@ class Test(NEOThreadedTest): ...@@ -2308,7 +2343,8 @@ class Test(NEOThreadedTest):
self.thread_switcher((thread,), self.thread_switcher((thread,),
(1, 0, 1, 1, t2_b, 0, 0, 1, t2_vote, 0, 0), (1, 0, 1, 1, t2_b, 0, 0, 1, t2_vote, 0, 0),
('tpc_begin', 'tpc_begin', 1, 1, 2, 2, ('tpc_begin', 'tpc_begin', 1, 1, 2, 2,
'RebaseTransaction', 'RebaseTransaction', 'StoreTransaction', 'AskRebaseTransaction', 'AskRebaseTransaction',
'AskStoreTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
)) as end: )) as end:
delay = f.delayAskFetchTransactions() delay = f.delayAskFetchTransactions()
...@@ -2361,15 +2397,20 @@ class Test(NEOThreadedTest): ...@@ -2361,15 +2397,20 @@ class Test(NEOThreadedTest):
# Check that the storage hasn't answered to the store, # Check that the storage hasn't answered to the store,
# which means that a lock is still taken for r[''] by t1. # which means that a lock is still taken for r[''] by t1.
self.tic() self.tic()
try:
txn = txn.data(c3)
except (AttributeError, KeyError): # BBB: ZODB < 5
pass
txn_context = db.storage.app._txn_container.get(txn) txn_context = db.storage.app._txn_container.get(txn)
raise Abort(txn_context.queue.empty()) raise Abort(txn_context.queue.empty())
TransactionalResource(t3, 1, commit=t3_commit) TransactionalResource(t3, 1, commit=t3_commit)
with self.thread_switcher((commit23,), with self.thread_switcher((commit23,),
(1, 1, 0, 0, t1_rebase, 0, 0, 0, 1, 1, 1, 1, 0), (1, 1, 0, 0, t1_rebase, 0, 0, 0, 1, 1, 1, 1, 0),
('tpc_begin', 'tpc_begin', 0, 1, 0, ('tpc_begin', 'tpc_begin', 0, 1, 0,
'RebaseTransaction', 'RebaseTransaction', 'AskRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction', 'tpc_begin', 1, 'tpc_abort')) as end: 'AskStoreTransaction', 'tpc_begin', 1, 'tpc_abort',
)) as end:
self.assertRaises(POSException.ConflictError, t1.commit) self.assertRaises(POSException.ConflictError, t1.commit)
commit23.join() commit23.join()
self.assertEqual(end, {0: ['tpc_abort']}) self.assertEqual(end, {0: ['tpc_abort']})
...@@ -2407,8 +2448,8 @@ class Test(NEOThreadedTest): ...@@ -2407,8 +2448,8 @@ class Test(NEOThreadedTest):
for x in 'ab': for x in 'ab':
r[x] = PCounterWithResolution() r[x] = PCounterWithResolution()
t1.commit() t1.commit()
cluster.stop(replicas=1) cluster.neoctl.setNumReplicas(1)
cluster.start() self.tic()
s0, s1 = cluster.sortStorageList() s0, s1 = cluster.sortStorageList()
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
r = c1.root() r = c1.root()
...@@ -2456,9 +2497,9 @@ class Test(NEOThreadedTest): ...@@ -2456,9 +2497,9 @@ class Test(NEOThreadedTest):
self.thread_switcher((commit2,), self.thread_switcher((commit2,),
(1, 1, 0, 0, t1_b, t1_resolve, 0, 0, 0, 0, 1, t2_vote, t1_end), (1, 1, 0, 0, t1_b, t1_resolve, 0, 0, 0, 0, 1, t2_vote, t1_end),
('tpc_begin', 'tpc_begin', 2, 1, 2, 1, 1, ('tpc_begin', 'tpc_begin', 2, 1, 2, 1, 1,
'RebaseTransaction', 'RebaseTransaction', 'AskRebaseTransaction', 'AskRebaseTransaction',
'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction')) as end: 'AskStoreTransaction')) as end:
t1.commit() t1.commit()
commit2.join() commit2.join()
t1.begin() t1.begin()
...@@ -2466,7 +2507,7 @@ class Test(NEOThreadedTest): ...@@ -2466,7 +2507,7 @@ class Test(NEOThreadedTest):
self.assertEqual(r['a'].value, 9) self.assertEqual(r['a'].value, 9)
self.assertEqual(r['b'].value, 6) self.assertEqual(r['b'].value, 6)
t1 = end.pop(0) t1 = end.pop(0)
self.assertEqual(t1.pop(), 'StoreTransaction') self.assertEqual(t1.pop(), 'AskStoreTransaction')
self.assertEqual(sorted(t1), [1, 2]) self.assertEqual(sorted(t1), [1, 2])
self.assertFalse(end) self.assertFalse(end)
self.assertPartitionTable(cluster, 'UU|UU') self.assertPartitionTable(cluster, 'UU|UU')
...@@ -2568,9 +2609,9 @@ class Test(NEOThreadedTest): ...@@ -2568,9 +2609,9 @@ class Test(NEOThreadedTest):
with Patch(cluster.client, _loadFromStorage=load) as p, \ with Patch(cluster.client, _loadFromStorage=load) as p, \
self.thread_switcher((commit2,), self.thread_switcher((commit2,),
(1, 0, tic1, 0, t1_resolve, 1, t2_begin, 0, 1, 1, 0), (1, 0, tic1, 0, t1_resolve, 1, t2_begin, 0, 1, 1, 0),
('tpc_begin', 'tpc_begin', 1, 1, 1, 'StoreTransaction', ('tpc_begin', 'tpc_begin', 1, 1, 1, 'AskStoreTransaction',
'tpc_begin', 'RebaseTransaction', 'RebaseTransaction', 1, 'tpc_begin', 'AskRebaseTransaction', 'AskRebaseTransaction',
'StoreTransaction')) as end: 1, 'AskStoreTransaction')) as end:
self.assertRaisesRegexp(NEOStorageError, self.assertRaisesRegexp(NEOStorageError,
'^partition 0 not fully write-locked$', '^partition 0 not fully write-locked$',
t1.commit) t1.commit)
...@@ -2592,8 +2633,8 @@ class Test(NEOThreadedTest): ...@@ -2592,8 +2633,8 @@ class Test(NEOThreadedTest):
for x in 'ab': for x in 'ab':
r[x] = PCounterWithResolution() r[x] = PCounterWithResolution()
t1.commit() t1.commit()
cluster.stop(replicas=1) cluster.neoctl.setNumReplicas(1)
cluster.start() self.tic()
s0, s1 = cluster.sortStorageList() s0, s1 = cluster.sortStorageList()
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
r = c1.root() r = c1.root()
...@@ -2623,13 +2664,14 @@ class Test(NEOThreadedTest): ...@@ -2623,13 +2664,14 @@ class Test(NEOThreadedTest):
f.remove(delayFinish) f.remove(delayFinish)
with self.thread_switcher((commit2,), with self.thread_switcher((commit2,),
(1, 0, 0, 1, t2_b, 0, t1_resolve), (1, 0, 0, 1, t2_b, 0, t1_resolve),
('tpc_begin', 'tpc_begin', 0, 2, 2, 'StoreTransaction')) as end: ('tpc_begin', 'tpc_begin', 0, 2, 2, 'AskStoreTransaction')
) as end:
t1.commit() t1.commit()
commit2.join() commit2.join()
t1.begin() t1.begin()
self.assertEqual(c1.root()['b'].value, 6) self.assertEqual(c1.root()['b'].value, 6)
self.assertPartitionTable(cluster, 'UU|UU') self.assertPartitionTable(cluster, 'UU|UU')
self.assertEqual(end, {0: [2, 2, 'StoreTransaction']}) self.assertEqual(end, {0: [2, 2, 'AskStoreTransaction']})
self.assertFalse(s1.dm.getOrphanList()) self.assertFalse(s1.dm.getOrphanList())
@with_cluster(storage_count=2, partitions=2) @with_cluster(storage_count=2, partitions=2)
...@@ -2652,19 +2694,19 @@ class Test(NEOThreadedTest): ...@@ -2652,19 +2694,19 @@ class Test(NEOThreadedTest):
yield 1 yield 1
self.tic() self.tic()
with self.thread_switcher((t,), (1, 0, 1, 0, t1_b, 0, 0, 0, 1), with self.thread_switcher((t,), (1, 0, 1, 0, t1_b, 0, 0, 0, 1),
('tpc_begin', 'tpc_begin', 1, 3, 3, 1, 'RebaseTransaction', ('tpc_begin', 'tpc_begin', 1, 3, 3, 1, 'AskRebaseTransaction',
2, 'AnswerRebaseTransaction')) as end: 2, 'AnswerRebaseTransaction')) as end:
t1.commit() t1.commit()
t.join() t.join()
t2.begin() t2.begin()
self.assertEqual([6, 9, 6], [r[x].value for x in 'abc']) self.assertEqual([6, 9, 6], [r[x].value for x in 'abc'])
self.assertEqual([2, 2], map(end.pop(1).count, self.assertEqual([2, 2], map(end.pop(1).count,
['RebaseTransaction', 'AnswerRebaseTransaction'])) ['AskRebaseTransaction', 'AnswerRebaseTransaction']))
# Rarely, there's an extra deadlock for t1: # Rarely, there's an extra deadlock for t1:
# 0: ['AnswerRebaseTransaction', 'RebaseTransaction', # 0: ['AnswerRebaseTransaction', 'AskRebaseTransaction',
# 'RebaseTransaction', 'AnswerRebaseTransaction', # 'AskRebaseTransaction', 'AnswerRebaseTransaction',
# 'AnswerRebaseTransaction', 2, 3, 1, # 'AnswerRebaseTransaction', 2, 3, 1,
# 'StoreTransaction', 'VoteTransaction'] # 'AskStoreTransaction', 'VoteTransaction']
self.assertEqual(end.pop(0)[0], 'AnswerRebaseTransaction') self.assertEqual(end.pop(0)[0], 'AnswerRebaseTransaction')
self.assertFalse(end) self.assertFalse(end)
...@@ -2694,13 +2736,13 @@ class Test(NEOThreadedTest): ...@@ -2694,13 +2736,13 @@ class Test(NEOThreadedTest):
threads = map(self.newPausedThread, (t2.commit, t3.commit)) threads = map(self.newPausedThread, (t2.commit, t3.commit))
with self.thread_switcher(threads, (1, 2, 0, 1, 2, 1, 0, 2, 0, 1, 2), with self.thread_switcher(threads, (1, 2, 0, 1, 2, 1, 0, 2, 0, 1, 2),
('tpc_begin', 'tpc_begin', 'tpc_begin', 1, 2, 3, 4, 4, 4, ('tpc_begin', 'tpc_begin', 'tpc_begin', 1, 2, 3, 4, 4, 4,
'RebaseTransaction', 'StoreTransaction')) as end: 'AskRebaseTransaction', 'AskStoreTransaction')) as end:
t1.commit() t1.commit()
for t in threads: for t in threads:
t.join() t.join()
self.assertEqual(end, { self.assertEqual(end, {
0: ['AnswerRebaseTransaction', 'StoreTransaction'], 0: ['AnswerRebaseTransaction', 'AskStoreTransaction'],
2: ['StoreTransaction']}) 2: ['AskStoreTransaction']})
@with_cluster(replicas=1) @with_cluster(replicas=1)
def testConflictAfterDeadlockWithSlowReplica1(self, cluster, def testConflictAfterDeadlockWithSlowReplica1(self, cluster,
...@@ -2743,16 +2785,16 @@ class Test(NEOThreadedTest): ...@@ -2743,16 +2785,16 @@ class Test(NEOThreadedTest):
order[-1] = t1_resolve order[-1] = t1_resolve
delay = f.delayAskStoreObject() delay = f.delayAskStoreObject()
with self.thread_switcher((t,), order, with self.thread_switcher((t,), order,
('tpc_begin', 'tpc_begin', 1, 1, 2, 2, 'RebaseTransaction', ('tpc_begin', 'tpc_begin', 1, 1, 2, 2, 'AskRebaseTransaction',
'RebaseTransaction', 'AnswerRebaseTransaction', 'AskRebaseTransaction', 'AnswerRebaseTransaction',
'StoreTransaction')) as end: 'AskStoreTransaction')) as end:
t1.commit() t1.commit()
t.join() t.join()
self.assertNotIn(delay, f) self.assertNotIn(delay, f)
t2.begin() t2.begin()
end[0].sort(key=str) end[0].sort(key=str)
self.assertEqual(end, {0: [1, 'AnswerRebaseTransaction', self.assertEqual(end, {0: [1, 'AnswerRebaseTransaction',
'StoreTransaction']}) 'AskStoreTransaction']})
self.assertEqual([4, 2], [r[x].value for x in 'ab']) self.assertEqual([4, 2], [r[x].value for x in 'ab'])
def testConflictAfterDeadlockWithSlowReplica2(self): def testConflictAfterDeadlockWithSlowReplica2(self):
...@@ -2803,7 +2845,7 @@ class Test(NEOThreadedTest): ...@@ -2803,7 +2845,7 @@ class Test(NEOThreadedTest):
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.add(lambda conn, packet:
isinstance(packet, Packets.RequestIdentification) isinstance(packet, Packets.RequestIdentification)
and packet.decode()[0] == NodeTypes.STORAGE) and packet._args[0] == NodeTypes.STORAGE)
self.tic() self.tic()
m2.start() m2.start()
self.tic() self.tic()
...@@ -2843,7 +2885,7 @@ class Test(NEOThreadedTest): ...@@ -2843,7 +2885,7 @@ class Test(NEOThreadedTest):
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.add(lambda conn, packet: f.add(lambda conn, packet:
isinstance(packet, Packets.RequestIdentification) isinstance(packet, Packets.RequestIdentification)
and packet.decode()[0] == NodeTypes.MASTER) and packet._args[0] == NodeTypes.MASTER)
cluster.start(recovering=True) cluster.start(recovering=True)
neoctl = cluster.neoctl neoctl = cluster.neoctl
getClusterState = neoctl.getClusterState getClusterState = neoctl.getClusterState
...@@ -2894,9 +2936,9 @@ class Test(NEOThreadedTest): ...@@ -2894,9 +2936,9 @@ class Test(NEOThreadedTest):
dm = s.dm dm = s.dm
dm.commit() dm.commit()
dump_dict[s.uuid] = dm.dump() dump_dict[s.uuid] = dm.dump()
dm.erase()
with open(path % (s.getAdapter(), s.uuid)) as f: with open(path % (s.getAdapter(), s.uuid)) as f:
dm.restore(f.read()) dm.restore(f.read())
dm.setConfiguration('partitions', None) # XXX: see dm._migrate4
with NEOCluster(storage_count=3, partitions=3, replicas=1, with NEOCluster(storage_count=3, partitions=3, replicas=1,
name=self._testMethodName) as cluster: name=self._testMethodName) as cluster:
s1, s2, s3 = cluster.storage_list s1, s2, s3 = cluster.storage_list
......
...@@ -17,16 +17,19 @@ ...@@ -17,16 +17,19 @@
from cPickle import Pickler, Unpickler from cPickle import Pickler, Unpickler
from cStringIO import StringIO from cStringIO import StringIO
from itertools import izip_longest from itertools import izip_longest
import os, random, shutil, time, unittest import os, random, shutil, threading, time, unittest
import transaction, ZODB import transaction, ZODB
from neo.client.exception import NEOPrimaryMasterLost from neo.client.exception import NEOPrimaryMasterLost
from neo.lib import logging from neo.lib import logging
from neo.lib.util import u64 from neo.lib.util import cached_property, p64, u64
from neo.master.transactions import TransactionManager
from neo.storage.database import getAdapterKlass, importer, manager from neo.storage.database import getAdapterKlass, importer, manager
from neo.storage.database.importer import Repickler, TransactionRecord from neo.storage.database.importer import \
Repickler, TransactionRecord, WriteBack
from .. import expectedFailure, getTempDirectory, random_tree, Patch from .. import expectedFailure, getTempDirectory, random_tree, Patch
from . import NEOCluster, NEOThreadedTest from . import NEOCluster, NEOThreadedTest
from ZODB import serialize from ZODB import serialize
from ZODB.DB import TransactionalUndo
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
class Equal: class Equal:
...@@ -128,31 +131,51 @@ class ImporterTests(NEOThreadedTest): ...@@ -128,31 +131,51 @@ class ImporterTests(NEOThreadedTest):
self.assertIs(Obj, load()) self.assertIs(Obj, load())
self.assertDictEqual(state, load()) self.assertDictEqual(state, load())
def _importFromFileStorage(self, multi=(), @cached_property
root_filter=None, sub_filter=None): def getFS(self):
import_hash = '1d4ff03730fe6bcbf235e3739fbe5f5b' fs_dir = os.path.join(getTempDirectory(), self.id())
shutil.rmtree(fs_dir, 1) # for --loop
os.mkdir(fs_dir)
def getFS(db='root'):
path = os.path.join(fs_dir, '%s.fs' % db)
return path, {
"storage": "<filestorage>\npath %s\n</filestorage>" % path
}
return getFS
def getData(self, tree=random_tree.generateTree(random.Random(0))):
txn_size = 10 txn_size = 10
tree = random_tree.generateTree(random.Random(0))
i = len(tree) // 3 i = len(tree) // 3
assert i > txn_size assert i > txn_size
before_tree = tree[:i] before_tree = tree[:i]
after_tree = tree[i:] after_tree = tree[i:]
fs_dir = os.path.join(getTempDirectory(), self.id()) def beforeCheck(h, count=52):
shutil.rmtree(fs_dir, 1) # for --loop self.assertEqual(count, h())
os.mkdir(fs_dir) self.assertEqual('1d4ff03730fe6bcbf235e3739fbe5f5b', h.hexdigest())
def finalCheck(r):
h = random_tree.hashTree(r)
self.assertEqual(93, h())
self.assertEqual('6bf0f0cb2d6c1aae9e52c412ef0e25b6', h.hexdigest())
return (
beforeCheck,
lambda r, *f: random_tree.importTree(r, before_tree, txn_size, *f),
finalCheck,
lambda r: random_tree.importTree(r, after_tree, txn_size),
)
def _importFromFileStorage(self, multi=(),
root_filter=None, sub_filter=None):
beforeCheck, before, finalCheck, after = self.getData()
iter_list = [] iter_list = []
db_list = [] db_list = []
# Setup several FileStorage databases. # Setup several FileStorage databases.
for i, db in enumerate(('root',) + multi): for i, db in enumerate(('root',) + multi):
fs_path = os.path.join(fs_dir, '%s.fs' % db) fs_path, cfg = self.getFS(db)
c = ZODB.DB(FileStorage(fs_path)).open() c = ZODB.DB(FileStorage(fs_path)).open()
r = c.root()['tree'] = random_tree.Node() r = c.root()['tree'] = random_tree.Node()
transaction.commit() transaction.commit()
iter_list.append(random_tree.importTree(r, before_tree, txn_size, iter_list.append(before(r, sub_filter(db) if i else root_filter))
sub_filter(db) if i else root_filter)) db_list.append((db, r, cfg))
db_list.append((db, r, {
"storage": "<filestorage>\npath %s\n</filestorage>" % fs_path
}))
# Populate FileStorage databases. # Populate FileStorage databases.
for i, iter_list in enumerate(izip_longest(*iter_list)): for i, iter_list in enumerate(izip_longest(*iter_list)):
for r in iter_list: for r in iter_list:
...@@ -167,9 +190,7 @@ class ImporterTests(NEOThreadedTest): ...@@ -167,9 +190,7 @@ class ImporterTests(NEOThreadedTest):
for x in multi: for x in multi:
cfg['_%s' % x] = str(u64(r[x]._p_oid)) cfg['_%s' % x] = str(u64(r[x]._p_oid))
else: else:
h = random_tree.hashTree(r) beforeCheck(random_tree.hashTree(r))
h()
self.assertEqual(import_hash, h.hexdigest())
importer['writeback'] = 'true' importer['writeback'] = 'true'
else: else:
cfg["oid"] = str(u64(r[db]._p_oid)) cfg["oid"] = str(u64(r[db]._p_oid))
...@@ -179,7 +200,7 @@ class ImporterTests(NEOThreadedTest): ...@@ -179,7 +200,7 @@ class ImporterTests(NEOThreadedTest):
del db_list, iter_list del db_list, iter_list
#del zodb[0][1][zodb.pop()[0]] #del zodb[0][1][zodb.pop()[0]]
# Start NEO cluster with transparent import. # Start NEO cluster with transparent import.
with NEOCluster(importer=importer) as cluster: with NEOCluster(importer=importer, partitions=2) as cluster:
# Suspend import for a while, so that import # Suspend import for a while, so that import
# is finished in the middle of the below 'for' loop. # is finished in the middle of the below 'for' loop.
# Use a slightly different main loop for storage so that it # Use a slightly different main loop for storage so that it
...@@ -214,25 +235,26 @@ class ImporterTests(NEOThreadedTest): ...@@ -214,25 +235,26 @@ class ImporterTests(NEOThreadedTest):
logging.info("start migration") logging.info("start migration")
dm.doOperation(cluster.storage) dm.doOperation(cluster.storage)
# Adjust if needed. Must remain > 0. # Adjust if needed. Must remain > 0.
self.assertEqual(22, h()) beforeCheck(h, 22)
self.assertEqual(import_hash, h.hexdigest())
# New writes after the switch to NEO. # New writes after the switch to NEO.
last_import = -1 last_import = -1
for i, r in enumerate(random_tree.importTree( for i, r in enumerate(after(r)):
r, after_tree, txn_size)): c.readCurrent(r)
t.commit() t.commit()
if cluster.storage.dm._import: if cluster.storage.dm._import:
last_import = i last_import = i
for x in 0, 1:
undo = TransactionalUndo(c.db(), [storage.lastTransaction()])
txn = transaction.Transaction()
undo.tpc_begin(txn)
undo.commit(txn)
undo.tpc_vote(txn)
undo.tpc_finish(txn)
self.tic() self.tic()
# Same as above. We want last_import smaller enough compared to i # Same as above. We want last_import smaller enough compared to i
assert i < last_import * 3 < 2 * i, (last_import, i) assert i < last_import * 3 < 2 * i, (last_import, i)
self.assertFalse(cluster.storage.dm._import) self.assertFalse(cluster.storage.dm._import)
storage._cache.clear() storage._cache.clear()
def finalCheck(r):
h = random_tree.hashTree(r)
self.assertEqual(93, h())
self.assertEqual('6bf0f0cb2d6c1aae9e52c412ef0e25b6',
h.hexdigest())
finalCheck(r) finalCheck(r)
if dm._writeback: if dm._writeback:
dm.commit() dm.commit()
...@@ -243,10 +265,10 @@ class ImporterTests(NEOThreadedTest): ...@@ -243,10 +265,10 @@ class ImporterTests(NEOThreadedTest):
db.close() db.close()
@unittest.skipUnless(importer.FORK, 'no os.fork') @unittest.skipUnless(importer.FORK, 'no os.fork')
def test1(self): def testMultiProcessWriteBack(self):
self._importFromFileStorage() self._importFromFileStorage()
def testThreadedWriteback(self): def testThreadedWritebackAndDBReconnection(self):
# Also check reconnection to the underlying DB for relevant backends. # Also check reconnection to the underlying DB for relevant backends.
tid_list = [] tid_list = []
def __init__(orig, tr, db, tid): def __init__(orig, tr, db, tid):
...@@ -272,7 +294,25 @@ class ImporterTests(NEOThreadedTest): ...@@ -272,7 +294,25 @@ class ImporterTests(NEOThreadedTest):
Patch(time, sleep=sleep) as p: Patch(time, sleep=sleep) as p:
self._importFromFileStorage() self._importFromFileStorage()
self.assertFalse(p.applied) self.assertFalse(p.applied)
self.assertEqual(len(tid_list), 11) self.assertEqual(len(tid_list), 13)
def testThreadedWritebackWithUnbalancedPartitions(self):
N = 7
nonlocal_ = [0]
def committed(orig, self):
if nonlocal_[0] > N:
orig(self)
def _nextTID(orig, self, *args):
if args:
return orig(self, *args)
nonlocal_[0] += 1
return orig(self, p64(nonlocal_[0] == N), 2)
with Patch(importer, FORK=False), \
Patch(TransactionManager, _nextTID=_nextTID), \
Patch(WriteBack, chunk_size=N-2), \
Patch(WriteBack, committed=committed):
self._importFromFileStorage()
self.assertEqual(nonlocal_[0], 12)
def testMerge(self): def testMerge(self):
multi = 1, 2, 3 multi = 1, 2, 3
...@@ -285,5 +325,52 @@ class ImporterTests(NEOThreadedTest): ...@@ -285,5 +325,52 @@ class ImporterTests(NEOThreadedTest):
# merge several DB. # merge several DB.
testMerge = expectedFailure(NEOPrimaryMasterLost)(testMerge) testMerge = expectedFailure(NEOPrimaryMasterLost)(testMerge)
def testIncremental(self):
"""
This reproduces an undocumented way to speed up the import of a single
ZODB by doing most of the work before switching to NEO.
"""
beforeCheck, before, finalCheck, after = self.getData()
fs_path, cfg = self.getFS()
c = ZODB.DB(FileStorage(fs_path)).open()
r = c.root()['tree'] = random_tree.Node()
transaction.commit()
for _ in before(r):
transaction.commit()
c.db().close()
importer = {'zodb': [('root', cfg)]}
# Start NEO cluster with transparent import.
with NEOCluster(importer=importer, partitions=2) as cluster:
s = cluster.storage
l = threading.Lock()
l.acquire()
def _finished(orig):
orig()
l.release()
with Patch(s.dm, _finished=_finished):
cluster.start()
l.acquire()
t, c = cluster.getTransaction()
r = c.root()['tree']
beforeCheck(random_tree.hashTree(r))
c = ZODB.DB(FileStorage(fs_path)).open()
for _ in after(c.root()['tree']):
transaction.commit()
c.db().close()
# TODO: Add a storage option that only does this and exits.
# Such command would also check that there's no data after
# what's already imported.
s.dm.setConfiguration('zodb', None)
s.stop()
cluster.join((s,))
s.resetNode()
with Patch(s.dm, _finished=_finished):
s.start()
self.tic()
l.acquire()
t.begin()
finalCheck(r)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -29,7 +29,7 @@ from neo.storage.database.manager import DatabaseManager ...@@ -29,7 +29,7 @@ from neo.storage.database.manager import DatabaseManager
from neo.storage import replicator from neo.storage import replicator
from neo.lib.connector import SocketConnector from neo.lib.connector import SocketConnector
from neo.lib.connection import ClientConnection from neo.lib.connection import ClientConnection
from neo.lib.protocol import CellStates, ClusterStates, Packets, \ from neo.lib.protocol import CellStates, ClusterStates, NodeStates, Packets, \
ZERO_OID, ZERO_TID, MAX_TID, uuid_str ZERO_OID, ZERO_TID, MAX_TID, uuid_str
from neo.lib.util import add64, p64, u64 from neo.lib.util import add64, p64, u64
from .. import Patch, TransactionalResource from .. import Patch, TransactionalResource
...@@ -74,6 +74,8 @@ class ReplicationTests(NEOThreadedTest): ...@@ -74,6 +74,8 @@ class ReplicationTests(NEOThreadedTest):
source_dict = {x.uuid: x for x in cluster.upstream.storage_list} source_dict = {x.uuid: x for x in cluster.upstream.storage_list}
for storage in cluster.storage_list: for storage in cluster.storage_list:
self.assertFalse(storage.dm._uncommitted_data) self.assertFalse(storage.dm._uncommitted_data)
if storage.pt is None:
storage.loadPartitionTable()
self.assertEqual(np, storage.pt.getPartitions()) self.assertEqual(np, storage.pt.getPartitions())
for partition in pt.getAssignedPartitionList(storage.uuid): for partition in pt.getAssignedPartitionList(storage.uuid):
cell_list = upstream_pt.getCellList(partition, readable=True) cell_list = upstream_pt.getCellList(partition, readable=True)
...@@ -89,6 +91,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -89,6 +91,7 @@ class ReplicationTests(NEOThreadedTest):
checksum_list = [ checksum_list = [
self.checksumPartition(storage_dict[x.getUUID()], offset) self.checksumPartition(storage_dict[x.getUUID()], offset)
for x in pt.getCellList(offset)] for x in pt.getCellList(offset)]
self.assertLess(1, len(checksum_list))
self.assertEqual(1, len(set(checksum_list)), self.assertEqual(1, len(set(checksum_list)),
(offset, checksum_list)) (offset, checksum_list))
...@@ -103,7 +106,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -103,7 +106,7 @@ class ReplicationTests(NEOThreadedTest):
importZODB(3) importZODB(3)
def delaySecondary(conn, packet): def delaySecondary(conn, packet):
if isinstance(packet, Packets.Replicate): if isinstance(packet, Packets.Replicate):
tid, upstream_name, source_dict = packet.decode() tid, upstream_name, source_dict = packet._args
return not upstream_name and all(source_dict.itervalues()) return not upstream_name and all(source_dict.itervalues())
with NEOCluster(partitions=np, replicas=nr-1, storage_count=5, with NEOCluster(partitions=np, replicas=nr-1, storage_count=5,
upstream=upstream) as backup: upstream=upstream) as backup:
...@@ -346,6 +349,22 @@ class ReplicationTests(NEOThreadedTest): ...@@ -346,6 +349,22 @@ class ReplicationTests(NEOThreadedTest):
self.tic() self.tic()
self.assertTrue(backup.master.is_alive()) self.assertTrue(backup.master.is_alive())
@with_cluster(master_count=2)
def testBackupFromUpstreamWithSecondaryMaster(self, upstream):
"""
Check that the backup master reacts correctly when connecting first
to a secondary master of the upstream cluster.
"""
with NEOCluster(upstream=upstream) as backup:
primary = upstream.primary_master
m, = (m for m in upstream.master_list if m is not primary)
backup.master.resetNode(upstream_masters=[m.server])
backup.start()
backup.neoctl.setClusterState(ClusterStates.STARTING_BACKUP)
self.tic()
self.assertEqual(backup.neoctl.getClusterState(),
ClusterStates.BACKINGUP)
@backup_test() @backup_test()
def testCreationUndone(self, backup): def testCreationUndone(self, backup):
""" """
...@@ -443,15 +462,15 @@ class ReplicationTests(NEOThreadedTest): ...@@ -443,15 +462,15 @@ class ReplicationTests(NEOThreadedTest):
""" """
def delayAskFetch(conn, packet): def delayAskFetch(conn, packet):
return isinstance(packet, delayed) and \ return isinstance(packet, delayed) and \
packet.decode()[0] == offset and \ packet._args[0] == offset and \
conn in s1.getConnectionList(s0) conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, cell_list): def changePartitionTable(orig, ptid, num_replicas, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list: if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
connection_filter.remove(delayAskFetch) connection_filter.remove(delayAskFetch)
# XXX: this is currently not done by # XXX: this is currently not done by
# default for performance reason # default for performance reason
orig.im_self.dropPartitions((offset,)) orig.im_self.dropPartitions((offset,))
return orig(ptid, cell_list) return orig(ptid, num_replicas, cell_list)
np = cluster.num_partitions np = cluster.num_partitions
s0, s1, s2 = cluster.storage_list s0, s1, s2 = cluster.storage_list
for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects: for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects:
...@@ -511,7 +530,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -511,7 +530,9 @@ class ReplicationTests(NEOThreadedTest):
for x in 'ab': for x in 'ab':
r[x] = PCounter() r[x] = PCounter()
t.commit() t.commit()
cluster.stop(replicas=1) cluster.neoctl.setNumReplicas(1)
self.tic()
cluster.stop()
cluster.start((s1, s2)) cluster.start((s1, s2))
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.delayAddObject() f.delayAddObject()
...@@ -596,8 +617,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -596,8 +617,9 @@ class ReplicationTests(NEOThreadedTest):
tweak() tweak()
t.commit() t.commit()
t2.join() t2.join()
cluster.neoctl.dropNode(S[2].uuid) for s in S[2:]:
cluster.neoctl.dropNode(S[3].uuid) cluster.neoctl.killNode(s.uuid)
cluster.neoctl.dropNode(s.uuid)
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
if done: if done:
f.remove(delay) f.remove(delay)
...@@ -695,7 +717,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -695,7 +717,7 @@ class ReplicationTests(NEOThreadedTest):
def logReplication(conn, packet): def logReplication(conn, packet):
if isinstance(packet, (Packets.AskFetchTransactions, if isinstance(packet, (Packets.AskFetchTransactions,
Packets.AskFetchObjects)): Packets.AskFetchObjects)):
ask.append(packet.decode()[2:]) ask.append(packet._args[2:])
def getTIDList(): def getTIDList():
return [t.tid for t in c.db().storage.iterator()] return [t.tid for t in c.db().storage.iterator()]
s0, s1 = cluster.storage_list s0, s1 = cluster.storage_list
...@@ -796,7 +818,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -796,7 +818,7 @@ class ReplicationTests(NEOThreadedTest):
return True return True
elif not isinstance(packet, Packets.AskFetchTransactions): elif not isinstance(packet, Packets.AskFetchTransactions):
return return
ask.append(packet.decode()) ask.append(packet._args)
conn, = upstream.master.getConnectionList(backup.master) conn, = upstream.master.getConnectionList(backup.master)
with ConnectionFilter() as f, Patch(replicator.Replicator, with ConnectionFilter() as f, Patch(replicator.Replicator,
_nextPartitionSortKey=lambda orig, self, offset: offset): _nextPartitionSortKey=lambda orig, self, offset: offset):
...@@ -857,11 +879,11 @@ class ReplicationTests(NEOThreadedTest): ...@@ -857,11 +879,11 @@ class ReplicationTests(NEOThreadedTest):
@f.add @f.add
def delayReplicate(conn, packet): def delayReplicate(conn, packet):
if isinstance(packet, Packets.AskFetchTransactions): if isinstance(packet, Packets.AskFetchTransactions):
trans.append(packet.decode()[2]) trans.append(packet._args[2])
elif isinstance(packet, Packets.AskFetchObjects): elif isinstance(packet, Packets.AskFetchObjects):
if obj: if obj:
return True return True
obj.append(packet.decode()[2]) obj.append(packet._args[2])
s2.start() s2.start()
self.tic() self.tic()
cluster.neoctl.enableStorageList([s2.uuid]) cluster.neoctl.enableStorageList([s2.uuid])
...@@ -928,6 +950,74 @@ class ReplicationTests(NEOThreadedTest): ...@@ -928,6 +950,74 @@ class ReplicationTests(NEOThreadedTest):
def testReplicationBlockedByUnfinished2(self): def testReplicationBlockedByUnfinished2(self):
self.testReplicationBlockedByUnfinished1(True) self.testReplicationBlockedByUnfinished1(True)
@with_cluster(partitions=6, storage_count=5, start_cluster=0)
def testSplitAndMakeResilientUsingClone(self, cluster):
"""
Test cloning of storage nodes using --new-nid instead NEO replication.
"""
s0 = cluster.storage_list[0]
s12 = cluster.storage_list[1:3]
s34 = cluster.storage_list[3:]
cluster.start(storage_list=(s0,))
cluster.importZODB()(6)
for s in s12:
s.start()
self.tic()
drop_list = [s0.uuid]
self.assertRaises(SystemExit, cluster.neoctl.tweakPartitionTable,
drop_list)
cluster.enableStorageList(s12)
def expected(changed):
s0 = 1, CellStates.UP_TO_DATE
s = CellStates.OUT_OF_DATE if changed else CellStates.UP_TO_DATE
return changed, 3 * ((s0, (2, s)), (s0, (3, s)))
for dry_run in True, False:
self.assertEqual(expected(True),
cluster.neoctl.tweakPartitionTable(drop_list, dry_run))
self.tic()
self.assertEqual(expected(False),
cluster.neoctl.tweakPartitionTable(drop_list))
for s, d in zip(s12, s34):
s.stop()
cluster.join((s,))
s.resetNode()
d.dm.restore(s.dm.dump())
d.resetNode(new_nid=True)
s.start()
d.start()
self.tic()
self.assertEqual(cluster.getNodeState(s), NodeStates.RUNNING)
self.assertEqual(cluster.getNodeState(d), NodeStates.DOWN)
cluster.join((d,))
d.resetNode(new_nid=False)
d.start()
self.tic()
self.checkReplicas(cluster)
expected = '|'.join(['UU.U.|U.U.U'] * 3)
self.assertPartitionTable(cluster, expected)
cluster.neoctl.setNumReplicas(1)
cluster.neoctl.tweakPartitionTable(drop_list)
self.tic()
self.assertPartitionTable(cluster, expected)
s0.stop()
cluster.join((s0,))
cluster.neoctl.dropNode(s0.uuid)
expected = '|'.join(['U.U.|.U.U'] * 3)
self.assertPartitionTable(cluster, expected)
@with_cluster(partitions=3, replicas=1, storage_count=3)
def testAdminOnerousOperationCondition(self, cluster):
s = cluster.storage_list[2]
cluster.neoctl.killNode(s.uuid)
tweak = cluster.neoctl.tweakPartitionTable
self.assertRaises(SystemExit, tweak)
self.assertRaises(SystemExit, tweak, dry_run=True)
self.assertTrue(tweak((s.uuid,))[0])
self.tic()
cluster.neoctl.dropNode(s.uuid)
s = cluster.storage_list[1]
self.assertRaises(SystemExit, cluster.neoctl.dropNode, s.uuid)
@with_cluster(partitions=5, replicas=2, storage_count=3) @with_cluster(partitions=5, replicas=2, storage_count=3)
def testCheckReplicas(self, cluster): def testCheckReplicas(self, cluster):
from neo.storage import checker from neo.storage import checker
...@@ -940,8 +1030,8 @@ class ReplicationTests(NEOThreadedTest): ...@@ -940,8 +1030,8 @@ class ReplicationTests(NEOThreadedTest):
return s0.uuid return s0.uuid
def check(expected_state, expected_count): def check(expected_state, expected_count):
self.assertEqual(expected_count, len([None self.assertEqual(expected_count, len([None
for row in cluster.neoctl.getPartitionRowList()[1] for row in cluster.neoctl.getPartitionRowList()[2]
for cell in row[1] for cell in row
if cell[1] == CellStates.CORRUPTED])) if cell[1] == CellStates.CORRUPTED]))
self.assertEqual(expected_state, cluster.neoctl.getClusterState()) self.assertEqual(expected_state, cluster.neoctl.getClusterState())
np = cluster.num_partitions np = cluster.num_partitions
......
...@@ -33,8 +33,6 @@ class RecoveryTests(ZODBTestCase, StorageTestBase, RecoveryStorage): ...@@ -33,8 +33,6 @@ class RecoveryTests(ZODBTestCase, StorageTestBase, RecoveryStorage):
os.makedirs(dst_temp_dir) os.makedirs(dst_temp_dir)
self.neo_dst = NEOCluster(['test_neo1-dst'], partitions=1, replicas=0, self.neo_dst = NEOCluster(['test_neo1-dst'], partitions=1, replicas=0,
master_count=1, temp_dir=dst_temp_dir) master_count=1, temp_dir=dst_temp_dir)
self.neo_dst.stop()
self.neo_dst.setupDB()
self.neo_dst.start() self.neo_dst.start()
self._dst = self.neo.getZODBStorage() self._dst = self.neo.getZODBStorage()
self._dst_db = ZODB.DB(self._dst) self._dst_db = ZODB.DB(self._dst)
......
...@@ -53,7 +53,7 @@ extras_require = { ...@@ -53,7 +53,7 @@ extras_require = {
'master': [], 'master': [],
'storage-sqlite': [], 'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'], 'storage-mysqldb': ['mysqlclient'],
'storage-importer': zodb_require + ['msgpack>=0.5.6', 'setproctitle'], 'storage-importer': zodb_require + ['setproctitle'],
} }
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2', extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
'neoppod[%s]' % ', '.join(extras_require)] 'neoppod[%s]' % ', '.join(extras_require)]
...@@ -78,7 +78,7 @@ else: ...@@ -78,7 +78,7 @@ else:
setup( setup(
name = 'neoppod', name = 'neoppod',
version = '1.11', version = '1.12.0',
description = __doc__.strip(), description = __doc__.strip(),
author = 'Nexedi SA', author = 'Nexedi SA',
author_email = 'neo-dev@erp5.org', author_email = 'neo-dev@erp5.org',
...@@ -108,6 +108,7 @@ setup( ...@@ -108,6 +108,7 @@ setup(
], ],
}, },
install_requires = [ install_requires = [
'msgpack>=0.5.6',
'python-dateutil', # neolog --from 'python-dateutil', # neolog --from
], ],
extras_require = extras_require, extras_require = extras_require,
......
...@@ -129,7 +129,7 @@ class MatrixImportBenchmark(BenchmarkRunner): ...@@ -129,7 +129,7 @@ class MatrixImportBenchmark(BenchmarkRunner):
finally: finally:
zodb.stop() zodb.stop()
# Clear DB if no error happened. # Clear DB if no error happened.
zodb.setupDB() zodb.resetDB()
return end - start return end - start
except: except:
traceback.print_exc() traceback.print_exc()
......
...@@ -53,7 +53,7 @@ class ImportBenchmark(BenchmarkRunner): ...@@ -53,7 +53,7 @@ class ImportBenchmark(BenchmarkRunner):
finally: finally:
neo.stop() neo.stop()
# Clear DB if no error happened. # Clear DB if no error happened.
neo.setupDB() neo.resetDB()
return result return result
except: except:
return 'Perf: import failed', ''.join(traceback.format_exc()) return 'Perf: import failed', ''.join(traceback.format_exc())
......
...@@ -7,6 +7,7 @@ from contextlib import contextmanager ...@@ -7,6 +7,7 @@ from contextlib import contextmanager
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from multiprocessing import Lock, RawArray from multiprocessing import Lock, RawArray
from multiprocessing.queues import SimpleQueue
from struct import Struct from struct import Struct
from netfilterqueue import NetfilterQueue from netfilterqueue import NetfilterQueue
import gevent.socket # preload for subprocesses import gevent.socket # preload for subprocesses
...@@ -19,7 +20,7 @@ from neo.lib.protocol import NodeTypes ...@@ -19,7 +20,7 @@ from neo.lib.protocol import NodeTypes
from neo.lib.util import timeStringFromTID, p64, u64 from neo.lib.util import timeStringFromTID, p64, u64
from neo.storage.app import DATABASE_MANAGER_DICT, \ from neo.storage.app import DATABASE_MANAGER_DICT, \
Application as StorageApplication Application as StorageApplication
from neo.tests import getTempDirectory from neo.tests import getTempDirectory, mysql_pool
from neo.tests.ConflictFree import ConflictFreeLog from neo.tests.ConflictFree import ConflictFreeLog
from neo.tests.functional import AlreadyStopped, NEOCluster, Process from neo.tests.functional import AlreadyStopped, NEOCluster, Process
from neo.tests.stress import StressApplication from neo.tests.stress import StressApplication
...@@ -312,13 +313,15 @@ class NEOCluster(NEOCluster): ...@@ -312,13 +313,15 @@ class NEOCluster(NEOCluster):
class Application(StressApplication): class Application(StressApplication):
_blocking = None _blocking = _kill_mysqld = None
def __init__(self, client_count, thread_count, restart_ratio, logrotate, def __init__(self, client_count, thread_count,
*args, **kw): fault_probability, restart_ratio, kill_mysqld,
logrotate, *args, **kw):
self.client_count = client_count self.client_count = client_count
self.thread_count = thread_count self.thread_count = thread_count
self.logrotate = logrotate self.logrotate = logrotate
self.fault_probability = fault_probability
self.restart_ratio = restart_ratio self.restart_ratio = restart_ratio
self.cluster = cluster = NEOCluster(*args, **kw) self.cluster = cluster = NEOCluster(*args, **kw)
# Make the firewall also affect connections between storage nodes. # Make the firewall also affect connections between storage nodes.
...@@ -326,7 +329,24 @@ class Application(StressApplication): ...@@ -326,7 +329,24 @@ class Application(StressApplication):
def __init__(self, config): def __init__(self, config):
dscpPatch(1) dscpPatch(1)
StorageApplication__init__(self, config) StorageApplication__init__(self, config)
StorageApplication.__init__ = __init__
if kill_mysqld:
from neo.scripts import neostorage
from neo.storage.database import mysqldb
neostorage_main = neostorage.main
self._kill_mysqld = kill_mysqld = SimpleQueue()
def main():
pid = os.getpid()
try:
neostorage_main()
except mysqldb.OperationalError as e:
code = e.args[0]
except mysqldb.MysqlError as e:
code = e.code
if mysqldb.SERVER_LOST != code != mysqldb.SERVER_GONE_ERROR:
raise
kill_mysqld.put(pid)
neostorage.main = main
super(Application, self).__init__(cluster.SSL, super(Application, self).__init__(cluster.SSL,
util.parseMasterList(cluster.master_nodes)) util.parseMasterList(cluster.master_nodes))
...@@ -398,6 +418,10 @@ class Application(StressApplication): ...@@ -398,6 +418,10 @@ class Application(StressApplication):
t = threading.Thread(target=self._logrotate_thread) t = threading.Thread(target=self._logrotate_thread)
t.daemon = 1 t.daemon = 1
t.start() t.start()
if self._kill_mysqld:
t = threading.Thread(target=self._watch_storage_thread)
t.daemon = 1
t.start()
def stopCluster(self, wait=None): def stopCluster(self, wait=None):
self.restart_lock.acquire() self.restart_lock.acquire()
...@@ -471,13 +495,30 @@ class Application(StressApplication): ...@@ -471,13 +495,30 @@ class Application(StressApplication):
except ValueError: except ValueError:
pass pass
def _watch_storage_thread(self):
get = self._kill_mysqld.get
storage_list = self.cluster.getStorageProcessList()
while 1:
pid = get()
p, = (p for p in storage_list if p.pid == pid)
p.wait()
p.start()
def restartStorages(self, nids): def restartStorages(self, nids):
processes = [p for p in self.cluster.getStorageProcessList() storage_list = self.cluster.getStorageProcessList()
if p.uuid in nids] if self._kill_mysqld:
for p in processes: p.kill(signal.SIGKILL) db_list = [db for db, p in zip(self.cluster.db_list, storage_list)
time.sleep(1) if p.uuid in nids]
for p in processes: p.wait() mysql_pool.kill(*db_list)
for p in processes: p.start() time.sleep(1)
with open(os.devnull, "wb") as f:
mysql_pool.start(*db_list, stderr=f)
else:
processes = [p for p in storage_list if p.uuid in nids]
for p in processes: p.kill(signal.SIGKILL)
time.sleep(1)
for p in processes: p.wait()
for p in processes: p.start()
def _cleanFirewall(self): def _cleanFirewall(self):
with open(os.devnull, "wb") as f: with open(os.devnull, "wb") as f:
...@@ -548,6 +589,7 @@ def main(): ...@@ -548,6 +589,7 @@ def main():
default=socket.AF_INET, const=socket.AF_INET6, help='(default: IPv4)') default=socket.AF_INET, const=socket.AF_INET6, help='(default: IPv4)')
_('-a', '--adapter', choices=adapters, default=default_adapter) _('-a', '--adapter', choices=adapters, default=default_adapter)
_('-d', '--datadir', help="(default: same as unit tests)") _('-d', '--datadir', help="(default: same as unit tests)")
_('-e', '--engine', help="database engine (MySQL only)")
_('-l', '--logdir', help="(default: same as --datadir)") _('-l', '--logdir', help="(default: same as --datadir)")
_('-m', '--masters', type=int, default=1) _('-m', '--masters', type=int, default=1)
_('-s', '--storages', type=int, default=8) _('-s', '--storages', type=int, default=8)
...@@ -571,9 +613,14 @@ def main(): ...@@ -571,9 +613,14 @@ def main():
help='number of client processes') help='number of client processes')
_('-t', '--threads', type=int, default=1, _('-t', '--threads', type=int, default=1,
help='number of thread workers per client process') help='number of thread workers per client process')
_('-f', '--fault-probability', type=ratio, default=1, metavar='P',
help='probability to cause faults every second')
_('-r', '--restart-ratio', type=ratio, default=.5, metavar='RATIO', _('-r', '--restart-ratio', type=ratio, default=.5, metavar='RATIO',
help='probability to kill/restart a storage node, rather than just' help='probability to kill/restart a storage node, rather than just'
' RSTing a TCP connection with this node') ' RSTing a TCP connection with this node')
_('--kill-mysqld', action='store_true',
help='if r != 0 and if NEO_DB_MYCNF is set,'
' kill mysqld rather than storage node')
_('-C', '--console', type=int, default=0, _('-C', '--console', type=int, default=0,
help='console port (localhost) (default: any)') help='console port (localhost) (default: any)')
_('-D', '--delay', type=float, default=.01, _('-D', '--delay', type=float, default=.01,
...@@ -594,18 +641,31 @@ def main(): ...@@ -594,18 +641,31 @@ def main():
db_list = ['stress_neo%s' % x for x in xrange(args.storages)] db_list = ['stress_neo%s' % x for x in xrange(args.storages)]
if args.datadir: if args.datadir:
if args.adapter != 'SQLite': if args.adapter == 'SQLite':
parser.error('--datadir is only for SQLite adapter') db_list = [os.path.join(args.datadir, x + '.sqlite')
db_list = [os.path.join(args.datadir, x + '.sqlite') for x in db_list] for x in db_list]
elif mysql_pool:
mysql_pool.__init__(args.datadir)
else:
parser.error(
'--datadir: meaningless when using an existing MySQL server')
kw = dict(db_list=db_list, name='stress', kw = dict(db_list=db_list, name='stress',
partitions=args.partitions, replicas=args.replicas, partitions=args.partitions, replicas=args.replicas,
adapter=args.adapter, address_type=args.address_type, adapter=args.adapter, address_type=args.address_type,
temp_dir=args.logdir or args.datadir or getTempDirectory()) temp_dir=args.logdir or args.datadir or getTempDirectory(),
storage_kw={'engine': args.engine, 'wait': -1})
if args.command == 'run': if args.command == 'run':
NFQueue.delay = args.delay NFQueue.delay = args.delay
app = Application(args.clients, args.threads, args.restart_ratio, error = args.kill_mysqld and (
'invalid adapter' if args.adapter != 'MySQL' else
None if mysql_pool else 'NEO_DB_MYCNF not set'
)
if error:
parser.error('--kill-mysqld: ' + error)
app = Application(args.clients, args.threads,
args.fault_probability, args.restart_ratio, args.kill_mysqld,
int(round(args.logrotate * 3600, 0)), **kw) int(round(args.logrotate * 3600, 0)), **kw)
t = threading.Thread(target=console, args=(args.console, app)) t = threading.Thread(target=console, args=(args.console, app))
t.daemon = 1 t.daemon = 1
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment