Commit 5941b27d authored by Julien Muchembled's avatar Julien Muchembled

lib.node: code refactoring

parent c17f5f91
...@@ -172,60 +172,6 @@ class Node(object): ...@@ -172,60 +172,6 @@ class Node(object):
id(self), id(self),
) )
def isMaster(self):
return False
def isStorage(self):
return False
def isClient(self):
return False
def isAdmin(self):
return False
def isRunning(self):
return self._state == NodeStates.RUNNING
def isUnknown(self):
return self._state == NodeStates.UNKNOWN
def isTemporarilyDown(self):
return self._state == NodeStates.TEMPORARILY_DOWN
def isDown(self):
return self._state == NodeStates.DOWN
def isBroken(self):
return self._state == NodeStates.BROKEN
def isHidden(self):
return self._state == NodeStates.HIDDEN
def isPending(self):
return self._state == NodeStates.PENDING
def setRunning(self):
self.setState(NodeStates.RUNNING)
def setUnknown(self):
self.setState(NodeStates.UNKNOWN)
def setTemporarilyDown(self):
self.setState(NodeStates.TEMPORARILY_DOWN)
def setDown(self):
self.setState(NodeStates.DOWN)
def setBroken(self):
self.setState(NodeStates.BROKEN)
def setHidden(self):
self.setState(NodeStates.HIDDEN)
def setPending(self):
self.setState(NodeStates.PENDING)
def asTuple(self): def asTuple(self):
""" Returned tuple is intended to be used in protocol encoders """ """ Returned tuple is intended to be used in protocol encoders """
return (self.getType(), self._address, self._uuid, self._state) return (self.getType(), self._address, self._uuid, self._state)
...@@ -236,12 +182,6 @@ class Node(object): ...@@ -236,12 +182,6 @@ class Node(object):
return self._uuid > node._uuid return self._uuid > node._uuid
return self._address > node._address return self._address > node._address
def getType(self):
try:
return NODE_CLASS_MAPPING[self.__class__]
except KeyError:
raise NotImplementedError
def whoSetState(self): def whoSetState(self):
""" """
Debugging method: call this method to know who set the current Debugging method: call this method to know who set the current
...@@ -251,43 +191,6 @@ class Node(object): ...@@ -251,43 +191,6 @@ class Node(object):
attributeTracker.track(Node) attributeTracker.track(Node)
class MasterNode(Node):
"""This class represents a master node."""
def isMaster(self):
return True
class StorageNode(Node):
"""This class represents a storage node."""
def isStorage(self):
return True
class ClientNode(Node):
"""This class represents a client node."""
def isClient(self):
return True
class AdminNode(Node):
"""This class represents an admin node."""
def isAdmin(self):
return True
NODE_TYPE_MAPPING = {
NodeTypes.MASTER: MasterNode,
NodeTypes.STORAGE: StorageNode,
NodeTypes.CLIENT: ClientNode,
NodeTypes.ADMIN: AdminNode,
}
NODE_CLASS_MAPPING = {
StorageNode: NodeTypes.STORAGE,
MasterNode: NodeTypes.MASTER,
ClientNode: NodeTypes.CLIENT,
AdminNode: NodeTypes.ADMIN,
}
class MasterDB(object): class MasterDB(object):
""" """
...@@ -361,7 +264,7 @@ class NodeManager(object): ...@@ -361,7 +264,7 @@ class NodeManager(object):
self._node_set.add(node) self._node_set.add(node)
self._updateAddress(node, None) self._updateAddress(node, None)
self._updateUUID(node, None) self._updateUUID(node, None)
self.__updateSet(self._type_dict, None, node.__class__, node) self.__updateSet(self._type_dict, None, node.getType(), node)
self.__updateSet(self._state_dict, None, node.getState(), node) self.__updateSet(self._state_dict, None, node.getState(), node)
self._updateIdentified(node) self._updateIdentified(node)
if node.isMaster() and self._master_db is not None: if node.isMaster() and self._master_db is not None:
...@@ -372,25 +275,19 @@ class NodeManager(object): ...@@ -372,25 +275,19 @@ class NodeManager(object):
logging.warning('removing unknown node %r, ignoring', node) logging.warning('removing unknown node %r, ignoring', node)
return return
self._node_set.remove(node) self._node_set.remove(node)
self.__drop(self._address_dict, node.getAddress()) # a node may have not be indexed by uuid or address, eg.:
self.__drop(self._uuid_dict, node.getUUID()) # - a client or admin node that don't have listening address
self._address_dict.pop(node.getAddress(), None)
# - a master known by address but without UUID
self._uuid_dict.pop(node.getUUID(), None)
self.__dropSet(self._state_dict, node.getState(), node) self.__dropSet(self._state_dict, node.getState(), node)
self.__dropSet(self._type_dict, node.__class__, node) self.__dropSet(self._type_dict, node.getType(), node)
uuid = node.getUUID() uuid = node.getUUID()
if uuid in self._identified_dict: if uuid in self._identified_dict:
del self._identified_dict[uuid] del self._identified_dict[uuid]
if node.isMaster() and self._master_db is not None: if node.isMaster() and self._master_db is not None:
self._master_db.discard(node.getAddress()) self._master_db.discard(node.getAddress())
def __drop(self, index_dict, key):
try:
del index_dict[key]
except KeyError:
# a node may have not be indexed by uuid or address, eg.:
# - a master known by address but without UUID
# - a client or admin node that don't have listening address
pass
def __update(self, index_dict, old_key, new_key, node): def __update(self, index_dict, old_key, new_key, node):
""" Update an index from old to new key """ """ Update an index from old to new key """
if old_key is not None: if old_key is not None:
...@@ -421,15 +318,14 @@ class NodeManager(object): ...@@ -421,15 +318,14 @@ class NodeManager(object):
self.__update(self._uuid_dict, old_uuid, node.getUUID(), node) self.__update(self._uuid_dict, old_uuid, node.getUUID(), node)
def __dropSet(self, set_dict, key, node): def __dropSet(self, set_dict, key, node):
if key in set_dict and node in set_dict[key]: if key in set_dict:
set_dict[key].remove(node) set_dict[key].remove(node)
def __updateSet(self, set_dict, old_key, new_key, node): def __updateSet(self, set_dict, old_key, new_key, node):
""" Update a set index from old to new key """ """ Update a set index from old to new key """
if old_key in set_dict: if old_key in set_dict:
set_dict[old_key].remove(node) set_dict[old_key].remove(node)
if new_key is not None: set_dict.setdefault(new_key, set()).add(node)
set_dict.setdefault(new_key, set()).add(node)
def _updateState(self, node, old_state): def _updateState(self, node, old_state):
assert not node.isDown(), node assert not node.isDown(), node
...@@ -457,35 +353,16 @@ class NodeManager(object): ...@@ -457,35 +353,16 @@ class NodeManager(object):
# TODO: use an index # TODO: use an index
return [x for x in self._node_set if x.isConnected()] return [x for x in self._node_set if x.isConnected()]
def __getList(self, index_dict, key):
return index_dict.setdefault(key, set())
def getByStateList(self, state): def getByStateList(self, state):
""" Get a node list filtered per the node state """ """ Get a node list filtered per the node state """
return list(self.__getList(self._state_dict, state)) return list(self._state_dict.get(state, ()))
def __getTypeList(self, type_klass, only_identified=False): def _getTypeList(self, node_type, only_identified=False):
node_set = self.__getList(self._type_dict, type_klass) node_set = self._type_dict.get(node_type, ())
if only_identified: if only_identified:
return [x for x in node_set if x.getUUID() in self._identified_dict] return [x for x in node_set if x.getUUID() in self._identified_dict]
return list(node_set) return list(node_set)
def getMasterList(self, only_identified=False):
""" Return a list with master nodes """
return self.__getTypeList(MasterNode, only_identified)
def getStorageList(self, only_identified=False):
""" Return a list with storage nodes """
return self.__getTypeList(StorageNode, only_identified)
def getClientList(self, only_identified=False):
""" Return a list with client nodes """
return self.__getTypeList(ClientNode, only_identified)
def getAdminList(self, only_identified=False):
""" Return a list with admin nodes """
return self.__getTypeList(AdminNode, only_identified)
def getByAddress(self, address): def getByAddress(self, address):
""" Return the node that match with a given address """ """ Return the node that match with a given address """
return self._address_dict.get(address, None) return self._address_dict.get(address, None)
...@@ -494,12 +371,6 @@ class NodeManager(object): ...@@ -494,12 +371,6 @@ class NodeManager(object):
""" Return the node that match with a given UUID """ """ Return the node that match with a given UUID """
return self._uuid_dict.get(uuid, None) return self._uuid_dict.get(uuid, None)
def hasAddress(self, address):
return address in self._address_dict
def hasUUID(self, uuid):
return uuid in self._uuid_dict
def _createNode(self, klass, address=None, uuid=None, **kw): def _createNode(self, klass, address=None, uuid=None, **kw):
by_address = self.getByAddress(address) by_address = self.getByAddress(address)
by_uuid = self.getByUUID(uuid) by_uuid = self.getByUUID(uuid)
...@@ -531,36 +402,14 @@ class NodeManager(object): ...@@ -531,36 +402,14 @@ class NodeManager(object):
assert node.__class__ is klass, (node.__class__, klass) assert node.__class__ is klass, (node.__class__, klass)
return node return node
def createMaster(self, **kw):
""" Create and register a new master """
return self._createNode(MasterNode, **kw)
def createStorage(self, **kw):
""" Create and register a new storage """
return self._createNode(StorageNode, **kw)
def createClient(self, **kw):
""" Create and register a new client """
return self._createNode(ClientNode, **kw)
def createAdmin(self, **kw):
""" Create and register a new admin """
return self._createNode(AdminNode, **kw)
def _getClassFromNodeType(self, node_type):
klass = NODE_TYPE_MAPPING.get(node_type)
if klass is None:
raise ValueError('Unknown node type : %s' % node_type)
return klass
def createFromNodeType(self, node_type, **kw): def createFromNodeType(self, node_type, **kw):
return self._createNode(self._getClassFromNodeType(node_type), **kw) return self._createNode(NODE_TYPE_MAPPING[node_type], **kw)
def update(self, node_list): def update(self, node_list):
for node_type, addr, uuid, state in node_list: for node_type, addr, uuid, state in node_list:
# This should be done here (although klass might not be used in this # This should be done here (although klass might not be used in this
# iteration), as it raises if type is not valid. # iteration), as it raises if type is not valid.
klass = self._getClassFromNodeType(node_type) klass = NODE_TYPE_MAPPING[node_type]
# lookup in current table # lookup in current table
node_by_uuid = self.getByUUID(uuid) node_by_uuid = self.getByUUID(uuid)
...@@ -614,3 +463,40 @@ class NodeManager(object): ...@@ -614,3 +463,40 @@ class NodeManager(object):
address = '%s:%d' % address address = '%s:%d' % address
logging.info(' * %*s | %8s | %22s | %s', logging.info(' * %*s | %8s | %22s | %s',
max_len, uuid, node.getType(), address, node.getState()) max_len, uuid, node.getType(), address, node.getState())
@apply
def NODE_TYPE_MAPPING():
def setmethod(cls, attr, value):
assert not hasattr(cls, attr), (cls, attr)
setattr(cls, attr, value)
def setfullmethod(cls, attr, value):
value.__name__ = attr
setmethod(cls, attr, value)
def camel_case(enum):
return str(enum).replace('_', ' ').title().replace(' ', '')
def setStateAccessors(state):
name = camel_case(state)
setfullmethod(Node, 'set' + name, lambda self: self.setState(state))
setfullmethod(Node, 'is' + name, lambda self: self._state == state)
map(setStateAccessors, NodeStates)
node_type_dict = {}
getType = lambda node_type: staticmethod(lambda: node_type)
true = staticmethod(lambda: True)
createNode = lambda cls: lambda self, **kw: self._createNode(cls, **kw)
getList = lambda node_type: lambda self, only_identified=False: \
self._getTypeList(node_type, only_identified)
bases = Node,
for node_type in NodeTypes:
name = camel_case(node_type)
is_name = 'is' + name
setmethod(Node, is_name, bool)
node_type_dict[node_type] = cls = type(name + 'Node', bases, {
'getType': getType(node_type),
is_name: true,
})
setfullmethod(NodeManager, 'create' + name, createNode(cls))
setfullmethod(NodeManager, 'get%sList' % name, getList(node_type))
return node_type_dict
...@@ -32,6 +32,7 @@ from functools import wraps ...@@ -32,6 +32,7 @@ from functools import wraps
from mock import Mock from mock import Mock
from neo.lib import debug, logging, protocol from neo.lib import debug, logging, protocol
from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES from neo.lib.protocol import NodeTypes, Packets, UUID_NAMESPACES
from neo.lib.util import cached_property
from time import time from time import time
from struct import pack, unpack from struct import pack, unpack
from unittest.case import _ExpectedFailure, _UnexpectedSuccess from unittest.case import _ExpectedFailure, _UnexpectedSuccess
...@@ -194,6 +195,15 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -194,6 +195,15 @@ class NeoUnitTestBase(NeoTestBase):
self.uuid_dict = {} self.uuid_dict = {}
NeoTestBase.setUp(self) NeoTestBase.setUp(self)
@cached_property
def nm(self):
from neo.lib import node
return node.NodeManager()
def createStorage(self, *args):
return self.nm.createStorage(**dict(zip(
('address', 'uuid', 'state'), args)))
def prepareDatabase(self, number, prefix=DB_PREFIX): def prepareDatabase(self, number, prefix=DB_PREFIX):
""" create empty databases """ """ create empty databases """
adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL') adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL')
......
...@@ -21,7 +21,6 @@ from .. import NeoUnitTestBase ...@@ -21,7 +21,6 @@ from .. import NeoUnitTestBase
from neo.lib.protocol import NodeStates, CellStates from neo.lib.protocol import NodeStates, CellStates
from neo.lib.pt import PartitionTableException from neo.lib.pt import PartitionTableException
from neo.master.pt import PartitionTable from neo.master.pt import PartitionTable
from neo.lib.node import StorageNode
class MasterPartitionTableTests(NeoUnitTestBase): class MasterPartitionTableTests(NeoUnitTestBase):
...@@ -55,19 +54,19 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -55,19 +54,19 @@ class MasterPartitionTableTests(NeoUnitTestBase):
# create nodes # create nodes
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19002) server2 = ("127.0.0.2", 19002)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19003) server3 = ("127.0.0.3", 19003)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
uuid4 = self.getStorageUUID() uuid4 = self.getStorageUUID()
server4 = ("127.0.0.4", 19004) server4 = ("127.0.0.4", 19004)
sn4 = StorageNode(Mock(), server4, uuid4) sn4 = self.createStorage(server4, uuid4)
uuid5 = self.getStorageUUID() uuid5 = self.getStorageUUID()
server5 = ("127.0.0.5", 19005) server5 = ("127.0.0.5", 19005)
sn5 = StorageNode(Mock(), server5, uuid5) sn5 = self.createStorage(server5, uuid5)
# create partition table # create partition table
num_partitions = 5 num_partitions = 5
num_replicas = 3 num_replicas = 3
...@@ -117,7 +116,7 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -117,7 +116,7 @@ class MasterPartitionTableTests(NeoUnitTestBase):
self.assertEqual(cell.getState(), CellStates.UP_TO_DATE) self.assertEqual(cell.getState(), CellStates.UP_TO_DATE)
def test_15_dropNodeList(self): def test_15_dropNodeList(self):
sn = [StorageNode(Mock(), None, i + 1, NodeStates.RUNNING) sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(3)] for i in xrange(3)]
pt = PartitionTable(3, 0) pt = PartitionTable(3, 0)
pt.setCell(0, sn[0], CellStates.OUT_OF_DATE) pt.setCell(0, sn[0], CellStates.OUT_OF_DATE)
...@@ -153,22 +152,22 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -153,22 +152,22 @@ class MasterPartitionTableTests(NeoUnitTestBase):
# add nodes # add nodes
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1, NodeStates.RUNNING) sn1 = self.createStorage(server1, uuid1, NodeStates.RUNNING)
# add not running node # add not running node
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
sn2.setState(NodeStates.TEMPORARILY_DOWN) sn2.setState(NodeStates.TEMPORARILY_DOWN)
# add node without uuid # add node without uuid
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, None, NodeStates.RUNNING) sn3 = self.createStorage(server3, None, NodeStates.RUNNING)
# add clear node # add clear node
uuid4 = self.getStorageUUID() uuid4 = self.getStorageUUID()
server4 = ("127.0.0.4", 19001) server4 = ("127.0.0.4", 19001)
sn4 = StorageNode(Mock(), server4, uuid4, NodeStates.RUNNING) sn4 = self.createStorage(server4, uuid4, NodeStates.RUNNING)
uuid5 = self.getStorageUUID() uuid5 = self.getStorageUUID()
server5 = ("127.0.0.5", 1900) server5 = ("127.0.0.5", 1900)
sn5 = StorageNode(Mock(), server5, uuid5, NodeStates.RUNNING) sn5 = self.createStorage(server5, uuid5, NodeStates.RUNNING)
# make the table # make the table
pt.make([sn1, sn2, sn3, sn4, sn5]) pt.make([sn1, sn2, sn3, sn4, sn5])
# check it's ok, only running nodes and node with uuid # check it's ok, only running nodes and node with uuid
...@@ -231,7 +230,7 @@ class MasterPartitionTableTests(NeoUnitTestBase): ...@@ -231,7 +230,7 @@ class MasterPartitionTableTests(NeoUnitTestBase):
return change_list return change_list
def test_17_tweak(self): def test_17_tweak(self):
sn = [StorageNode(Mock(), None, i + 1, NodeStates.RUNNING) sn = [self.createStorage(None, i + 1, NodeStates.RUNNING)
for i in xrange(5)] for i in xrange(5)]
pt = PartitionTable(5, 2) pt = PartitionTable(5, 2)
# part 0 # part 0
......
...@@ -18,8 +18,7 @@ import unittest ...@@ -18,8 +18,7 @@ import unittest
from mock import Mock from mock import Mock
from neo.lib import protocol from neo.lib import protocol
from neo.lib.protocol import NodeTypes, NodeStates from neo.lib.protocol import NodeTypes, NodeStates
from neo.lib.node import Node, MasterNode, StorageNode, \ from neo.lib.node import Node, NodeManager, MasterDB
ClientNode, AdminNode, NodeManager, MasterDB
from . import NeoUnitTestBase, getTempDirectory from . import NeoUnitTestBase, getTempDirectory
from time import time from time import time
from os import chmod, mkdir, rmdir, unlink from os import chmod, mkdir, rmdir, unlink
...@@ -29,15 +28,15 @@ class NodesTests(NeoUnitTestBase): ...@@ -29,15 +28,15 @@ class NodesTests(NeoUnitTestBase):
def setUp(self): def setUp(self):
NeoUnitTestBase.setUp(self) NeoUnitTestBase.setUp(self)
self.manager = Mock() self.nm = Mock()
def _updatedByAddress(self, node, index=0): def _updatedByAddress(self, node, index=0):
calls = self.manager.mockGetNamedCalls('_updateAddress') calls = self.nm.mockGetNamedCalls('_updateAddress')
self.assertEqual(len(calls), index + 1) self.assertEqual(len(calls), index + 1)
self.assertEqual(calls[index].getParam(0), node) self.assertEqual(calls[index].getParam(0), node)
def _updatedByUUID(self, node, index=0): def _updatedByUUID(self, node, index=0):
calls = self.manager.mockGetNamedCalls('_updateUUID') calls = self.nm.mockGetNamedCalls('_updateUUID')
self.assertEqual(len(calls), index + 1) self.assertEqual(len(calls), index + 1)
self.assertEqual(calls[index].getParam(0), node) self.assertEqual(calls[index].getParam(0), node)
...@@ -45,7 +44,7 @@ class NodesTests(NeoUnitTestBase): ...@@ -45,7 +44,7 @@ class NodesTests(NeoUnitTestBase):
""" Check the node initialization """ """ Check the node initialization """
address = ('127.0.0.1', 10000) address = ('127.0.0.1', 10000)
uuid = self.getNewUUID(None) uuid = self.getNewUUID(None)
node = Node(self.manager, address=address, uuid=uuid) node = Node(self.nm, address=address, uuid=uuid)
self.assertEqual(node.getState(), NodeStates.UNKNOWN) self.assertEqual(node.getState(), NodeStates.UNKNOWN)
self.assertEqual(node.getAddress(), address) self.assertEqual(node.getAddress(), address)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
...@@ -53,7 +52,7 @@ class NodesTests(NeoUnitTestBase): ...@@ -53,7 +52,7 @@ class NodesTests(NeoUnitTestBase):
def testState(self): def testState(self):
""" Check if the last changed time is updated when state is changed """ """ Check if the last changed time is updated when state is changed """
node = Node(self.manager) node = Node(self.nm)
self.assertEqual(node.getState(), NodeStates.UNKNOWN) self.assertEqual(node.getState(), NodeStates.UNKNOWN)
self.assertTrue(time() - 1 < node.getLastStateChange() < time()) self.assertTrue(time() - 1 < node.getLastStateChange() < time())
previous_time = node.getLastStateChange() previous_time = node.getLastStateChange()
...@@ -64,7 +63,7 @@ class NodesTests(NeoUnitTestBase): ...@@ -64,7 +63,7 @@ class NodesTests(NeoUnitTestBase):
def testAddress(self): def testAddress(self):
""" Check if the node is indexed by address """ """ Check if the node is indexed by address """
node = Node(self.manager) node = Node(self.nm)
self.assertEqual(node.getAddress(), None) self.assertEqual(node.getAddress(), None)
address = ('127.0.0.1', 10000) address = ('127.0.0.1', 10000)
node.setAddress(address) node.setAddress(address)
...@@ -72,107 +71,55 @@ class NodesTests(NeoUnitTestBase): ...@@ -72,107 +71,55 @@ class NodesTests(NeoUnitTestBase):
def testUUID(self): def testUUID(self):
""" As for Address but UUID """ """ As for Address but UUID """
node = Node(self.manager) node = Node(self.nm)
self.assertEqual(node.getAddress(), None) self.assertEqual(node.getAddress(), None)
uuid = self.getNewUUID(None) uuid = self.getNewUUID(None)
node.setUUID(uuid) node.setUUID(uuid)
self._updatedByUUID(node) self._updatedByUUID(node)
def testTypes(self):
""" Check that the abstract node has no type """
node = Node(self.manager)
self.assertRaises(NotImplementedError, node.getType)
self.assertFalse(node.isStorage())
self.assertFalse(node.isMaster())
self.assertFalse(node.isClient())
self.assertFalse(node.isAdmin())
def testMaster(self):
""" Check Master sub class """
node = MasterNode(self.manager)
self.assertEqual(node.getType(), protocol.NodeTypes.MASTER)
self.assertTrue(node.isMaster())
self.assertFalse(node.isStorage())
self.assertFalse(node.isClient())
self.assertFalse(node.isAdmin())
def testStorage(self):
""" Check Storage sub class """
node = StorageNode(self.manager)
self.assertEqual(node.getType(), protocol.NodeTypes.STORAGE)
self.assertTrue(node.isStorage())
self.assertFalse(node.isMaster())
self.assertFalse(node.isClient())
self.assertFalse(node.isAdmin())
def testClient(self):
""" Check Client sub class """
node = ClientNode(self.manager)
self.assertEqual(node.getType(), protocol.NodeTypes.CLIENT)
self.assertTrue(node.isClient())
self.assertFalse(node.isMaster())
self.assertFalse(node.isStorage())
self.assertFalse(node.isAdmin())
def testAdmin(self):
""" Check Admin sub class """
node = AdminNode(self.manager)
self.assertEqual(node.getType(), protocol.NodeTypes.ADMIN)
self.assertTrue(node.isAdmin())
self.assertFalse(node.isMaster())
self.assertFalse(node.isStorage())
self.assertFalse(node.isClient())
class NodeManagerTests(NeoUnitTestBase): class NodeManagerTests(NeoUnitTestBase):
def setUp(self):
NeoUnitTestBase.setUp(self)
self.manager = NodeManager()
def _addStorage(self): def _addStorage(self):
self.storage = StorageNode(self.manager, ('127.0.0.1', 1000), self.getStorageUUID()) self.storage = self.nm.createStorage(
address=('127.0.0.1', 1000), uuid=self.getStorageUUID())
def _addMaster(self): def _addMaster(self):
self.master = MasterNode(self.manager, ('127.0.0.1', 2000), self.getMasterUUID()) self.master = self.nm.createMaster(
address=('127.0.0.1', 2000), uuid=self.getMasterUUID())
def _addClient(self): def _addClient(self):
self.client = ClientNode(self.manager, None, self.getClientUUID()) self.client = self.nm.createClient(uuid=self.getClientUUID())
def _addAdmin(self): def _addAdmin(self):
self.admin = AdminNode(self.manager, ('127.0.0.1', 4000), self.getAdminUUID()) self.admin = self.nm.createAdmin(
address=('127.0.0.1', 4000), uuid=self.getAdminUUID())
def checkNodes(self, node_list): def checkNodes(self, node_list):
manager = self.manager self.assertEqual(sorted(self.nm.getList()), sorted(node_list))
self.assertEqual(sorted(manager.getList()), sorted(node_list))
def checkMasters(self, master_list): def checkMasters(self, master_list):
manager = self.manager self.assertEqual(self.nm.getMasterList(), master_list)
self.assertEqual(manager.getMasterList(), master_list)
def checkStorages(self, storage_list): def checkStorages(self, storage_list):
manager = self.manager self.assertEqual(self.nm.getStorageList(), storage_list)
self.assertEqual(manager.getStorageList(), storage_list)
def checkClients(self, client_list): def checkClients(self, client_list):
manager = self.manager self.assertEqual(self.nm.getClientList(), client_list)
self.assertEqual(manager.getClientList(), client_list)
def checkByServer(self, node): def checkByServer(self, node):
node_found = self.manager.getByAddress(node.getAddress()) self.assertEqual(node, self.nm.getByAddress(node.getAddress()))
self.assertEqual(node_found, node)
def checkByUUID(self, node): def checkByUUID(self, node):
node_found = self.manager.getByUUID(node.getUUID()) self.assertEqual(node, self.nm.getByUUID(node.getUUID()))
self.assertEqual(node_found, node)
def checkIdentified(self, node_list, pool_set=None): def checkIdentified(self, node_list, pool_set=None):
identified_node_list = self.manager.getIdentifiedList(pool_set) identified_node_list = self.nm.getIdentifiedList(pool_set)
self.assertEqual(set(identified_node_list), set(node_list)) self.assertEqual(set(identified_node_list), set(node_list))
def testInit(self): def testInit(self):
""" Check the manager is empty when started """ """ Check the manager is empty when started """
manager = self.manager manager = self.nm
self.checkNodes([]) self.checkNodes([])
self.checkMasters([]) self.checkMasters([])
self.checkStorages([]) self.checkStorages([])
...@@ -186,7 +133,7 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -186,7 +133,7 @@ class NodeManagerTests(NeoUnitTestBase):
def testAdd(self): def testAdd(self):
""" Check if new nodes are registered in the manager """ """ Check if new nodes are registered in the manager """
manager = self.manager manager = self.nm
self.checkNodes([]) self.checkNodes([])
# storage # storage
self._addStorage() self._addStorage()
...@@ -225,7 +172,7 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -225,7 +172,7 @@ class NodeManagerTests(NeoUnitTestBase):
def testUpdate(self): def testUpdate(self):
""" Check manager content update """ """ Check manager content update """
# set up four nodes # set up four nodes
manager = self.manager manager = self.nm
self._addMaster() self._addMaster()
self._addStorage() self._addStorage()
self._addClient() self._addClient()
...@@ -268,7 +215,6 @@ class NodeManagerTests(NeoUnitTestBase): ...@@ -268,7 +215,6 @@ class NodeManagerTests(NeoUnitTestBase):
def testIdentified(self): def testIdentified(self):
# set up four nodes # set up four nodes
manager = self.manager
self._addMaster() self._addMaster()
self._addStorage() self._addStorage()
self._addClient() self._addClient()
......
...@@ -18,7 +18,6 @@ import unittest ...@@ -18,7 +18,6 @@ import unittest
from mock import Mock from mock import Mock
from neo.lib.protocol import NodeStates, CellStates from neo.lib.protocol import NodeStates, CellStates
from neo.lib.pt import Cell, PartitionTable, PartitionTableException from neo.lib.pt import Cell, PartitionTable, PartitionTableException
from neo.lib.node import StorageNode
from . import NeoUnitTestBase from . import NeoUnitTestBase
class PartitionTableTests(NeoUnitTestBase): class PartitionTableTests(NeoUnitTestBase):
...@@ -26,7 +25,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -26,7 +25,7 @@ class PartitionTableTests(NeoUnitTestBase):
def test_01_Cell(self): def test_01_Cell(self):
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
server = ("127.0.0.1", 19001) server = ("127.0.0.1", 19001)
sn = StorageNode(Mock(), server, uuid) sn = self.createStorage(server, uuid)
cell = Cell(sn) cell = Cell(sn)
self.assertEqual(cell.node, sn) self.assertEqual(cell.node, sn)
self.assertEqual(cell.state, CellStates.UP_TO_DATE) self.assertEqual(cell.state, CellStates.UP_TO_DATE)
...@@ -50,7 +49,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -50,7 +49,7 @@ class PartitionTableTests(NeoUnitTestBase):
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
self.assertEqual(len(pt.partition_list[x]), 0) self.assertEqual(len(pt.partition_list[x]), 0)
# add a cell to an empty row # add a cell to an empty row
...@@ -131,7 +130,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -131,7 +130,7 @@ class PartitionTableTests(NeoUnitTestBase):
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
self.assertEqual(len(pt.partition_list[x]), 0) self.assertEqual(len(pt.partition_list[x]), 0)
# add a cell to an empty row # add a cell to an empty row
...@@ -171,19 +170,19 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -171,19 +170,19 @@ class PartitionTableTests(NeoUnitTestBase):
# add two kind of node, usable and unusable # add two kind of node, usable and unusable
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
pt.setCell(0, sn2, CellStates.OUT_OF_DATE) pt.setCell(0, sn2, CellStates.OUT_OF_DATE)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
pt.setCell(0, sn3, CellStates.FEEDING) pt.setCell(0, sn3, CellStates.FEEDING)
uuid4 = self.getStorageUUID() uuid4 = self.getStorageUUID()
server4 = ("127.0.0.4", 19001) server4 = ("127.0.0.4", 19001)
sn4 = StorageNode(Mock(), server4, uuid4) sn4 = self.createStorage(server4, uuid4)
pt.setCell(0, sn4, CellStates.DISCARDED) # won't be added pt.setCell(0, sn4, CellStates.DISCARDED) # won't be added
# now checks result # now checks result
self.assertEqual(len(pt.partition_list[0]), 3) self.assertEqual(len(pt.partition_list[0]), 3)
...@@ -217,15 +216,15 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -217,15 +216,15 @@ class PartitionTableTests(NeoUnitTestBase):
# add two kind of node, usable and unusable # add two kind of node, usable and unusable
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
pt.setCell(1, sn2, CellStates.OUT_OF_DATE) pt.setCell(1, sn2, CellStates.OUT_OF_DATE)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
pt.setCell(2, sn3, CellStates.FEEDING) pt.setCell(2, sn3, CellStates.FEEDING)
# now checks result # now checks result
self.assertEqual(len(pt.partition_list[0]), 1) self.assertEqual(len(pt.partition_list[0]), 1)
...@@ -247,19 +246,19 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -247,19 +246,19 @@ class PartitionTableTests(NeoUnitTestBase):
# add two kind of node, usable and unusable # add two kind of node, usable and unusable
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
pt.setCell(0, sn2, CellStates.OUT_OF_DATE) pt.setCell(0, sn2, CellStates.OUT_OF_DATE)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
pt.setCell(0, sn3, CellStates.FEEDING) pt.setCell(0, sn3, CellStates.FEEDING)
uuid4 = self.getStorageUUID() uuid4 = self.getStorageUUID()
server4 = ("127.0.0.4", 19001) server4 = ("127.0.0.4", 19001)
sn4 = StorageNode(Mock(), server4, uuid4) sn4 = self.createStorage(server4, uuid4)
pt.setCell(0, sn4, CellStates.DISCARDED) # won't be added pt.setCell(0, sn4, CellStates.DISCARDED) # won't be added
# must get only two node as feeding and discarded not taken # must get only two node as feeding and discarded not taken
# into account # into account
...@@ -276,7 +275,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -276,7 +275,7 @@ class PartitionTableTests(NeoUnitTestBase):
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.UP_TO_DATE) pt.setCell(x, sn1, CellStates.UP_TO_DATE)
self.assertEqual(pt.num_filled_rows, num_partitions) self.assertEqual(pt.num_filled_rows, num_partitions)
...@@ -289,7 +288,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -289,7 +288,7 @@ class PartitionTableTests(NeoUnitTestBase):
# add two kind of node, usable and unusable # add two kind of node, usable and unusable
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
# now test # now test
self.assertTrue(pt.hasOffset(0)) self.assertTrue(pt.hasOffset(0))
...@@ -298,15 +297,16 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -298,15 +297,16 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertFalse(pt.hasOffset(50)) self.assertFalse(pt.hasOffset(50))
def test_10_operational(self): def test_10_operational(self):
def createStorage():
uuid = self.getStorageUUID()
return self.createStorage(("127.0.0.1", uuid), uuid)
num_partitions = 5 num_partitions = 5
num_replicas = 2 num_replicas = 2
pt = PartitionTable(num_partitions, num_replicas) pt = PartitionTable(num_partitions, num_replicas)
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() sn1 = createStorage()
server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.UP_TO_DATE) pt.setCell(x, sn1, CellStates.UP_TO_DATE)
self.assertTrue(pt.filled()) self.assertTrue(pt.filled())
...@@ -318,9 +318,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -318,9 +318,7 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() sn1 = createStorage()
server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.FEEDING) pt.setCell(x, sn1, CellStates.FEEDING)
self.assertTrue(pt.filled()) self.assertTrue(pt.filled())
...@@ -333,9 +331,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -333,9 +331,7 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() sn1 = createStorage()
server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1)
sn1.setState(NodeStates.TEMPORARILY_DOWN) sn1.setState(NodeStates.TEMPORARILY_DOWN)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.FEEDING) pt.setCell(x, sn1, CellStates.FEEDING)
...@@ -348,9 +344,7 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -348,9 +344,7 @@ class PartitionTableTests(NeoUnitTestBase):
self.assertFalse(pt.filled()) self.assertFalse(pt.filled())
self.assertFalse(pt.operational()) self.assertFalse(pt.operational())
# adding a node in all partition # adding a node in all partition
uuid1 = self.getStorageUUID() sn1 = createStorage()
server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1)
for x in xrange(num_partitions): for x in xrange(num_partitions):
pt.setCell(x, sn1, CellStates.OUT_OF_DATE) pt.setCell(x, sn1, CellStates.OUT_OF_DATE)
self.assertTrue(pt.filled()) self.assertTrue(pt.filled())
...@@ -364,18 +358,18 @@ class PartitionTableTests(NeoUnitTestBase): ...@@ -364,18 +358,18 @@ class PartitionTableTests(NeoUnitTestBase):
# add nodes # add nodes
uuid1 = self.getStorageUUID() uuid1 = self.getStorageUUID()
server1 = ("127.0.0.1", 19001) server1 = ("127.0.0.1", 19001)
sn1 = StorageNode(Mock(), server1, uuid1) sn1 = self.createStorage(server1, uuid1)
pt.setCell(0, sn1, CellStates.UP_TO_DATE) pt.setCell(0, sn1, CellStates.UP_TO_DATE)
pt.setCell(1, sn1, CellStates.UP_TO_DATE) pt.setCell(1, sn1, CellStates.UP_TO_DATE)
pt.setCell(2, sn1, CellStates.UP_TO_DATE) pt.setCell(2, sn1, CellStates.UP_TO_DATE)
uuid2 = self.getStorageUUID() uuid2 = self.getStorageUUID()
server2 = ("127.0.0.2", 19001) server2 = ("127.0.0.2", 19001)
sn2 = StorageNode(Mock(), server2, uuid2) sn2 = self.createStorage(server2, uuid2)
pt.setCell(0, sn2, CellStates.UP_TO_DATE) pt.setCell(0, sn2, CellStates.UP_TO_DATE)
pt.setCell(1, sn2, CellStates.UP_TO_DATE) pt.setCell(1, sn2, CellStates.UP_TO_DATE)
uuid3 = self.getStorageUUID() uuid3 = self.getStorageUUID()
server3 = ("127.0.0.3", 19001) server3 = ("127.0.0.3", 19001)
sn3 = StorageNode(Mock(), server3, uuid3) sn3 = self.createStorage(server3, uuid3)
pt.setCell(0, sn3, CellStates.UP_TO_DATE) pt.setCell(0, sn3, CellStates.UP_TO_DATE)
# test # test
row_0 = pt.getRow(0) row_0 = pt.getRow(0)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment