Commit b3b5175f authored by Julien Muchembled's avatar Julien Muchembled

tests: make it possible to run several threaded clusters at the same time

parent dcbf0b02
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import os, random, socket, sys, tempfile, threading, time, types, weakref import os, random, socket, sys, tempfile, threading, time, types, weakref
from collections import deque from collections import deque
from itertools import count
from functools import wraps from functools import wraps
from zlib import decompress from zlib import decompress
from mock import Mock from mock import Mock
...@@ -37,12 +38,6 @@ from .. import NeoTestBase, getTempDirectory, setupMySQLdb, \ ...@@ -37,12 +38,6 @@ from .. import NeoTestBase, getTempDirectory, setupMySQLdb, \
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0 BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]) LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
SERVER_TYPE = ['master', 'storage', 'admin']
VIRTUAL_IP = [socket.inet_ntop(ADDRESS_TYPE, LOCAL_IP[:-1] + chr(2 + i))
for i in xrange(len(SERVER_TYPE))]
def getVirtualIp(server_type):
return VIRTUAL_IP[SERVER_TYPE.index(server_type)]
class Serialized(object): class Serialized(object):
...@@ -57,12 +52,12 @@ class Serialized(object): ...@@ -57,12 +52,12 @@ class Serialized(object):
cls.pending = 0 cls.pending = 0
@classmethod @classmethod
def release(cls, lock=None, wake_other=True, stop=False): def release(cls, lock=None, wake_other=True, stop=None):
"""Suspend lock owner and resume first suspended thread""" """Suspend lock owner and resume first suspended thread"""
if lock is None: if lock is None:
lock = cls._global_lock lock = cls._global_lock
if stop: # XXX: we should fix ClusterStates.STOPPING if stop: # XXX: we should fix ClusterStates.STOPPING
cls.pending = None cls.pending = frozenset(stop)
else: else:
cls.pending = 0 cls.pending = 0
try: try:
...@@ -86,10 +81,10 @@ class Serialized(object): ...@@ -86,10 +81,10 @@ class Serialized(object):
if lock is None: if lock is None:
lock = cls._global_lock lock = cls._global_lock
lock.acquire() lock.acquire()
if cls.pending is None: # XXX if type(cls.pending) is frozenset: # XXX
if lock is cls._global_lock: if lock is cls._global_lock:
cls.pending = 0 cls.pending = 0
else: elif threading.currentThread() in cls.pending:
sys.exit() sys.exit()
if cls._pdb: if cls._pdb:
cls._pdb = False cls._pdb = False
...@@ -143,7 +138,7 @@ class SerializedEventManager(EventManager): ...@@ -143,7 +138,7 @@ class SerializedEventManager(EventManager):
self.writer_set): self.writer_set):
return return
else: else:
if self.writer_set and Serialized.pending is not None: if self.writer_set and Serialized.pending == 0:
Serialized.pending = 1 Serialized.pending = 1
# Jump to another thread before polling, so that when a message is # Jump to another thread before polling, so that when a message is
# sent on the network, one can debug immediately the receiving part. # sent on the network, one can debug immediately the receiving part.
...@@ -154,7 +149,7 @@ class SerializedEventManager(EventManager): ...@@ -154,7 +149,7 @@ class SerializedEventManager(EventManager):
Serialized.tic(self._lock) Serialized.tic(self._lock)
if timeout != 0: if timeout != 0:
timeout = self._timeout timeout = self._timeout
if timeout != 0 and Serialized.pending: if timeout != 0 and Serialized.pending == 1:
Serialized.pending = timeout = 0 Serialized.pending = timeout = 0
EventManager._poll(self, timeout) EventManager._poll(self, timeout)
...@@ -173,25 +168,50 @@ class Node(object): ...@@ -173,25 +168,50 @@ class Node(object):
class ServerNode(Node): class ServerNode(Node):
_server_class_dict = {}
class __metaclass__(type): class __metaclass__(type):
def __init__(cls, name, bases, d): def __init__(cls, name, bases, d):
type.__init__(cls, name, bases, d) type.__init__(cls, name, bases, d)
if Node not in bases and threading.Thread not in cls.__mro__: if Node not in bases and threading.Thread not in cls.__mro__:
cls.__bases__ = bases + (threading.Thread,) cls.__bases__ = bases + (threading.Thread,)
cls.node_type = getattr(NodeTypes, name[:-11].upper())
cls._node_list = []
cls._virtual_ip = socket.inet_ntop(ADDRESS_TYPE,
LOCAL_IP[:-1] + chr(2 + len(cls._server_class_dict)))
cls._server_class_dict[cls._virtual_ip] = cls
@classmethod
def newAddress(cls):
address = cls._virtual_ip, len(cls._node_list)
cls._node_list.append(None)
return address
@classmethod
def resolv(cls, address):
try:
cls = cls._server_class_dict[address[0]]
except KeyError:
return address
return cls._node_list[address[1]].getListeningAddress()
@SerializedEventManager.decorate @SerializedEventManager.decorate
def __init__(self, cluster, address, **kw): def __init__(self, cluster, address=None, **kw):
self._init_args = (cluster, address), dict(kw) if not address:
address = self.newAddress()
port = address[1]
self._node_list[port] = weakref.proxy(self)
self._init_args = (cluster, address), kw.copy()
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.daemon = True self.daemon = True
h, p = address self.node_name = '%s_%u' % (self.node_type, port)
self.node_type = getattr(NodeTypes,
SERVER_TYPE[VIRTUAL_IP.index(h)].upper())
self.node_name = '%s_%u' % (self.node_type, p)
kw.update(getCluster=cluster.name, getBind=address, kw.update(getCluster=cluster.name, getBind=address,
getMasters=parseMasterList(cluster.master_nodes, address)) getMasters=parseMasterList(cluster.master_nodes, address))
super(ServerNode, self).__init__(Mock(kw)) super(ServerNode, self).__init__(Mock(kw))
def getVirtualAddress(self):
return self._init_args[0][1]
def resetNode(self): def resetNode(self):
assert not self.isAlive() assert not self.isAlive()
args, kw = self._init_args args, kw = self._init_args
...@@ -321,12 +341,12 @@ class ClientApplication(Node, neo.client.app.Application): ...@@ -321,12 +341,12 @@ class ClientApplication(Node, neo.client.app.Application):
class NeoCTL(neo.neoctl.app.NeoCTL): class NeoCTL(neo.neoctl.app.NeoCTL):
@SerializedEventManager.decorate @SerializedEventManager.decorate
def __init__(self, cluster, address=(getVirtualIp('admin'), 0)): def __init__(self, cluster):
self._cluster = cluster self._cluster = cluster
super(NeoCTL, self).__init__(address) super(NeoCTL, self).__init__(cluster.admin.getVirtualAddress())
self.em._timeout = -1 self.em._timeout = -1
server = property(lambda self: self._cluster.resolv(self._server), server = property(lambda self: ServerNode.resolv(self._server),
lambda self, address: setattr(self, '_server', address)) lambda self, address: setattr(self, '_server', address))
...@@ -441,16 +461,24 @@ class NEOCluster(object): ...@@ -441,16 +461,24 @@ class NEOCluster(object):
SocketConnector.makeListeningConnection) SocketConnector.makeListeningConnection)
SocketConnector_send = staticmethod(SocketConnector.send) SocketConnector_send = staticmethod(SocketConnector.send)
Storage__init__ = staticmethod(Storage.__init__) Storage__init__ = staticmethod(Storage.__init__)
_patch_count = 0
_resource_dict = weakref.WeakValueDictionary()
_patched = threading.Lock() def _allocate(self, resource, new):
result = resource, new()
while result in self._resource_dict:
result = resource, new()
self._resource_dict[result] = self
return result[1]
def _patch(cluster): def _patch(cluster):
cls = cluster.__class__ cls = cluster.__class__
if not cls._patched.acquire(0): cls._patch_count += 1
raise RuntimeError("Can't run several cluster at the same time") if cls._patch_count > 1:
return
def makeClientConnection(self, addr): def makeClientConnection(self, addr):
real_addr = ServerNode.resolv(addr)
try: try:
real_addr = cluster.resolv(addr)
return cls.SocketConnector_makeClientConnection(self, real_addr) return cls.SocketConnector_makeClientConnection(self, real_addr)
finally: finally:
self.remote_addr = addr self.remote_addr = addr
...@@ -468,9 +496,14 @@ class NEOCluster(object): ...@@ -468,9 +496,14 @@ class NEOCluster(object):
cls.SocketConnector_makeListeningConnection(self, BIND) cls.SocketConnector_makeListeningConnection(self, BIND)
SocketConnector.send = send SocketConnector.send = send
Storage.setupLog = lambda *args, **kw: None Storage.setupLog = lambda *args, **kw: None
Serialized.init()
@classmethod @classmethod
def _unpatch(cls): def _unpatch(cls):
assert cls._patch_count > 0
cls._patch_count -= 1
if cls._patch_count:
return
bootstrap.sleep = time.sleep bootstrap.sleep = time.sleep
BaseConnection.checkTimeout = cls.BaseConnection_checkTimeout BaseConnection.checkTimeout = cls.BaseConnection_checkTimeout
SocketConnector.makeClientConnection = \ SocketConnector.makeClientConnection = \
...@@ -479,7 +512,6 @@ class NEOCluster(object): ...@@ -479,7 +512,6 @@ class NEOCluster(object):
cls.SocketConnector_makeListeningConnection cls.SocketConnector_makeListeningConnection
SocketConnector.send = cls.SocketConnector_send SocketConnector.send = cls.SocketConnector_send
Storage.setupLog = setupLog Storage.setupLog = setupLog
cls._patched.release()
def __init__(self, master_count=1, partitions=1, replicas=0, def __init__(self, master_count=1, partitions=1, replicas=0,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'), adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'),
...@@ -492,27 +524,27 @@ class NEOCluster(object): ...@@ -492,27 +524,27 @@ class NEOCluster(object):
log_file = tempfile.mkstemp('.log', '', temp_dir)[1] log_file = tempfile.mkstemp('.log', '', temp_dir)[1]
print 'Logging to %r' % log_file print 'Logging to %r' % log_file
setupLog(LoggerThreadName(), log_file, verbose) setupLog(LoggerThreadName(), log_file, verbose)
self.name = 'neo_%s' % random.randint(0, 100) self.name = 'neo_%s' % self._allocate('name',
ip = getVirtualIp('master') lambda: random.randint(0, 100))
self.master_nodes = ' '.join('%s:%s' % (ip, i) master_list = [MasterApplication.newAddress()
for i in xrange(master_count)) for _ in xrange(master_count)]
self.master_nodes = ' '.join('%s:%s' % x for x in master_list)
weak_self = weakref.proxy(self) weak_self = weakref.proxy(self)
kw = dict(cluster=weak_self, getReplicas=replicas, getAdapter=adapter, kw = dict(cluster=weak_self, getReplicas=replicas, getAdapter=adapter,
getPartitions=partitions, getReset=clear_databases) getPartitions=partitions, getReset=clear_databases)
self.master_list = [MasterApplication(address=(ip, i), **kw) self.master_list = [MasterApplication(address=x, **kw)
for i in xrange(master_count)] for x in master_list]
ip = getVirtualIp('storage')
if db_list is None: if db_list is None:
if storage_count is None: if storage_count is None:
storage_count = replicas + 1 storage_count = replicas + 1
db_list = ['%s%u' % (DB_PREFIX, i) for i in xrange(storage_count)] index = count().next
db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index))
for _ in xrange(storage_count)]
setupMySQLdb(db_list, db_user, db_password, clear_databases) setupMySQLdb(db_list, db_user, db_password, clear_databases)
db = '%s:%s@%%s' % (db_user, db_password) db = '%s:%s@%%s' % (db_user, db_password)
self.storage_list = [StorageApplication(address=(ip, i), self.storage_list = [StorageApplication(getDatabase=db % x, **kw)
getDatabase=db % x, **kw) for x in db_list]
for i, x in enumerate(db_list)] self.admin_list = [AdminApplication(**kw)]
ip = getVirtualIp('admin')
self.admin_list = [AdminApplication(address=(ip, 0), **kw)]
self.client = ClientApplication(weak_self) self.client = ClientApplication(weak_self)
self.neoctl = NeoCTL(weak_self) self.neoctl = NeoCTL(weak_self)
...@@ -531,16 +563,8 @@ class NEOCluster(object): ...@@ -531,16 +563,8 @@ class NEOCluster(object):
return admin return admin
### ###
def resolv(self, addr):
host, port = addr
try:
attr = SERVER_TYPE[VIRTUAL_IP.index(host)] + '_list'
except ValueError:
return addr
return getattr(self, attr)[port].getListeningAddress()
def reset(self, clear_database=False): def reset(self, clear_database=False):
for node_type in SERVER_TYPE: for node_type in 'master', 'storage', 'admin':
kw = {} kw = {}
if node_type == 'storage': if node_type == 'storage':
kw['clear_database'] = clear_database kw['clear_database'] = clear_database
...@@ -551,7 +575,6 @@ class NEOCluster(object): ...@@ -551,7 +575,6 @@ class NEOCluster(object):
def start(self, storage_list=None, fast_startup=False): def start(self, storage_list=None, fast_startup=False):
self._patch() self._patch()
Serialized.init()
for node_type in 'master', 'admin': for node_type in 'master', 'admin':
for node in getattr(self, node_type + '_list'): for node in getattr(self, node_type + '_list'):
node.start() node.start()
...@@ -566,7 +589,8 @@ class NEOCluster(object): ...@@ -566,7 +589,8 @@ class NEOCluster(object):
if not fast_startup: if not fast_startup:
self._startCluster() self._startCluster()
self.tic() self.tic()
assert self.neoctl.getClusterState() == ClusterStates.RUNNING state = self.neoctl.getClusterState()
assert state == ClusterStates.RUNNING, state
self.enableStorageList(storage_list) self.enableStorageList(storage_list)
def _startCluster(self): def _startCluster(self):
...@@ -598,8 +622,9 @@ class NEOCluster(object): ...@@ -598,8 +622,9 @@ class NEOCluster(object):
self.__dict__.pop('_db', self.client).close() self.__dict__.pop('_db', self.client).close()
#self.neoctl.setClusterState(ClusterStates.STOPPING) # TODO #self.neoctl.setClusterState(ClusterStates.STOPPING) # TODO
try: try:
Serialized.release(stop=1) Serialized.release(stop=
for node_type in SERVER_TYPE[::-1]: self.admin_list + self.storage_list + self.master_list)
for node_type in 'admin', 'storage', 'master':
for node in getattr(self, node_type + '_list'): for node in getattr(self, node_type + '_list'):
if node.isAlive(): if node.isAlive():
node.join() node.join()
......
...@@ -411,6 +411,21 @@ class Test(NEOThreadedTest): ...@@ -411,6 +411,21 @@ class Test(NEOThreadedTest):
finally: finally:
cluster.stop() cluster.stop()
def test2Clusters(self):
cluster1 = NEOCluster()
cluster2 = NEOCluster()
try:
cluster1.start()
cluster2.start()
t1, c1 = cluster1.getTransaction()
t2, c2 = cluster2.getTransaction()
c1.root()['1'] = c2.root()['2'] = ''
t1.commit()
t2.commit()
finally:
cluster1.stop()
cluster2.stop()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment