Commit 0af75181 authored by Julien Muchembled's avatar Julien Muchembled

Make the number of replicas modifiable when the cluster is running

neoctl gets a new command to change the number of replicas.

The number of replicas becomes a new partition table attribute and
like the PT id, it is stored in the config table. On the other side,
the configuration value for the number of partitions is dropped,
since it can be computed from the partition table, which is
always stored in full.

The -p/-r master options now only apply at database creation.

Some implementation notes:

- The protocol is slightly optimized in that the master now sends
  automatically the whole partition tables to the admin & client
  nodes upon connection, like for storage nodes.
  This makes the protocol more consistent, and the master is the
  only remaining node requesting partition tables, during recovery.

- Some parts become tricky because app.pt can be None in more cases.
  For example, the extra condition in NodeManager.update
  (before app.pt.dropNode) was added for this is the reason.
  Or the 'loadPartitionTable' method (storage) that is not inlined
  because of unit tests.
  Overall, this commit simplifies more than it complicates.

- In the master handlers, we stop hijacking the 'connectionCompleted'
  method for tasks to be performed (often send the full partition
  table) on handler switches.

- The admin's 'bootstrapped' flag could have been removed earlier:
  race conditions can't happen since the AskNodeInformation packet
  was removed (commit d048a52d).
parent b9cac3f8
...@@ -21,7 +21,6 @@ from neo.lib.exception import PrimaryFailure ...@@ -21,7 +21,6 @@ from neo.lib.exception import PrimaryFailure
from .handler import AdminEventHandler, MasterEventHandler, \ from .handler import AdminEventHandler, MasterEventHandler, \
MasterRequestEventHandler MasterRequestEventHandler
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from neo.lib.pt import PartitionTable
from neo.lib.protocol import ClusterStates, Errors, NodeTypes, Packets from neo.lib.protocol import ClusterStates, Errors, NodeTypes, Packets
from neo.lib.debug import register as registerLiveDebugger from neo.lib.debug import register as registerLiveDebugger
...@@ -66,7 +65,6 @@ class Application(BaseApplication): ...@@ -66,7 +65,6 @@ class Application(BaseApplication):
super(Application, self).close() super(Application, self).close()
def reset(self): def reset(self):
self.bootstrapped = False
self.master_conn = None self.master_conn = None
self.master_node = None self.master_node = None
...@@ -117,39 +115,17 @@ class Application(BaseApplication): ...@@ -117,39 +115,17 @@ class Application(BaseApplication):
self.cluster_state = None self.cluster_state = None
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, NodeTypes.ADMIN, self.server) bootstrap = BootstrapManager(self, NodeTypes.ADMIN, self.server)
self.master_node, self.master_conn, num_partitions, num_replicas = \ self.master_node, self.master_conn = bootstrap.getPrimaryConnection()
bootstrap.getPrimaryConnection()
if self.pt is None:
self.pt = PartitionTable(num_partitions, num_replicas)
elif self.pt.getPartitions() != num_partitions:
# XXX: shouldn't we recover instead of raising ?
raise RuntimeError('the number of partitions is inconsistent')
elif self.pt.getReplicas() != num_replicas:
# XXX: shouldn't we recover instead of raising ?
raise RuntimeError('the number of replicas is inconsistent')
# passive handler # passive handler
self.master_conn.setHandler(self.master_event_handler) self.master_conn.setHandler(self.master_event_handler)
self.master_conn.ask(Packets.AskClusterState()) self.master_conn.ask(Packets.AskClusterState())
self.master_conn.ask(Packets.AskPartitionTable())
def sendPartitionTable(self, conn, min_offset, max_offset, uuid): def sendPartitionTable(self, conn, min_offset, max_offset, uuid):
# we have a pt
self.pt.log()
row_list = []
if max_offset == 0: if max_offset == 0:
max_offset = self.pt.getPartitions() max_offset = self.pt.getPartitions()
try: try:
for offset in xrange(min_offset, max_offset): row_list = map(self.pt.getRow, xrange(min_offset, max_offset))
row = []
try:
for cell in self.pt.getCellList(offset):
if uuid is None or cell.getUUID() == uuid:
row.append((cell.getUUID(), cell.getState()))
except TypeError:
pass
row_list.append((offset, row))
except IndexError: except IndexError:
conn.send(Errors.ProtocolError('invalid partition table offset')) conn.send(Errors.ProtocolError('invalid partition table offset'))
else: else:
......
...@@ -17,11 +17,12 @@ ...@@ -17,11 +17,12 @@
from neo.lib import logging, protocol from neo.lib import logging, protocol
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import uuid_str, Packets from neo.lib.protocol import uuid_str, Packets
from neo.lib.pt import PartitionTable
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
def check_primary_master(func): def check_primary_master(func):
def wrapper(self, *args, **kw): def wrapper(self, *args, **kw):
if self.app.bootstrapped: if self.app.master_conn is not None:
return func(self, *args, **kw) return func(self, *args, **kw)
raise protocol.NotReadyError('Not connected to a primary master.') raise protocol.NotReadyError('Not connected to a primary master.')
return wrapper return wrapper
...@@ -74,6 +75,7 @@ class AdminEventHandler(EventHandler): ...@@ -74,6 +75,7 @@ class AdminEventHandler(EventHandler):
tweakPartitionTable = forward_ask(Packets.TweakPartitionTable) tweakPartitionTable = forward_ask(Packets.TweakPartitionTable)
setClusterState = forward_ask(Packets.SetClusterState) setClusterState = forward_ask(Packets.SetClusterState)
setNodeState = forward_ask(Packets.SetNodeState) setNodeState = forward_ask(Packets.SetNodeState)
setNumReplicas = forward_ask(Packets.SetNumReplicas)
checkReplicas = forward_ask(Packets.CheckReplicas) checkReplicas = forward_ask(Packets.CheckReplicas)
truncate = forward_ask(Packets.Truncate) truncate = forward_ask(Packets.Truncate)
repair = forward_ask(Packets.Repair) repair = forward_ask(Packets.Repair)
...@@ -112,16 +114,12 @@ class MasterEventHandler(EventHandler): ...@@ -112,16 +114,12 @@ class MasterEventHandler(EventHandler):
def answerClusterState(self, conn, state): def answerClusterState(self, conn, state):
self.app.cluster_state = state self.app.cluster_state = state
def notifyPartitionChanges(self, conn, ptid, cell_list): def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
self.app.pt.update(ptid, cell_list, self.app.nm) pt = self.app.pt = object.__new__(PartitionTable)
pt.load(ptid, num_replicas, row_list, self.app.nm)
def answerPartitionTable(self, conn, ptid, row_list): def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
self.app.pt.load(ptid, row_list, self.app.nm) self.app.pt.update(ptid, num_replicas, cell_list, self.app.nm)
self.app.bootstrapped = True
def sendPartitionTable(self, conn, ptid, row_list):
if self.app.bootstrapped:
self.app.pt.load(ptid, row_list, self.app.nm)
def notifyClusterInformation(self, conn, cluster_state): def notifyClusterInformation(self, conn, cluster_state):
self.app.cluster_state = cluster_state self.app.cluster_state = cluster_state
......
...@@ -244,7 +244,6 @@ class Application(ThreadedApplication): ...@@ -244,7 +244,6 @@ class Application(ThreadedApplication):
# operational. Might raise ConnectionClosed so that the new # operational. Might raise ConnectionClosed so that the new
# primary can be looked-up again. # primary can be looked-up again.
logging.info('Initializing from master') logging.info('Initializing from master')
ask(conn, Packets.AskPartitionTable(), handler=handler)
ask(conn, Packets.AskLastTransaction(), handler=handler) ask(conn, Packets.AskLastTransaction(), handler=handler)
if self.pt.operational(): if self.pt.operational():
break break
......
...@@ -26,10 +26,6 @@ from ..exception import NEOStorageError ...@@ -26,10 +26,6 @@ from ..exception import NEOStorageError
class PrimaryBootstrapHandler(AnswerBaseHandler): class PrimaryBootstrapHandler(AnswerBaseHandler):
""" Bootstrap handler used when looking for the primary master """ """ Bootstrap handler used when looking for the primary master """
def answerPartitionTable(self, conn, ptid, row_list):
assert row_list
self.app.pt.load(ptid, row_list, self.app.nm)
def answerLastTransaction(*args): def answerLastTransaction(*args):
pass pass
...@@ -42,9 +38,6 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -42,9 +38,6 @@ class PrimaryNotificationsHandler(MTEventHandler):
except PrimaryElected, e: except PrimaryElected, e:
self.app.primary_master_node, = e.args self.app.primary_master_node, = e.args
def _acceptIdentification(self, node, num_partitions, num_replicas):
self.app.pt = PartitionTable(num_partitions, num_replicas)
def answerLastTransaction(self, conn, ltid): def answerLastTransaction(self, conn, ltid):
app = self.app app = self.app
app_last_tid = app.__dict__.get('last_tid', '') app_last_tid = app.__dict__.get('last_tid', '')
...@@ -134,9 +127,12 @@ class PrimaryNotificationsHandler(MTEventHandler): ...@@ -134,9 +127,12 @@ class PrimaryNotificationsHandler(MTEventHandler):
finally: finally:
app._cache_lock_release() app._cache_lock_release()
def notifyPartitionChanges(self, conn, ptid, cell_list): def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
if self.app.pt.filled(): pt = self.app.pt = object.__new__(PartitionTable)
self.app.pt.update(ptid, cell_list, self.app.nm) pt.load(ptid, num_replicas, row_list, self.app.nm)
def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
self.app.pt.update(ptid, num_replicas, cell_list, self.app.nm)
def notifyNodeInformation(self, conn, timestamp, node_list): def notifyNodeInformation(self, conn, timestamp, node_list):
super(PrimaryNotificationsHandler, self).notifyNodeInformation( super(PrimaryNotificationsHandler, self).notifyNodeInformation(
......
...@@ -36,8 +36,6 @@ class BootstrapManager(EventHandler): ...@@ -36,8 +36,6 @@ class BootstrapManager(EventHandler):
self.devpath = devpath self.devpath = devpath
self.new_nid = new_nid self.new_nid = new_nid
self.node_type = node_type self.node_type = node_type
self.num_replicas = None
self.num_partitions = None
app.nm.reset() app.nm.reset()
uuid = property(lambda self: self.app.uuid) uuid = property(lambda self: self.app.uuid)
...@@ -54,10 +52,8 @@ class BootstrapManager(EventHandler): ...@@ -54,10 +52,8 @@ class BootstrapManager(EventHandler):
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
self.current = None self.current = None
def _acceptIdentification(self, node, num_partitions, num_replicas): def _acceptIdentification(self, node):
assert self.current is node, (self.current, node) assert self.current is node, (self.current, node)
self.num_partitions = num_partitions
self.num_replicas = num_replicas
def getPrimaryConnection(self): def getPrimaryConnection(self):
""" """
...@@ -74,8 +70,7 @@ class BootstrapManager(EventHandler): ...@@ -74,8 +70,7 @@ class BootstrapManager(EventHandler):
try: try:
while self.current: while self.current:
if self.current.isIdentified(): if self.current.isIdentified():
return (self.current, self.current.getConnection(), return self.current, self.current.getConnection()
self.num_partitions, self.num_replicas)
poll(1) poll(1)
except PrimaryElected, e: except PrimaryElected, e:
if self.current: if self.current:
......
...@@ -160,8 +160,7 @@ class EventHandler(object): ...@@ -160,8 +160,7 @@ class EventHandler(object):
def _acceptIdentification(*args): def _acceptIdentification(*args):
pass pass
def acceptIdentification(self, conn, node_type, uuid, def acceptIdentification(self, conn, node_type, uuid, your_uuid):
num_partitions, num_replicas, your_uuid):
app = self.app app = self.app
node = app.nm.getByAddress(conn.getAddress()) node = app.nm.getByAddress(conn.getAddress())
assert node.getConnection() is conn, (node.getConnection(), conn) assert node.getConnection() is conn, (node.getConnection(), conn)
...@@ -180,7 +179,7 @@ class EventHandler(object): ...@@ -180,7 +179,7 @@ class EventHandler(object):
elif node.getUUID() != uuid or app.uuid != your_uuid != None: elif node.getUUID() != uuid or app.uuid != your_uuid != None:
raise ProtocolError('invalid uuids') raise ProtocolError('invalid uuids')
node.setIdentified() node.setIdentified()
self._acceptIdentification(node, num_partitions, num_replicas) self._acceptIdentification(node)
return return
conn.close() conn.close()
......
...@@ -486,7 +486,7 @@ class NodeManager(EventQueue): ...@@ -486,7 +486,7 @@ class NodeManager(EventQueue):
# For the first notification, we receive a full list of nodes from # For the first notification, we receive a full list of nodes from
# the master. Remove all unknown nodes from a previous connection. # the master. Remove all unknown nodes from a previous connection.
for node in self._node_set.difference(added_list): for node in self._node_set.difference(added_list):
if app.pt.dropNode(node): if not node.isStorage() or app.pt.dropNode(node):
self.remove(node) self.remove(node)
self.log() self.log()
self.executeQueuedEvents() self.executeQueuedEvents()
......
...@@ -616,10 +616,7 @@ PFCellList = PList('cell_list', ...@@ -616,10 +616,7 @@ PFCellList = PList('cell_list',
) )
PFRowList = PList('row_list', PFRowList = PList('row_list',
PStruct('row',
PNumber('offset'),
PFCellList, PFCellList,
),
) )
PFHistoryList = PList('history_list', PFHistoryList = PList('history_list',
...@@ -694,8 +691,6 @@ class RequestIdentification(Packet): ...@@ -694,8 +691,6 @@ class RequestIdentification(Packet):
_answer = PStruct('accept_identification', _answer = PStruct('accept_identification',
PFNodeType, PFNodeType,
PUUID('my_uuid'), PUUID('my_uuid'),
PNumber('num_partitions'),
PNumber('num_replicas'),
PUUID('your_uuid'), PUUID('your_uuid'),
) )
...@@ -751,23 +746,24 @@ class LastIDs(Packet): ...@@ -751,23 +746,24 @@ class LastIDs(Packet):
class PartitionTable(Packet): class PartitionTable(Packet):
""" """
Ask storage node the remaining data needed by master to recover. Ask storage node the remaining data needed by master to recover.
This is also how the clients get the full partition table on connection.
:nodes: M -> S; C -> M :nodes: M -> S
""" """
_answer = PStruct('answer_partition_table', _answer = PStruct('answer_partition_table',
PPTID('ptid'), PPTID('ptid'),
PNumber('num_replicas'),
PFRowList, PFRowList,
) )
class NotifyPartitionTable(Packet): class NotifyPartitionTable(Packet):
""" """
Send the full partition table to admin/storage nodes on connection. Send the full partition table to admin/client/storage nodes on connection.
:nodes: M -> A, S :nodes: M -> A, C, S
""" """
_fmt = PStruct('send_partition_table', _fmt = PStruct('send_partition_table',
PPTID('ptid'), PPTID('ptid'),
PNumber('num_replicas'),
PFRowList, PFRowList,
) )
...@@ -779,6 +775,7 @@ class PartitionChanges(Packet): ...@@ -779,6 +775,7 @@ class PartitionChanges(Packet):
""" """
_fmt = PStruct('notify_partition_changes', _fmt = PStruct('notify_partition_changes',
PPTID('ptid'), PPTID('ptid'),
PNumber('num_replicas'),
PList('cell_list', PList('cell_list',
PStruct('cell', PStruct('cell',
PNumber('offset'), PNumber('offset'),
...@@ -1271,6 +1268,18 @@ class NotifyNodeInformation(Packet): ...@@ -1271,6 +1268,18 @@ class NotifyNodeInformation(Packet):
PFNodeList, PFNodeList,
) )
class SetNumReplicas(Packet):
"""
Set the number of replicas.
:nodes: ctl -> A -> M
"""
_fmt = PStruct('set_num_replicas',
PNumber('num_replicas'),
)
_answer = Error
class SetClusterState(Packet): class SetClusterState(Packet):
""" """
Set the cluster state. Set the cluster state.
...@@ -1766,6 +1775,8 @@ class Packets(dict): ...@@ -1766,6 +1775,8 @@ class Packets(dict):
AddPendingNodes, ignore_when_closed=False) AddPendingNodes, ignore_when_closed=False)
TweakPartitionTable = register( TweakPartitionTable = register(
TweakPartitionTable, ignore_when_closed=False) TweakPartitionTable, ignore_when_closed=False)
SetNumReplicas = register(
SetNumReplicas, ignore_when_closed=False)
SetClusterState = register( SetClusterState = register(
SetClusterState, ignore_when_closed=False) SetClusterState, ignore_when_closed=False)
Repair = register( Repair = register(
......
...@@ -86,15 +86,9 @@ class PartitionTable(object): ...@@ -86,15 +86,9 @@ class PartitionTable(object):
'a cell became non-readable whereas all cells were readable' 'a cell became non-readable whereas all cells were readable'
def __init__(self, num_partitions, num_replicas): def __init__(self, num_partitions, num_replicas):
self._id = None
self.np = num_partitions self.np = num_partitions
self.nr = num_replicas self.nr = num_replicas
self.num_filled_rows = 0 self.clear()
# Note: don't use [[]] * num_partition construct, as it duplicates
# instance *references*, so the outer list contains really just one
# inner list instance.
self.partition_list = [[] for _ in xrange(num_partitions)]
self.count_dict = {}
def getID(self): def getID(self):
return self._id return self._id
...@@ -113,7 +107,7 @@ class PartitionTable(object): ...@@ -113,7 +107,7 @@ class PartitionTable(object):
# instance *references*, so the outer list contains really just one # instance *references*, so the outer list contains really just one
# inner list instance. # inner list instance.
self.partition_list = [[] for _ in xrange(self.np)] self.partition_list = [[] for _ in xrange(self.np)]
self.count_dict.clear() self.count_dict = {}
def getAssignedPartitionList(self, uuid): def getAssignedPartitionList(self, uuid):
""" Return the partition assigned to the specified UUID """ """ Return the partition assigned to the specified UUID """
...@@ -203,31 +197,31 @@ class PartitionTable(object): ...@@ -203,31 +197,31 @@ class PartitionTable(object):
del self.count_dict[node] del self.count_dict[node]
return not count return not count
def load(self, ptid, row_list, nm): def _load(self, ptid, num_replicas, row_list, getByUUID):
self.__init__(len(row_list), num_replicas)
self._id = ptid
for offset, row in enumerate(row_list):
for uuid, state in row:
node = getByUUID(uuid)
self._setCell(offset, node, state)
def load(self, ptid, num_replicas, row_list, nm):
""" """
Load the partition table with the specified PTID, discard all previous Load the partition table with the specified PTID, discard all previous
content. content.
""" """
self.clear() self._load(ptid, num_replicas, row_list, nm.getByUUID)
self._id = ptid
for offset, row in row_list:
if offset >= self.getPartitions():
raise IndexError
for uuid, state in row:
node = nm.getByUUID(uuid)
# the node must be known by the node manager
assert node is not None
self._setCell(offset, node, state)
logging.debug('partition table loaded (ptid=%s)', ptid) logging.debug('partition table loaded (ptid=%s)', ptid)
self.log() self.log()
def update(self, ptid, cell_list, nm): def update(self, ptid, num_replicas, cell_list, nm):
""" """
Update the partition with the cell list supplied. If a node Update the partition with the cell list supplied. If a node
is not known, it is created in the node manager and set as unavailable is not known, it is created in the node manager and set as unavailable
""" """
assert self._id < ptid, (self._id, ptid) assert self._id < ptid, (self._id, ptid)
self._id = ptid self._id = ptid
self.nr = num_replicas
readable_list = [] readable_list = []
for row in self.partition_list: for row in self.partition_list:
if not all(cell.isReadable() for cell in row): if not all(cell.isReadable() for cell in row):
...@@ -310,14 +304,11 @@ class PartitionTable(object): ...@@ -310,14 +304,11 @@ class PartitionTable(object):
return True return True
def getRow(self, offset): def getRow(self, offset):
row = self.partition_list[offset] return [(cell.getUUID(), cell.getState())
if row is None: for cell in self.partition_list[offset]]
return []
return [(cell.getUUID(), cell.getState()) for cell in row]
def getRowList(self): def getRowList(self):
getRow = self.getRow return map(self.getRow, xrange(self.np))
return [(x, getRow(x)) for x in xrange(self.np)]
class MTPartitionTable(PartitionTable): class MTPartitionTable(PartitionTable):
""" Thread-safe aware version of the partition table, override only methods """ Thread-safe aware version of the partition table, override only methods
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import sys import sys
from collections import defaultdict from collections import defaultdict
from functools import partial
from time import time from time import time
from neo.lib import logging, util from neo.lib import logging, util
...@@ -76,13 +77,11 @@ class Application(BaseApplication): ...@@ -76,13 +77,11 @@ class Application(BaseApplication):
@classmethod @classmethod
def _buildOptionParser(cls): def _buildOptionParser(cls):
_ = cls.option_parser parser = cls.option_parser
_.description = "NEO Master node" parser.description = "NEO Master node"
cls.addCommonServerOptions('master', '127.0.0.1:10000', '') cls.addCommonServerOptions('master', '127.0.0.1:10000', '')
_ = _.group('master') _ = parser.group('master')
_.int('r', 'replicas', default=0, help="replicas number")
_.int('p', 'partitions', default=100, help="partitions number")
_.int('A', 'autostart', _.int('A', 'autostart',
help="minimum number of pending storage nodes to automatically" help="minimum number of pending storage nodes to automatically"
" start new cluster (to avoid unwanted recreation of the" " start new cluster (to avoid unwanted recreation of the"
...@@ -94,6 +93,10 @@ class Application(BaseApplication): ...@@ -94,6 +93,10 @@ class Application(BaseApplication):
_.int('i', 'nid', _.int('i', 'nid',
help="specify an NID to use for this process (testing purpose)") help="specify an NID to use for this process (testing purpose)")
_ = parser.group('database creation')
_.int('r', 'replicas', default=0, help="replicas number")
_.int('p', 'partitions', default=100, help="partitions number")
def __init__(self, config): def __init__(self, config):
super(Application, self).__init__( super(Application, self).__init__(
config.get('ssl'), config.get('dynamic_master_list')) config.get('ssl'), config.get('dynamic_master_list'))
...@@ -117,14 +120,14 @@ class Application(BaseApplication): ...@@ -117,14 +120,14 @@ class Application(BaseApplication):
replicas = config['replicas'] replicas = config['replicas']
partitions = config['partitions'] partitions = config['partitions']
if replicas < 0: if replicas < 0:
raise RuntimeError, 'replicas must be a positive integer' sys.exit('replicas must be a positive integer')
if partitions <= 0: if partitions <= 0:
raise RuntimeError, 'partitions must be more than zero' sys.exit('partitions must be more than zero')
self.pt = PartitionTable(partitions, replicas)
logging.info('Configuration:') logging.info('Configuration:')
logging.info('Partitions: %d', partitions) logging.info('Partitions: %d', partitions)
logging.info('Replicas : %d', replicas) logging.info('Replicas : %d', replicas)
logging.info('Name : %s', self.name) logging.info('Name : %s', self.name)
self.newPartitionTable = partial(PartitionTable, partitions, replicas)
self.listening_conn = None self.listening_conn = None
self.cluster_state = None self.cluster_state = None
...@@ -212,12 +215,18 @@ class Application(BaseApplication): ...@@ -212,12 +215,18 @@ class Application(BaseApplication):
if node_list: if node_list:
node.send(Packets.NotifyNodeInformation(now, node_list)) node.send(Packets.NotifyNodeInformation(now, node_list))
def broadcastPartitionChanges(self, cell_list): def broadcastPartitionChanges(self, cell_list, num_replicas=None):
"""Broadcast a Notify Partition Changes packet.""" """Broadcast a Notify Partition Changes packet."""
if cell_list: pt = self.pt
ptid = self.pt.setNextID() if num_replicas is not None:
self.pt.logUpdated() pt.setReplicas(num_replicas)
packet = Packets.NotifyPartitionChanges(ptid, cell_list) elif cell_list:
num_replicas = pt.getReplicas()
else:
return
packet = Packets.NotifyPartitionChanges(
pt.setNextID(), num_replicas, cell_list)
pt.logUpdated()
for node in self.nm.getIdentifiedList(): for node in self.nm.getIdentifiedList():
# As for broadcastNodesInformation, we don't send the full PT # As for broadcastNodesInformation, we don't send the full PT
# when pending storage nodes are added, so keep them notified. # when pending storage nodes are added, so keep them notified.
...@@ -437,16 +446,7 @@ class Application(BaseApplication): ...@@ -437,16 +446,7 @@ class Application(BaseApplication):
conn.send(notification_packet) conn.send(notification_packet)
elif conn.isServer(): elif conn.isServer():
continue continue
if node.isClient(): if node.isMaster():
if state == ClusterStates.RUNNING:
handler = self.client_service_handler
elif state == ClusterStates.BACKINGUP:
handler = self.client_ro_service_handler
else:
if state != ClusterStates.STOPPING:
conn.abort()
continue
elif node.isMaster():
if state == ClusterStates.RECOVERING: if state == ClusterStates.RECOVERING:
handler = self.election_handler handler = self.election_handler
else: else:
...@@ -454,10 +454,16 @@ class Application(BaseApplication): ...@@ -454,10 +454,16 @@ class Application(BaseApplication):
elif node.isStorage() and storage_handler: elif node.isStorage() and storage_handler:
handler = storage_handler handler = storage_handler
else: else:
# There's a single handler type for admins.
# Client can't change handler without being first disconnected.
assert state in (
ClusterStates.STOPPING,
ClusterStates.STOPPING_BACKUP,
) or not node.isClient(), (state, node)
continue # keep handler continue # keep handler
if type(handler) is not type(conn.getLastHandler()): if type(handler) is not type(conn.getLastHandler()):
conn.setHandler(handler) conn.setHandler(handler)
handler.connectionCompleted(conn, new=False) handler.handlerSwitched(conn, new=False)
self.cluster_state = state self.cluster_state = state
def getNewUUID(self, uuid, address, node_type): def getNewUUID(self, uuid, address, node_type):
......
...@@ -111,17 +111,12 @@ class BackupApplication(object): ...@@ -111,17 +111,12 @@ class BackupApplication(object):
else: else:
break break
poll(1) poll(1)
node, conn, num_partitions, num_replicas = \ node, conn = bootstrap.getPrimaryConnection()
bootstrap.getPrimaryConnection()
try: try:
app.changeClusterState(ClusterStates.BACKINGUP) app.changeClusterState(ClusterStates.BACKINGUP)
del bootstrap, node del bootstrap, node
if num_partitions != pt.getPartitions():
raise RuntimeError("inconsistent number of partitions")
self.ignore_invalidations = True self.ignore_invalidations = True
self.pt = PartitionTable(num_partitions, num_replicas)
conn.setHandler(BackupHandler(self)) conn.setHandler(BackupHandler(self))
conn.ask(Packets.AskPartitionTable())
conn.ask(Packets.AskLastTransaction()) conn.ask(Packets.AskLastTransaction())
# debug variable to log how big 'tid_list' can be. # debug variable to log how big 'tid_list' can be.
self.debug_tid_count = 0 self.debug_tid_count = 0
......
...@@ -23,10 +23,6 @@ from neo.lib.protocol import Packets ...@@ -23,10 +23,6 @@ from neo.lib.protocol import Packets
class MasterHandler(EventHandler): class MasterHandler(EventHandler):
"""This class implements a generic part of the event handlers.""" """This class implements a generic part of the event handlers."""
def connectionCompleted(self, conn, new=None):
if new is None:
super(MasterHandler, self).connectionCompleted(conn)
def connectionLost(self, conn, new_state=None): def connectionLost(self, conn, new_state=None):
if self.app.listening_conn: # if running if self.app.listening_conn: # if running
self._connectionLost(conn) self._connectionLost(conn)
...@@ -59,17 +55,20 @@ class MasterHandler(EventHandler): ...@@ -59,17 +55,20 @@ class MasterHandler(EventHandler):
+ app.getNodeInformationDict(node_list)[node.getType()]) + app.getNodeInformationDict(node_list)[node.getType()])
conn.send(Packets.NotifyNodeInformation(monotonic_time(), node_list)) conn.send(Packets.NotifyNodeInformation(monotonic_time(), node_list))
def askPartitionTable(self, conn): def handlerSwitched(self, conn, new):
pt = self.app.pt pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList())) # Except storages during recovery and secondary masters, all nodes
# receives the full partition table as soon as they're identified.
# It is also sent in 2 other cases:
# - to admins during recovery, whenever a newer PT is loaded;
# - to storage when switching from recovery to verification.
# After that, non-master nodes only receive incremental updates.
conn.send(Packets.SendPartitionTable(
pt.getID(), pt.getReplicas(), pt.getRowList()))
class BaseServiceHandler(MasterHandler): class BaseServiceHandler(MasterHandler):
"""This class deals with events for a service phase.""" """Common handler class for storage nodes."""
def connectionCompleted(self, conn, new):
pt = self.app.pt
conn.send(Packets.SendPartitionTable(pt.getID(), pt.getRowList()))
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
app = self.app app = self.app
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import random import random
from functools import wraps
from . import MasterHandler from . import MasterHandler
from ..app import monotonic_time, StateChangedException from ..app import monotonic_time, StateChangedException
...@@ -38,9 +39,25 @@ NODE_STATE_WORKFLOW = { ...@@ -38,9 +39,25 @@ NODE_STATE_WORKFLOW = {
NodeTypes.STORAGE: (NodeStates.DOWN, NodeStates.UNKNOWN), NodeTypes.STORAGE: (NodeStates.DOWN, NodeStates.UNKNOWN),
} }
def check_state(*states):
def decorator(wrapped):
def wrapper(self, *args):
state = self.app.getClusterState()
if state not in states:
raise ProtocolError('%s RPC can not be used in %s state'
% (wrapped.__name__, state))
wrapped(self, *args)
return wraps(wrapped)(wrapper)
return decorator
class AdministrationHandler(MasterHandler): class AdministrationHandler(MasterHandler):
"""This class deals with messages from the admin node only""" """This class deals with messages from the admin node only"""
def handlerSwitched(self, conn, new):
assert new
super(AdministrationHandler, self).handlerSwitched(conn, new)
def connectionLost(self, conn, new_state): def connectionLost(self, conn, new_state):
node = self.app.nm.getByUUID(conn.getUUID()) node = self.app.nm.getByUUID(conn.getUUID())
if node is not None: if node is not None:
...@@ -134,16 +151,17 @@ class AdministrationHandler(MasterHandler): ...@@ -134,16 +151,17 @@ class AdministrationHandler(MasterHandler):
monotonic_time(), [node.asTuple()])) monotonic_time(), [node.asTuple()]))
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
# XXX: Would it be safe to allow more states ?
__change_pt_rpc = check_state(
ClusterStates.RUNNING,
ClusterStates.STARTING_BACKUP,
ClusterStates.BACKINGUP)
@__change_pt_rpc
def addPendingNodes(self, conn, uuid_list): def addPendingNodes(self, conn, uuid_list):
uuids = ', '.join(map(uuid_str, uuid_list)) uuids = ', '.join(map(uuid_str, uuid_list))
logging.debug('Add nodes %s', uuids) logging.debug('Add nodes %s', uuids)
app = self.app app = self.app
state = app.getClusterState()
# XXX: Would it be safe to allow more states ?
if state not in (ClusterStates.RUNNING,
ClusterStates.STARTING_BACKUP,
ClusterStates.BACKINGUP):
raise ProtocolError('Can not add nodes in %s state' % state)
# take all pending nodes # take all pending nodes
node_list = list(app.pt.addNodeList(node node_list = list(app.pt.addNodeList(node
for node in app.nm.getStorageList() for node in app.nm.getStorageList()
...@@ -172,24 +190,21 @@ class AdministrationHandler(MasterHandler): ...@@ -172,24 +190,21 @@ class AdministrationHandler(MasterHandler):
node.send(repair) node.send(repair)
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
@__change_pt_rpc
def setNumReplicas(self, conn, num_replicas):
self.app.broadcastPartitionChanges((), num_replicas)
conn.answer(Errors.Ack(''))
@__change_pt_rpc
def tweakPartitionTable(self, conn, uuid_list): def tweakPartitionTable(self, conn, uuid_list):
app = self.app app = self.app
state = app.getClusterState()
# XXX: Would it be safe to allow more states ?
if state not in (ClusterStates.RUNNING,
ClusterStates.STARTING_BACKUP,
ClusterStates.BACKINGUP):
raise ProtocolError('Can not tweak partition table in %s state'
% state)
app.broadcastPartitionChanges(app.pt.tweak([node app.broadcastPartitionChanges(app.pt.tweak([node
for node in app.nm.getStorageList() for node in app.nm.getStorageList()
if node.getUUID() in uuid_list or not node.isRunning()])) if node.getUUID() in uuid_list or not node.isRunning()]))
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
@check_state(ClusterStates.RUNNING)
def truncate(self, conn, tid): def truncate(self, conn, tid):
app = self.app
if app.cluster_state != ClusterStates.RUNNING:
raise ProtocolError('Can not truncate in this state')
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
raise StoppedOperation(tid) raise StoppedOperation(tid)
...@@ -237,3 +252,5 @@ class AdministrationHandler(MasterHandler): ...@@ -237,3 +252,5 @@ class AdministrationHandler(MasterHandler):
node.send(Packets.CheckPartition( node.send(Packets.CheckPartition(
offset, source, min_tid, max_tid)) offset, source, min_tid, max_tid))
conn.answer(Errors.Ack('')) conn.answer(Errors.Ack(''))
del __change_pt_rpc
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from neo.lib.exception import PrimaryFailure from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler from neo.lib.handler import EventHandler
from neo.lib.protocol import ZERO_TID from neo.lib.protocol import ZERO_TID
from neo.lib.pt import PartitionTable
class BackupHandler(EventHandler): class BackupHandler(EventHandler):
"""Handler dedicated to upstream master during BACKINGUP state""" """Handler dedicated to upstream master during BACKINGUP state"""
...@@ -25,12 +26,15 @@ class BackupHandler(EventHandler): ...@@ -25,12 +26,15 @@ class BackupHandler(EventHandler):
if self.app.app.listening_conn: # if running if self.app.app.listening_conn: # if running
raise PrimaryFailure('connection lost') raise PrimaryFailure('connection lost')
def answerPartitionTable(self, conn, ptid, row_list): def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
self.app.pt.load(ptid, row_list, self.app.nm) app = self.app
pt = app.pt = object.__new__(PartitionTable)
pt.load(ptid, num_replicas, row_list, self.app.nm)
if pt.getPartitions() != app.app.pt.getPartitions():
raise RuntimeError("inconsistent number of partitions")
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
if self.app.pt.filled(): self.app.pt.update(ptid, num_replicas, cell_list, self.app.nm)
self.app.pt.update(ptid, cell_list, self.app.nm)
def answerLastTransaction(self, conn, tid): def answerLastTransaction(self, conn, tid):
app = self.app app = self.app
......
...@@ -22,6 +22,10 @@ from . import MasterHandler ...@@ -22,6 +22,10 @@ from . import MasterHandler
class ClientServiceHandler(MasterHandler): class ClientServiceHandler(MasterHandler):
""" Handler dedicated to client during service state """ """ Handler dedicated to client during service state """
def handlerSwitched(self, conn, new):
assert new
super(ClientServiceHandler, self).handlerSwitched(conn, new)
def _connectionLost(self, conn): def _connectionLost(self, conn):
# cancel its transactions and forgot the node # cancel its transactions and forgot the node
app = self.app app = self.app
......
...@@ -128,11 +128,9 @@ class IdentificationHandler(EventHandler): ...@@ -128,11 +128,9 @@ class IdentificationHandler(EventHandler):
conn.answer(Packets.AcceptIdentification( conn.answer(Packets.AcceptIdentification(
NodeTypes.MASTER, NodeTypes.MASTER,
app.uuid, app.uuid,
app.pt.getPartitions(),
app.pt.getReplicas(),
uuid)) uuid))
handler._notifyNodeInformation(conn) handler._notifyNodeInformation(conn)
handler.connectionCompleted(conn, True) handler.handlerSwitched(conn, True)
class SecondaryIdentificationHandler(EventHandler): class SecondaryIdentificationHandler(EventHandler):
......
...@@ -23,6 +23,9 @@ from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets ...@@ -23,6 +23,9 @@ from neo.lib.protocol import ClusterStates, NodeStates, NodeTypes, Packets
class SecondaryHandler(MasterHandler): class SecondaryHandler(MasterHandler):
"""Handler used by primary to handle secondary masters""" """Handler used by primary to handle secondary masters"""
def handlerSwitched(self, conn, new):
pass
def _connectionLost(self, conn): def _connectionLost(self, conn):
app = self.app app = self.app
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
...@@ -30,11 +33,10 @@ class SecondaryHandler(MasterHandler): ...@@ -30,11 +33,10 @@ class SecondaryHandler(MasterHandler):
app.broadcastNodesInformation([node]) app.broadcastNodesInformation([node])
class ElectionHandler(MasterHandler): class ElectionHandler(SecondaryHandler):
"""Handler used by primary to handle secondary masters during election""" """Handler used by primary to handle secondary masters during election"""
def connectionCompleted(self, conn, new=None): def connectionCompleted(self, conn):
if new is None:
super(ElectionHandler, self).connectionCompleted(conn) super(ElectionHandler, self).connectionCompleted(conn)
app = self.app app = self.app
conn.ask(Packets.RequestIdentification(NodeTypes.MASTER, conn.ask(Packets.RequestIdentification(NodeTypes.MASTER,
...@@ -44,7 +46,7 @@ class ElectionHandler(MasterHandler): ...@@ -44,7 +46,7 @@ class ElectionHandler(MasterHandler):
super(ElectionHandler, self).connectionFailed(conn) super(ElectionHandler, self).connectionFailed(conn)
self.connectionLost(conn) self.connectionLost(conn)
def _acceptIdentification(self, node, *args): def _acceptIdentification(self, node):
raise PrimaryElected(node) raise PrimaryElected(node)
def _connectionLost(self, *args): def _connectionLost(self, *args):
...@@ -66,7 +68,7 @@ class ElectionHandler(MasterHandler): ...@@ -66,7 +68,7 @@ class ElectionHandler(MasterHandler):
class PrimaryHandler(ElectionHandler): class PrimaryHandler(ElectionHandler):
"""Handler used by secondaries to handle primary master""" """Handler used by secondaries to handle primary master"""
def _acceptIdentification(self, node, num_partitions, num_replicas): def _acceptIdentification(self, node):
assert self.app.primary_master is node, (self.app.primary_master, node) assert self.app.primary_master is node, (self.app.primary_master, node)
def _connectionLost(self, conn): def _connectionLost(self, conn):
......
...@@ -26,10 +26,10 @@ from . import BaseServiceHandler ...@@ -26,10 +26,10 @@ from . import BaseServiceHandler
class StorageServiceHandler(BaseServiceHandler): class StorageServiceHandler(BaseServiceHandler):
""" Handler dedicated to storages during service state """ """ Handler dedicated to storages during service state """
def connectionCompleted(self, conn, new): def handlerSwitched(self, conn, new):
app = self.app app = self.app
if new: if new:
super(StorageServiceHandler, self).connectionCompleted(conn, new) super(StorageServiceHandler, self).handlerSwitched(conn, new)
node = app.nm.getByUUID(conn.getUUID()) node = app.nm.getByUUID(conn.getUUID())
if node.isRunning(): # node may be PENDING if node.isRunning(): # node may be PENDING
app.startStorage(node) app.startStorage(node)
......
...@@ -56,6 +56,10 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -56,6 +56,10 @@ class PartitionTable(neo.lib.pt.PartitionTable):
self._id += 1 self._id += 1
return self._id return self._id
def setReplicas(self, num_replicas):
assert num_replicas >= 0, num_replicas
self.nr = num_replicas
def make(self, node_list): def make(self, node_list):
"""Make a new partition table from scratch.""" """Make a new partition table from scratch."""
assert self._id is None and node_list, (self._id, node_list) assert self._id is None and node_list, (self._id, node_list)
...@@ -108,26 +112,19 @@ class PartitionTable(neo.lib.pt.PartitionTable): ...@@ -108,26 +112,19 @@ class PartitionTable(neo.lib.pt.PartitionTable):
self.num_filled_rows = len(filter(None, self.partition_list)) self.num_filled_rows = len(filter(None, self.partition_list))
return change_list return change_list
def load(self, ptid, row_list, nm): def load(self, ptid, num_replicas, row_list, nm):
""" """
Load a partition table from a storage node during the recovery. Load a partition table from a storage node during the recovery.
Return the new storage nodes registered Return the new storage nodes registered
""" """
# check offsets
for offset, _row in row_list:
if offset >= self.getPartitions():
raise IndexError, offset
# store the partition table
self.clear()
self._id = ptid
new_nodes = [] new_nodes = []
for offset, row in row_list: def getByUUID(nid):
for uuid, state in row: node = nm.getByUUID(nid)
node = nm.getByUUID(uuid)
if node is None: if node is None:
node = nm.createStorage(uuid=uuid) node = nm.createStorage(uuid=nid)
new_nodes.append(node.asTuple()) new_nodes.append(node.asTuple())
self._setCell(offset, node, state) return node
self._load(ptid, num_replicas, row_list, getByUUID)
return new_nodes return new_nodes
def setUpToDate(self, node, offset): def setUpToDate(self, node, offset):
......
...@@ -28,7 +28,7 @@ class RecoveryManager(MasterHandler): ...@@ -28,7 +28,7 @@ class RecoveryManager(MasterHandler):
def __init__(self, app): def __init__(self, app):
# The target node's uuid to request next. # The target node's uuid to request next.
self.target_ptid = None self.target_ptid = 0
self.ask_pt = [] self.ask_pt = []
self.backup_tid_dict = {} self.backup_tid_dict = {}
self.truncate_dict = {} self.truncate_dict = {}
...@@ -52,9 +52,8 @@ class RecoveryManager(MasterHandler): ...@@ -52,9 +52,8 @@ class RecoveryManager(MasterHandler):
""" """
logging.info('begin the recovery of the status') logging.info('begin the recovery of the status')
app = self.app app = self.app
pt = app.pt pt = app.pt = app.newPartitionTable()
app.changeClusterState(ClusterStates.RECOVERING) app.changeClusterState(ClusterStates.RECOVERING)
pt.clear()
self.try_secondary = True self.try_secondary = True
...@@ -113,7 +112,7 @@ class RecoveryManager(MasterHandler): ...@@ -113,7 +112,7 @@ class RecoveryManager(MasterHandler):
for node in node_list: for node in node_list:
conn = node.getConnection() conn = node.getConnection()
conn.send(truncate) conn.send(truncate)
self.connectionCompleted(conn, False) self.handlerSwitched(conn, False)
continue continue
node_list = pt.getConnectedNodeList() node_list = pt.getConnectedNodeList()
break break
...@@ -140,12 +139,12 @@ class RecoveryManager(MasterHandler): ...@@ -140,12 +139,12 @@ class RecoveryManager(MasterHandler):
logging.info('creating a new partition table') logging.info('creating a new partition table')
pt.make(node_list) pt.make(node_list)
self._notifyAdmins(Packets.SendPartitionTable( self._notifyAdmins(Packets.SendPartitionTable(
pt.getID(), pt.getRowList())) pt.getID(), pt.getReplicas(), pt.getRowList()))
else: else:
cell_list = pt.outdate() cell_list = pt.outdate()
if cell_list: if cell_list:
self._notifyAdmins(Packets.NotifyPartitionChanges( self._notifyAdmins(Packets.NotifyPartitionChanges(
pt.setNextID(), cell_list)) pt.setNextID(), pt.getReplicas(), cell_list))
if app.backup_tid: if app.backup_tid:
pt.setBackupTidDict(self.backup_tid_dict) pt.setBackupTidDict(self.backup_tid_dict)
app.backup_tid = pt.getBackupTid() app.backup_tid = pt.getBackupTid()
...@@ -175,16 +174,16 @@ class RecoveryManager(MasterHandler): ...@@ -175,16 +174,16 @@ class RecoveryManager(MasterHandler):
if node is None or node.getState() == new_state: if node is None or node.getState() == new_state:
return return
node.setState(new_state) node.setState(new_state)
# broadcast to all so that admin nodes gets informed
self.app.broadcastNodesInformation([node]) self.app.broadcastNodesInformation([node])
def connectionCompleted(self, conn, new): def handlerSwitched(self, conn, new):
# ask the last IDs to perform the recovery # ask the last IDs to perform the recovery
conn.ask(Packets.AskRecovery()) conn.ask(Packets.AskRecovery())
def answerRecovery(self, conn, ptid, backup_tid, truncate_tid): def answerRecovery(self, conn, ptid, backup_tid, truncate_tid):
uuid = conn.getUUID() uuid = conn.getUUID()
if self.target_ptid <= ptid: # ptid is None if the node has an empty partition table.
if ptid and self.target_ptid <= ptid:
# Maybe a newer partition table. # Maybe a newer partition table.
if self.target_ptid == ptid and self.ask_pt: if self.target_ptid == ptid and self.ask_pt:
# Another node is already asked. # Another node is already asked.
...@@ -197,17 +196,14 @@ class RecoveryManager(MasterHandler): ...@@ -197,17 +196,14 @@ class RecoveryManager(MasterHandler):
self.backup_tid_dict[uuid] = backup_tid self.backup_tid_dict[uuid] = backup_tid
self.truncate_dict[uuid] = truncate_tid self.truncate_dict[uuid] = truncate_tid
def answerPartitionTable(self, conn, ptid, row_list): def answerPartitionTable(self, conn, ptid, num_replicas, row_list):
# If this is not from a target node, ignore it. # If this is not from a target node, ignore it.
if ptid == self.target_ptid: if ptid == self.target_ptid:
app = self.app app = self.app
try: new_nodes = app.pt.load(ptid, num_replicas, row_list, app.nm)
new_nodes = app.pt.load(ptid, row_list, app.nm)
except IndexError:
raise ProtocolError('Invalid offset')
self._notifyAdmins( self._notifyAdmins(
Packets.NotifyNodeInformation(monotonic_time(), new_nodes), Packets.NotifyNodeInformation(monotonic_time(), new_nodes),
Packets.SendPartitionTable(ptid, row_list)) Packets.SendPartitionTable(ptid, num_replicas, row_list))
self.ask_pt = () self.ask_pt = ()
uuid = conn.getUUID() uuid = conn.getUUID()
app.backup_tid = self.backup_tid_dict[uuid] app.backup_tid = self.backup_tid_dict[uuid]
......
...@@ -30,6 +30,7 @@ action_dict = { ...@@ -30,6 +30,7 @@ action_dict = {
}, },
'set': { 'set': {
'cluster': 'setClusterState', 'cluster': 'setClusterState',
'replicas': 'setNumReplicas',
}, },
'check': 'checkReplicas', 'check': 'checkReplicas',
'start': 'startCluster', 'start': 'startCluster',
...@@ -108,7 +109,7 @@ class TerminalNeoCTL(object): ...@@ -108,7 +109,7 @@ class TerminalNeoCTL(object):
ptid, row_list = self.neoctl.getPartitionRowList( ptid, row_list = self.neoctl.getPartitionRowList(
min_offset=min_offset, max_offset=max_offset, node=node) min_offset=min_offset, max_offset=max_offset, node=node)
# TODO: return ptid # TODO: return ptid
return self.formatRowList(row_list) return self.formatRowList(enumerate(row_list, min_offset))
def getNodeList(self, params): def getNodeList(self, params):
""" """
...@@ -140,6 +141,18 @@ class TerminalNeoCTL(object): ...@@ -140,6 +141,18 @@ class TerminalNeoCTL(object):
assert len(params) == 1 assert len(params) == 1
return self.neoctl.setClusterState(self.asClusterState(params[0])) return self.neoctl.setClusterState(self.asClusterState(params[0]))
def setNumReplicas(self, params):
"""
Set number of replicas.
Parameters: nr
nr: positive number (0 means no redundancy)
"""
assert len(params) == 1
nr = int(params[0])
if nr < 0:
sys.exit('invalid number of replicas')
return self.neoctl.setNumReplicas(nr)
def startCluster(self, params): def startCluster(self, params):
""" """
Starts cluster operation after a startup. Starts cluster operation after a startup.
......
...@@ -97,6 +97,12 @@ class NeoCTL(BaseApplication): ...@@ -97,6 +97,12 @@ class NeoCTL(BaseApplication):
raise RuntimeError(response) raise RuntimeError(response)
return response[2] return response[2]
def setNumReplicas(self, nr):
response = self.__ask(Packets.SetNumReplicas(nr))
if response[0] != Packets.Error or response[1] != ErrorCodes.ACK:
raise RuntimeError(response)
return response[2]
def setClusterState(self, state): def setClusterState(self, state):
""" """
Set cluster state. Set cluster state.
......
...@@ -169,36 +169,27 @@ class Application(BaseApplication): ...@@ -169,36 +169,27 @@ class Application(BaseApplication):
# load configuration # load configuration
self.uuid = dm.getUUID() self.uuid = dm.getUUID()
logging.node(self.name, self.uuid) logging.node(self.name, self.uuid)
num_partitions = dm.getNumPartitions()
num_replicas = dm.getNumReplicas()
ptid = dm.getPTID()
# check partition table configuration
if num_partitions is not None and num_replicas is not None:
if num_partitions <= 0:
raise RuntimeError, 'partitions must be more than zero'
# create a partition table
self.pt = PartitionTable(num_partitions, num_replicas)
logging.info('Configuration loaded:') logging.info('Configuration loaded:')
logging.info('PTID : %s', dump(ptid)) logging.info('PTID : %s', dump(dm.getPTID()))
logging.info('Name : %s', self.name) logging.info('Name : %s', self.name)
logging.info('Partitions: %s', num_partitions)
logging.info('Replicas : %s', num_replicas)
def loadPartitionTable(self): def loadPartitionTable(self):
"""Load a partition table from the database.""" """Load a partition table from the database."""
self.pt.clear()
ptid = self.dm.getPTID() ptid = self.dm.getPTID()
if ptid is None: if ptid is None:
self.pt = PartitionTable(0, 0)
return return
cell_list = [] row_list = []
for offset, uuid, state in self.dm.getPartitionTable(): for offset, uuid, state in self.dm.getPartitionTable():
while len(row_list) <= offset:
row_list.append([])
# register unknown nodes # register unknown nodes
if self.nm.getByUUID(uuid) is None: if self.nm.getByUUID(uuid) is None:
self.nm.createStorage(uuid=uuid) self.nm.createStorage(uuid=uuid)
cell_list.append((offset, uuid, CellStates[state])) row_list[offset].append((uuid, CellStates[state]))
self.pt.update(ptid, cell_list, self.nm) self.pt = object.__new__(PartitionTable)
self.pt.load(ptid, self.dm.getNumReplicas(), row_list, self.nm)
def run(self): def run(self):
try: try:
...@@ -258,29 +249,15 @@ class Application(BaseApplication): ...@@ -258,29 +249,15 @@ class Application(BaseApplication):
Note that I do not accept any connection from non-master nodes Note that I do not accept any connection from non-master nodes
at this stage.""" at this stage."""
pt = self.pt
# search, find, connect and identify to the primary master # search, find, connect and identify to the primary master
bootstrap = BootstrapManager(self, NodeTypes.STORAGE, bootstrap = BootstrapManager(self, NodeTypes.STORAGE,
None if self.new_nid else self.server, None if self.new_nid else self.server,
self.devpath, self.new_nid) self.devpath, self.new_nid)
self.master_node, self.master_conn, num_partitions, num_replicas = \ self.master_node, self.master_conn = bootstrap.getPrimaryConnection()
bootstrap.getPrimaryConnection()
self.dm.setUUID(self.uuid) self.dm.setUUID(self.uuid)
# Reload a partition table from the database. This is necessary # Reload a partition table from the database,
# when a previous primary master died while sending a partition # in case that we're in RECOVERING phase.
# table, because the table might be incomplete.
if pt is not None:
self.loadPartitionTable()
if num_partitions != pt.getPartitions():
raise RuntimeError('the number of partitions is inconsistent')
if pt is None or pt.getReplicas() != num_replicas:
# changing number of replicas is not an issue
self.dm.setNumPartitions(num_partitions)
self.dm.setNumReplicas(num_replicas)
self.pt = PartitionTable(num_partitions, num_replicas)
self.loadPartitionTable() self.loadPartitionTable()
def initialize(self): def initialize(self):
......
...@@ -378,7 +378,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -378,7 +378,7 @@ class ImporterDatabaseManager(DatabaseManager):
conf = self._conf conf = self._conf
db = self.db = buildDatabaseManager(conf['adapter'], db = self.db = buildDatabaseManager(conf['adapter'],
(conf['database'], conf.get('engine'), conf['wait'])) (conf['database'], conf.get('engine'), conf['wait']))
for x in """getConfiguration _setConfiguration setNumPartitions for x in """getConfiguration _setConfiguration _getMaxPartition
query erase getPartitionTable iterAssignedCells query erase getPartitionTable iterAssignedCells
updateCellTID getUnfinishedTIDDict dropUnfinishedData updateCellTID getUnfinishedTIDDict dropUnfinishedData
abortTransaction storeTransaction lockTransaction abortTransaction storeTransaction lockTransaction
...@@ -396,7 +396,7 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -396,7 +396,7 @@ class ImporterDatabaseManager(DatabaseManager):
self._writeback.committed() self._writeback.committed()
self.commit = db.commit = commit self.commit = db.commit = commit
def _updateReadable(self): def _updateReadable(*_):
raise AssertionError raise AssertionError
def setUUID(self, nid): def setUUID(self, nid):
...@@ -443,7 +443,8 @@ class ImporterDatabaseManager(DatabaseManager): ...@@ -443,7 +443,8 @@ class ImporterDatabaseManager(DatabaseManager):
self.zodb_ltid = max(x.ltid for x in self.zodb) self.zodb_ltid = max(x.ltid for x in self.zodb)
zodb = self.zodb[-1] zodb = self.zodb[-1]
self.zodb_loid = zodb.shift_oid + zodb.next_oid - 1 self.zodb_loid = zodb.shift_oid + zodb.next_oid - 1
self.zodb_tid = self.db.getLastTID(self.zodb_ltid) or 0 self.zodb_tid = self._getMaxPartition() is not None and \
self.db.getLastTID(self.zodb_ltid) or 0
if callable(self._import): # XXX: why ? if callable(self._import): # XXX: why ?
if self.zodb_tid == self.zodb_ltid: if self.zodb_tid == self.zodb_ltid:
self._finished() self._finished()
...@@ -723,7 +724,7 @@ class WriteBack(object): ...@@ -723,7 +724,7 @@ class WriteBack(object):
self._event = Event() self._event = Event()
self._idle = Event() self._idle = Event()
self._stop = Event() self._stop = Event()
self._np = self._db.getNumPartitions() self._np = 1 + self._db._getMaxPartition()
self._db = cPickle.dumps(self._db, 2) self._db = cPickle.dumps(self._db, 2)
self._process = Process(target=self._run) self._process = Process(target=self._run)
self._process.daemon = True self._process.daemon = True
......
...@@ -102,25 +102,24 @@ class DatabaseManager(object): ...@@ -102,25 +102,24 @@ class DatabaseManager(object):
finally: finally:
db.close() db.close()
_cached_attr_list = (
'_readable_set', '_getPartition', '_getReadablePartition')
def __getattr__(self, attr): def __getattr__(self, attr):
if attr in ('_readable_set', '_getPartition', '_getReadablePartition'): if attr in self._cached_attr_list:
self._updateReadable() self._updateReadable()
return self.__getattribute__(attr) return self.__getattribute__(attr)
def _partitionTableChanged(self):
try:
del (self._readable_set,
self._getPartition,
self._getReadablePartition)
except AttributeError:
pass
def __enter__(self): def __enter__(self):
assert not self.LOCK, "not a secondary connection" assert not self.LOCK, "not a secondary connection"
# XXX: All config caching should be done in this class, # XXX: All config caching should be done in this class,
# rather than in backend classes. # rather than in backend classes.
self._config.clear() self._config.clear()
self._partitionTableChanged() try:
for attr in self._cached_attr_list:
delattr(self, attr)
except AttributeError:
pass
def __exit__(self, t, v, tb): def __exit__(self, t, v, tb):
if v is None: if v is None:
...@@ -309,21 +308,6 @@ class DatabaseManager(object): ...@@ -309,21 +308,6 @@ class DatabaseManager(object):
for x, tid in ((x, None), (nid, tid))) for x, tid in ((x, None), (nid, tid)))
self.setConfiguration('nid', str(nid)) self.setConfiguration('nid', str(nid))
def getNumPartitions(self):
"""
Load the number of partitions from a database.
"""
n = self.getConfiguration('partitions')
if n is not None:
return int(n)
def setNumPartitions(self, num_partitions):
"""
Store the number of partitions into a database.
"""
self.setConfiguration('partitions', num_partitions)
self._partitionTableChanged()
def getNumReplicas(self): def getNumReplicas(self):
""" """
Load the number of replicas from a database. Load the number of replicas from a database.
...@@ -332,12 +316,6 @@ class DatabaseManager(object): ...@@ -332,12 +316,6 @@ class DatabaseManager(object):
if n is not None: if n is not None:
return int(n) return int(n)
def setNumReplicas(self, num_replicas):
"""
Store the number of replicas into a database.
"""
self.setConfiguration('replicas', num_replicas)
def getName(self): def getName(self):
""" """
Load a name from a database. Load a name from a database.
...@@ -398,8 +376,9 @@ class DatabaseManager(object): ...@@ -398,8 +376,9 @@ class DatabaseManager(object):
tids are in unpacked format. tids are in unpacked format.
""" """
if self.getNumPartitions(): x = self._readable_set
return max(self._getLastTID(x, max_tid) for x in self._readable_set) if x:
return max(self._getLastTID(x, max_tid) for x in x)
def _getLastIDs(self, partition): def _getLastIDs(self, partition):
"""Return max(tid) & max(oid) for objects of given partition """Return max(tid) & max(oid) for objects of given partition
...@@ -560,13 +539,15 @@ class DatabaseManager(object): ...@@ -560,13 +539,15 @@ class DatabaseManager(object):
""" """
""" """
@requires(_getDataLastId) def _getMaxPartition(self):
def _updateReadable(self): """
try: """
readable_set = self.__dict__['_readable_set']
except KeyError: @requires(_getDataLastId, _getMaxPartition)
def _updateReadable(self, reset=True):
if reset:
readable_set = self._readable_set = set() readable_set = self._readable_set = set()
np = self.getNumPartitions() np = 1 + self._getMaxPartition()
def _getPartition(x, np=np): def _getPartition(x, np=np):
return x % np return x % np
def _getReadablePartition(x, np=np, r=readable_set): def _getReadablePartition(x, np=np, r=readable_set):
...@@ -581,12 +562,13 @@ class DatabaseManager(object): ...@@ -581,12 +562,13 @@ class DatabaseManager(object):
i = self._getDataLastId(p) i = self._getDataLastId(p)
d.append(p << 48 if i is None else i + 1) d.append(p << 48 if i is None else i + 1)
else: else:
readable_set = self._readable_set
readable_set.clear() readable_set.clear()
readable_set.update(x[0] for x in self.iterAssignedCells() readable_set.update(x[0] for x in self.iterAssignedCells()
if -x[1] in READABLE) if -x[1] in READABLE)
@requires(_changePartitionTable, _getLastIDs, _getLastTID) @requires(_changePartitionTable, _getLastIDs, _getLastTID)
def changePartitionTable(self, ptid, cell_list, reset=False): def changePartitionTable(self, ptid, num_replicas, cell_list, reset=False):
my_nid = self.getUUID() my_nid = self.getUUID()
pt = dict(self.iterAssignedCells()) pt = dict(self.iterAssignedCells())
# In backup mode, the last transactions of a readable cell may be # In backup mode, the last transactions of a readable cell may be
...@@ -607,9 +589,10 @@ class DatabaseManager(object): ...@@ -607,9 +589,10 @@ class DatabaseManager(object):
outofdate_tid(offset))) outofdate_tid(offset)))
for offset, nid, state in cell_list] for offset, nid, state in cell_list]
self._changePartitionTable(cell_list, reset) self._changePartitionTable(cell_list, reset)
self._updateReadable() self._updateReadable(reset)
assert isinstance(ptid, (int, long)), ptid assert isinstance(ptid, (int, long)), ptid
self._setConfiguration('ptid', str(ptid)) self._setConfiguration('ptid', str(ptid))
self._setConfiguration('replicas', str(num_replicas))
@requires(_changePartitionTable) @requires(_changePartitionTable)
def updateCellTID(self, partition, tid): def updateCellTID(self, partition, tid):
......
...@@ -270,6 +270,12 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -270,6 +270,12 @@ class MySQLDatabaseManager(DatabaseManager):
" ELSE 1-state" " ELSE 1-state"
" END as tid") " END as tid")
# Let's wait for a more important change to clean up,
# so that users can still downgrade.
if 0:
def _migrate4(self, schema_dict):
self._setConfiguration('partitions', None)
def _setup(self, dedup=False): def _setup(self, dedup=False):
self._config.clear() self._config.clear()
q = self.query q = self.query
...@@ -407,6 +413,9 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -407,6 +413,9 @@ class MySQLDatabaseManager(DatabaseManager):
q("ALTER TABLE config MODIFY value VARBINARY(%s) NULL" % len(value)) q("ALTER TABLE config MODIFY value VARBINARY(%s) NULL" % len(value))
q(sql) q(sql)
def _getMaxPartition(self):
return self.query("SELECT MAX(`partition`) FROM pt")[0][0]
def _getPartitionTable(self): def _getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
......
...@@ -144,6 +144,12 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -144,6 +144,12 @@ class SQLiteDatabaseManager(DatabaseManager):
" WHEN 2 THEN -2" # FEEDING " WHEN 2 THEN -2" # FEEDING
" ELSE 1-state END") " ELSE 1-state END")
# Let's wait for a more important change to clean up,
# so that users can still downgrade.
if 0:
def _migrate4(self, schema_dict, index_dict):
self._setConfiguration('partitions', None)
def _setup(self, dedup=False): def _setup(self, dedup=False):
# BBB: SQLite has transactional DDL but before Python 3.6, # BBB: SQLite has transactional DDL but before Python 3.6,
# the binding automatically commits between such statements. # the binding automatically commits between such statements.
...@@ -265,6 +271,9 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -265,6 +271,9 @@ class SQLiteDatabaseManager(DatabaseManager):
else: else:
q("REPLACE INTO config VALUES (?,?)", (key, str(value))) q("REPLACE INTO config VALUES (?,?)", (key, str(value)))
def _getMaxPartition(self):
return self.query("SELECT MAX(`partition`) FROM pt").next()[0]
def _getPartitionTable(self): def _getPartitionTable(self):
return self.query("SELECT * FROM pt") return self.query("SELECT * FROM pt")
......
...@@ -65,14 +65,14 @@ class BaseMasterHandler(BaseHandler): ...@@ -65,14 +65,14 @@ class BaseMasterHandler(BaseHandler):
# See comment in ClientOperationHandler.connectionClosed # See comment in ClientOperationHandler.connectionClosed
self.app.tm.abortFor(uuid, even_if_voted=True) self.app.tm.abortFor(uuid, even_if_voted=True)
def notifyPartitionChanges(self, conn, ptid, cell_list): def notifyPartitionChanges(self, conn, ptid, num_replicas, cell_list):
"""This is very similar to Send Partition Table, except that """This is very similar to Send Partition Table, except that
the information is only about changes from the previous.""" the information is only about changes from the previous."""
app = self.app app = self.app
if ptid != 1 + app.pt.getID(): if ptid != 1 + app.pt.getID():
raise ProtocolError('wrong partition table id') raise ProtocolError('wrong partition table id')
app.pt.update(ptid, cell_list, app.nm) app.pt.update(ptid, num_replicas, cell_list, app.nm)
app.dm.changePartitionTable(ptid, cell_list) app.dm.changePartitionTable(ptid, num_replicas, cell_list)
if app.operational: if app.operational:
app.replicator.notifyPartitionChanges(cell_list) app.replicator.notifyPartitionChanges(cell_list)
app.dm.commit() app.dm.commit()
......
...@@ -65,6 +65,6 @@ class IdentificationHandler(EventHandler): ...@@ -65,6 +65,6 @@ class IdentificationHandler(EventHandler):
conn.setHandler(handler) conn.setHandler(handler)
node.setConnection(conn, force) node.setConnection(conn, force)
# accept the identification and trigger an event # accept the identification and trigger an event
conn.answer(Packets.AcceptIdentification(NodeTypes.STORAGE, uuid and conn.answer(Packets.AcceptIdentification(
app.uuid, app.pt.getPartitions(), app.pt.getReplicas(), uuid)) NodeTypes.STORAGE, uuid and app.uuid, uuid))
handler.connectionCompleted(conn) handler.connectionCompleted(conn)
...@@ -20,10 +20,10 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID ...@@ -20,10 +20,10 @@ from neo.lib.protocol import Packets, ProtocolError, ZERO_TID
class InitializationHandler(BaseMasterHandler): class InitializationHandler(BaseMasterHandler):
def sendPartitionTable(self, conn, ptid, row_list): def sendPartitionTable(self, conn, ptid, num_replicas, row_list):
app = self.app app = self.app
pt = app.pt pt = app.pt
pt.load(ptid, row_list, app.nm) pt.load(ptid, num_replicas, row_list, app.nm)
if not pt.filled(): if not pt.filled():
raise ProtocolError('Partial partition table received') raise ProtocolError('Partial partition table received')
# Install the partition table into the database for persistence. # Install the partition table into the database for persistence.
...@@ -44,7 +44,7 @@ class InitializationHandler(BaseMasterHandler): ...@@ -44,7 +44,7 @@ class InitializationHandler(BaseMasterHandler):
logging.debug('drop data for partitions %r', unassigned) logging.debug('drop data for partitions %r', unassigned)
dm.dropPartitions(unassigned) dm.dropPartitions(unassigned)
dm.changePartitionTable(ptid, cell_list, reset=True) dm.changePartitionTable(ptid, num_replicas, cell_list, reset=True)
dm.commit() dm.commit()
def truncate(self, conn, tid): def truncate(self, conn, tid):
...@@ -68,7 +68,8 @@ class InitializationHandler(BaseMasterHandler): ...@@ -68,7 +68,8 @@ class InitializationHandler(BaseMasterHandler):
def askPartitionTable(self, conn): def askPartitionTable(self, conn):
pt = self.app.pt pt = self.app.pt
conn.answer(Packets.AnswerPartitionTable(pt.getID(), pt.getRowList())) conn.answer(Packets.AnswerPartitionTable(
pt.getID(), pt.getReplicas(), pt.getRowList()))
def askLockedTransactions(self, conn): def askLockedTransactions(self, conn):
conn.answer(Packets.AnswerLockedTransactions( conn.answer(Packets.AnswerLockedTransactions(
......
...@@ -98,9 +98,12 @@ class TransactionManager(EventQueue): ...@@ -98,9 +98,12 @@ class TransactionManager(EventQueue):
self._load_lock_dict = {} self._load_lock_dict = {}
self._replicated = {} self._replicated = {}
self._replicating = set() self._replicating = set()
def getPartition(self, oid):
from neo.lib.util import u64 from neo.lib.util import u64
np = app.pt.getPartitions() np = self._app.pt.getPartitions()
self.getPartition = lambda oid: u64(oid) % np self.getPartition = lambda oid: u64(oid) % np
return self.getPartition(oid)
def discarded(self, offset_list): def discarded(self, offset_list):
self._replicating.difference_update(offset_list) self._replicating.difference_update(offset_list)
......
...@@ -656,7 +656,7 @@ class NEOCluster(object): ...@@ -656,7 +656,7 @@ class NEOCluster(object):
row_list = self.neoctl.getPartitionRowList()[1] row_list = self.neoctl.getPartitionRowList()[1]
number_of_outdated = 0 number_of_outdated = 0
for row in row_list: for row in row_list:
for cell in row[1]: for cell in row:
if cell[1] == CellStates.OUT_OF_DATE: if cell[1] == CellStates.OUT_OF_DATE:
number_of_outdated += 1 number_of_outdated += 1
return number_of_outdated == number, number_of_outdated return number_of_outdated == number, number_of_outdated
...@@ -667,7 +667,7 @@ class NEOCluster(object): ...@@ -667,7 +667,7 @@ class NEOCluster(object):
row_list = self.neoctl.getPartitionRowList()[1] row_list = self.neoctl.getPartitionRowList()[1]
assigned_cells_number = 0 assigned_cells_number = 0
for row in row_list: for row in row_list:
for cell in row[1]: for cell in row:
if cell[0] == process.getUUID(): if cell[0] == process.getUUID():
assigned_cells_number += 1 assigned_cells_number += 1
return assigned_cells_number == number, assigned_cells_number return assigned_cells_number == number, assigned_cells_number
......
...@@ -30,8 +30,6 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -30,8 +30,6 @@ class MasterClientHandlerTests(NeoUnitTestBase):
config = self.getMasterConfiguration(master_number=1, replicas=1) config = self.getMasterConfiguration(master_number=1, replicas=1)
self.app = Application(config) self.app = Application(config)
self.app.em.close() self.app.em.close()
self.app.pt.clear()
self.app.pt.setID(1)
self.app.em = Mock() self.app.em = Mock()
self.app.loid = '\0' * 8 self.app.loid = '\0' * 8
self.app.tm.setLastTID('\0' * 8) self.app.tm.setLastTID('\0' * 8)
......
...@@ -26,7 +26,6 @@ class MasterAppTests(NeoUnitTestBase): ...@@ -26,7 +26,6 @@ class MasterAppTests(NeoUnitTestBase):
# create an application object # create an application object
config = self.getMasterConfiguration() config = self.getMasterConfiguration()
self.app = Application(config) self.app = Application(config)
self.app.pt.clear()
def _tearDown(self, success): def _tearDown(self, success):
self.app.close() self.app.close()
......
...@@ -18,8 +18,8 @@ import unittest ...@@ -18,8 +18,8 @@ import unittest
from ..mock import Mock from ..mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, Packets from neo.lib.protocol import NodeTypes, Packets
from neo.master.handlers.storage import StorageServiceHandler
from neo.master.app import Application from neo.master.app import Application
from neo.master.handlers.storage import StorageServiceHandler
class MasterStorageHandlerTests(NeoUnitTestBase): class MasterStorageHandlerTests(NeoUnitTestBase):
...@@ -29,7 +29,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -29,7 +29,6 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
config = self.getMasterConfiguration(master_number=1, replicas=1) config = self.getMasterConfiguration(master_number=1, replicas=1)
self.app = Application(config) self.app = Application(config)
self.app.em.close() self.app.em.close()
self.app.pt.clear()
self.app.em = Mock() self.app.em = Mock()
self.service = StorageServiceHandler(self.app) self.service = StorageServiceHandler(self.app)
......
...@@ -56,7 +56,7 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -56,7 +56,7 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
self.app.pt = Mock({'getID': 1}) self.app.pt = Mock({'getID': 1})
count = len(self.app.nm.getList()) count = len(self.app.nm.getList())
self.assertRaises(ProtocolError, self.operation.notifyPartitionChanges, self.assertRaises(ProtocolError, self.operation.notifyPartitionChanges,
conn, 0, ()) conn, 0, 0, ())
self.assertEqual(self.app.pt.getID(), 1) self.assertEqual(self.app.pt.getID(), 1)
self.assertEqual(len(self.app.nm.getList()), count) self.assertEqual(len(self.app.nm.getList()), count)
calls = self.app.replicator.mockGetNamedCalls('removePartition') calls = self.app.replicator.mockGetNamedCalls('removePartition')
...@@ -84,13 +84,13 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -84,13 +84,13 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
ptid = 2 ptid = 2
app.dm = Mock({ }) app.dm = Mock({ })
app.replicator = Mock({}) app.replicator = Mock({})
self.operation.notifyPartitionChanges(conn, ptid, cells) self.operation.notifyPartitionChanges(conn, ptid, 1, cells)
# ptid set # ptid set
self.assertEqual(app.pt.getID(), ptid) self.assertEqual(app.pt.getID(), ptid)
# dm call # dm call
calls = self.app.dm.mockGetNamedCalls('changePartitionTable') calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(ptid, cells) calls[0].checkArgs(ptid, 1, cells)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -48,30 +48,15 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -48,30 +48,15 @@ class StorageDBTests(NeoUnitTestBase):
raise NotImplementedError raise NotImplementedError
def setNumPartitions(self, num_partitions, reset=0): def setNumPartitions(self, num_partitions, reset=0):
try: assert not hasattr(self, '_db')
db = self._db
except AttributeError:
self._db = db = self.getDB(reset) self._db = db = self.getDB(reset)
else:
if reset:
db.setup(reset)
else:
try:
n = db.getNumPartitions()
except KeyError:
n = 0
if num_partitions == n:
return
if num_partitions < n:
db.dropPartitions(n)
db.setNumPartitions(num_partitions)
self.assertEqual(num_partitions, db.getNumPartitions())
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
db.setUUID(uuid) db.setUUID(uuid)
self.assertEqual(uuid, db.getUUID()) self.assertEqual(uuid, db.getUUID())
db.changePartitionTable(1, db.changePartitionTable(1, 0,
[(i, uuid, CellStates.UP_TO_DATE) for i in xrange(num_partitions)], [(i, uuid, CellStates.UP_TO_DATE) for i in xrange(num_partitions)],
reset=True) reset=True)
self.assertEqual(num_partitions, 1 + db._getMaxPartition())
db.commit() db.commit()
def checkConfigEntry(self, get_call, set_call, value): def checkConfigEntry(self, get_call, set_call, value):
...@@ -102,16 +87,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -102,16 +87,6 @@ class StorageDBTests(NeoUnitTestBase):
db = self.getDB() db = self.getDB()
self.checkConfigEntry(db.getName, db.setName, 'TEST_NAME') self.checkConfigEntry(db.getName, db.setName, 'TEST_NAME')
def test_getPartitionTable(self):
db = self.getDB()
db.setNumPartitions(3)
uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID()
cell1 = (0, uuid1, CellStates.OUT_OF_DATE)
cell2 = (1, uuid1, CellStates.UP_TO_DATE)
db.changePartitionTable(1, [cell1, cell2], 1)
result = db.getPartitionTable()
self.assertEqual(set(result), {cell1, cell2})
def getOIDs(self, count): def getOIDs(self, count):
return map(p64, xrange(count)) return map(p64, xrange(count))
...@@ -202,52 +177,6 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -202,52 +177,6 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getObject(oid1, before_tid=tid2), self.assertEqual(self.db.getObject(oid1, before_tid=tid2),
OBJECT_T1_NEXT) OBJECT_T1_NEXT)
def test_setPartitionTable(self):
db = self.getDB()
db.setNumPartitions(3)
ptid = 1
uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE
cell3 = 1, uuid, CellStates.DISCARDED
# no partition table
self.assertEqual(list(db.getPartitionTable()), [])
# set one
db.changePartitionTable(ptid, [cell1], 1)
result = db.getPartitionTable()
self.assertEqual(list(result), [cell1])
# then another
db.changePartitionTable(ptid, [cell2], 1)
result = db.getPartitionTable()
self.assertEqual(list(result), [cell2])
# drop discarded cells
db.changePartitionTable(ptid, [cell2, cell3], 1)
result = db.getPartitionTable()
self.assertEqual(list(result), [])
def test_changePartitionTable(self):
db = self.getDB()
db.setNumPartitions(3)
ptid = 1
uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE
cell3 = 1, uuid, CellStates.DISCARDED
# no partition table
self.assertEqual(list(db.getPartitionTable()), [])
# set one
db.changePartitionTable(ptid, [cell1])
result = db.getPartitionTable()
self.assertEqual(list(result), [cell1])
# add more entries
db.changePartitionTable(ptid, [cell2])
result = db.getPartitionTable()
self.assertEqual(set(result), {cell1, cell2})
# drop discarded cells
db.changePartitionTable(ptid, [cell2, cell3])
result = db.getPartitionTable()
self.assertEqual(list(result), [cell1])
def test_commitTransaction(self): def test_commitTransaction(self):
oid1, oid2 = self.getOIDs(2) oid1, oid2 = self.getOIDs(2)
tid1, tid2 = self.getTIDs(2) tid1, tid2 = self.getTIDs(2)
......
...@@ -19,12 +19,9 @@ class Handler(MasterEventHandler): ...@@ -19,12 +19,9 @@ class Handler(MasterEventHandler):
super(Handler, self).answerClusterState(conn, state) super(Handler, self).answerClusterState(conn, state)
self.app.refresh('state') self.app.refresh('state')
def answerPartitionTable(self, *args):
super(Handler, self).answerPartitionTable(*args)
self.app.refresh('pt')
def sendPartitionTable(self, *args): def sendPartitionTable(self, *args):
raise AssertionError super(Handler, self).sendPartitionTable(*args)
self.app.refresh('pt')
def notifyPartitionChanges(self, *args): def notifyPartitionChanges(self, *args):
super(Handler, self).notifyPartitionChanges(*args) super(Handler, self).notifyPartitionChanges(*args)
......
...@@ -814,7 +814,7 @@ class NEOCluster(object): ...@@ -814,7 +814,7 @@ class NEOCluster(object):
master_list = self.master_list master_list = self.master_list
if storage_list is None: if storage_list is None:
storage_list = self.storage_list storage_list = self.storage_list
def answerPartitionTable(release, orig, *args): def sendPartitionTable(release, orig, *args):
orig(*args) orig(*args)
release() release()
def dispatch(release, orig, handler, *args): def dispatch(release, orig, handler, *args):
...@@ -830,7 +830,7 @@ class NEOCluster(object): ...@@ -830,7 +830,7 @@ class NEOCluster(object):
if state in expected_state: if state in expected_state:
release() release()
with Serialized.until(MasterEventHandler, with Serialized.until(MasterEventHandler,
answerPartitionTable=answerPartitionTable) as tic1, \ sendPartitionTable=sendPartitionTable) as tic1, \
Serialized.until(RecoveryManager, dispatch=dispatch) as tic2, \ Serialized.until(RecoveryManager, dispatch=dispatch) as tic2, \
Serialized.until(MasterEventHandler, Serialized.until(MasterEventHandler,
notifyClusterInformation=notifyClusterInformation) as tic3: notifyClusterInformation=notifyClusterInformation) as tic3:
......
...@@ -42,6 +42,7 @@ from neo.lib.util import add64, makeChecksum, p64, u64 ...@@ -42,6 +42,7 @@ from neo.lib.util import add64, makeChecksum, p64, u64
from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError from neo.client.exception import NEOPrimaryMasterLost, NEOStorageError
from neo.client.transactions import Transaction from neo.client.transactions import Transaction
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.master.pt import PartitionTable
from neo.storage.database import DatabaseFailure from neo.storage.database import DatabaseFailure
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.storage.handlers.identification import IdentificationHandler from neo.storage.handlers.identification import IdentificationHandler
...@@ -1303,7 +1304,7 @@ class Test(NEOThreadedTest): ...@@ -1303,7 +1304,7 @@ class Test(NEOThreadedTest):
del conn._queue[:] # XXX del conn._queue[:] # XXX
conn.close() conn.close()
if 1: if 1:
with Patch(cluster.master.pt, make=make), \ with Patch(PartitionTable, make=make), \
Patch(InitializationHandler, Patch(InitializationHandler,
askPartitionTable=askPartitionTable) as p: askPartitionTable=askPartitionTable) as p:
cluster.start() cluster.start()
...@@ -2328,8 +2329,8 @@ class Test(NEOThreadedTest): ...@@ -2328,8 +2329,8 @@ class Test(NEOThreadedTest):
for x in 'ab': for x in 'ab':
r[x] = PCounterWithResolution() r[x] = PCounterWithResolution()
t1.commit() t1.commit()
cluster.stop(replicas=1) cluster.neoctl.setNumReplicas(1)
cluster.start() self.tic()
s0, s1 = cluster.sortStorageList() s0, s1 = cluster.sortStorageList()
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
r = c1.root() r = c1.root()
...@@ -2513,8 +2514,8 @@ class Test(NEOThreadedTest): ...@@ -2513,8 +2514,8 @@ class Test(NEOThreadedTest):
for x in 'ab': for x in 'ab':
r[x] = PCounterWithResolution() r[x] = PCounterWithResolution()
t1.commit() t1.commit()
cluster.stop(replicas=1) cluster.neoctl.setNumReplicas(1)
cluster.start() self.tic()
s0, s1 = cluster.sortStorageList() s0, s1 = cluster.sortStorageList()
t1, c1 = cluster.getTransaction() t1, c1 = cluster.getTransaction()
r = c1.root() r = c1.root()
...@@ -2817,6 +2818,7 @@ class Test(NEOThreadedTest): ...@@ -2817,6 +2818,7 @@ class Test(NEOThreadedTest):
dump_dict[s.uuid] = dm.dump() dump_dict[s.uuid] = dm.dump()
with open(path % (s.getAdapter(), s.uuid)) as f: with open(path % (s.getAdapter(), s.uuid)) as f:
dm.restore(f.read()) dm.restore(f.read())
dm.setConfiguration('partitions', None) # XXX: see dm._migrate4
with NEOCluster(storage_count=3, partitions=3, replicas=1, with NEOCluster(storage_count=3, partitions=3, replicas=1,
name=self._testMethodName) as cluster: name=self._testMethodName) as cluster:
s1, s2, s3 = cluster.storage_list s1, s2, s3 = cluster.storage_list
......
...@@ -74,6 +74,8 @@ class ReplicationTests(NEOThreadedTest): ...@@ -74,6 +74,8 @@ class ReplicationTests(NEOThreadedTest):
source_dict = {x.uuid: x for x in cluster.upstream.storage_list} source_dict = {x.uuid: x for x in cluster.upstream.storage_list}
for storage in cluster.storage_list: for storage in cluster.storage_list:
self.assertFalse(storage.dm._uncommitted_data) self.assertFalse(storage.dm._uncommitted_data)
if storage.pt is None:
storage.loadPartitionTable()
self.assertEqual(np, storage.pt.getPartitions()) self.assertEqual(np, storage.pt.getPartitions())
for partition in pt.getAssignedPartitionList(storage.uuid): for partition in pt.getAssignedPartitionList(storage.uuid):
cell_list = upstream_pt.getCellList(partition, readable=True) cell_list = upstream_pt.getCellList(partition, readable=True)
...@@ -89,6 +91,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -89,6 +91,7 @@ class ReplicationTests(NEOThreadedTest):
checksum_list = [ checksum_list = [
self.checksumPartition(storage_dict[x.getUUID()], offset) self.checksumPartition(storage_dict[x.getUUID()], offset)
for x in pt.getCellList(offset)] for x in pt.getCellList(offset)]
self.assertLess(1, len(checksum_list))
self.assertEqual(1, len(set(checksum_list)), self.assertEqual(1, len(set(checksum_list)),
(offset, checksum_list)) (offset, checksum_list))
...@@ -445,13 +448,13 @@ class ReplicationTests(NEOThreadedTest): ...@@ -445,13 +448,13 @@ class ReplicationTests(NEOThreadedTest):
return isinstance(packet, delayed) and \ return isinstance(packet, delayed) and \
packet.decode()[0] == offset and \ packet.decode()[0] == offset and \
conn in s1.getConnectionList(s0) conn in s1.getConnectionList(s0)
def changePartitionTable(orig, ptid, cell_list): def changePartitionTable(orig, ptid, num_replicas, cell_list):
if (offset, s0.uuid, CellStates.DISCARDED) in cell_list: if (offset, s0.uuid, CellStates.DISCARDED) in cell_list:
connection_filter.remove(delayAskFetch) connection_filter.remove(delayAskFetch)
# XXX: this is currently not done by # XXX: this is currently not done by
# default for performance reason # default for performance reason
orig.im_self.dropPartitions((offset,)) orig.im_self.dropPartitions((offset,))
return orig(ptid, cell_list) return orig(ptid, num_replicas, cell_list)
np = cluster.num_partitions np = cluster.num_partitions
s0, s1, s2 = cluster.storage_list s0, s1, s2 = cluster.storage_list
for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects: for delayed in Packets.AskFetchTransactions, Packets.AskFetchObjects:
...@@ -511,7 +514,9 @@ class ReplicationTests(NEOThreadedTest): ...@@ -511,7 +514,9 @@ class ReplicationTests(NEOThreadedTest):
for x in 'ab': for x in 'ab':
r[x] = PCounter() r[x] = PCounter()
t.commit() t.commit()
cluster.stop(replicas=1) cluster.neoctl.setNumReplicas(1)
self.tic()
cluster.stop()
cluster.start((s1, s2)) cluster.start((s1, s2))
with ConnectionFilter() as f: with ConnectionFilter() as f:
f.delayAddObject() f.delayAddObject()
...@@ -940,7 +945,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -940,7 +945,7 @@ class ReplicationTests(NEOThreadedTest):
self.tic() self.tic()
with Patch(cluster, storage_list=s01): with Patch(cluster, storage_list=s01):
cluster.sortStorageList() cluster.sortStorageList()
cluster.stop(replicas=1) cluster.stop()
cluster.storage_list[:2] = s01 cluster.storage_list[:2] = s01
storage_dict = {} storage_dict = {}
for s, d in zip(s01, s23): for s, d in zip(s01, s23):
...@@ -957,6 +962,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -957,6 +962,7 @@ class ReplicationTests(NEOThreadedTest):
self.checkReplicas(cluster) self.checkReplicas(cluster)
expected = '|'.join(['U.U.|.U.U'] * 3) expected = '|'.join(['U.U.|.U.U'] * 3)
self.assertPartitionTable(cluster, expected) self.assertPartitionTable(cluster, expected)
cluster.neoctl.setNumReplicas(1)
cluster.neoctl.tweakPartitionTable() cluster.neoctl.tweakPartitionTable()
self.tic() self.tic()
self.assertPartitionTable(cluster, expected) self.assertPartitionTable(cluster, expected)
...@@ -974,7 +980,7 @@ class ReplicationTests(NEOThreadedTest): ...@@ -974,7 +980,7 @@ class ReplicationTests(NEOThreadedTest):
def check(expected_state, expected_count): def check(expected_state, expected_count):
self.assertEqual(expected_count, len([None self.assertEqual(expected_count, len([None
for row in cluster.neoctl.getPartitionRowList()[1] for row in cluster.neoctl.getPartitionRowList()[1]
for cell in row[1] for cell in row
if cell[1] == CellStates.CORRUPTED])) if cell[1] == CellStates.CORRUPTED]))
self.assertEqual(expected_state, cluster.neoctl.getClusterState()) self.assertEqual(expected_state, cluster.neoctl.getClusterState())
np = cluster.num_partitions np = cluster.num_partitions
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment