Commit 0bed5cd0 authored by Julien Muchembled's avatar Julien Muchembled

tests: refactor and fix setup of MySQL DB

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2821 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent b621e0a7
...@@ -32,9 +32,9 @@ from neo.lib.util import getAddressType ...@@ -32,9 +32,9 @@ from neo.lib.util import getAddressType
from time import time, gmtime from time import time, gmtime
from struct import pack, unpack from struct import pack, unpack
DB_PREFIX = os.getenv('NEO_DB_PREFIX', 'test_neo_') DB_PREFIX = os.getenv('NEO_DB_PREFIX', 'test_neo')
DB_ADMIN = os.getenv('NEO_DB_ADMIN', 'root') DB_ADMIN = os.getenv('NEO_DB_ADMIN', 'root')
DB_PASSWD = os.getenv('NEO_DB_PASSWD', None) DB_PASSWD = os.getenv('NEO_DB_PASSWD', '')
DB_USER = os.getenv('NEO_DB_USER', 'test') DB_USER = os.getenv('NEO_DB_USER', 'test')
IP_VERSION_FORMAT_DICT = { IP_VERSION_FORMAT_DICT = {
...@@ -88,6 +88,26 @@ def getTempDirectory(): ...@@ -88,6 +88,26 @@ def getTempDirectory():
print 'Using temp directory %r.' % temp_dir print 'Using temp directory %r.' % temp_dir
return temp_dir return temp_dir
def setupMySQLdb(db_list, user=DB_USER, password='', clear_databases=True):
from MySQLdb.constants.ER import BAD_DB_ERROR
conn = MySQLdb.Connect(user=DB_ADMIN, passwd=DB_PASSWD)
cursor = conn.cursor()
for database in db_list:
try:
conn.select_db(database)
if not clear_databases:
continue
cursor.execute('DROP DATABASE `%s`' % database)
except MySQLdb.OperationalError, (code, _):
if code != BAD_DB_ERROR:
raise
cursor.execute('GRANT ALL ON `%s`.* TO "%s"@"localhost" IDENTIFIED'
' BY "%s"' % (database, user, password))
cursor.execute('CREATE DATABASE `%s`' % database)
cursor.close()
conn.commit()
conn.close()
class NeoTestBase(unittest.TestCase): class NeoTestBase(unittest.TestCase):
def setUp(self): def setUp(self):
logger.PACKET_LOGGER.enable(True) logger.PACKET_LOGGER.enable(True)
...@@ -123,24 +143,9 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -123,24 +143,9 @@ class NeoUnitTestBase(NeoTestBase):
local_ip = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE] local_ip = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]
def prepareDatabase(self, number, admin=DB_ADMIN, password=DB_PASSWD, def prepareDatabase(self, number, prefix='test_neo'):
user=DB_USER, prefix=DB_PREFIX, address_type = ADDRESS_TYPE):
""" create empties databases """ """ create empties databases """
# SQL connection setupMySQLdb(['%s%u' % (prefix, i) for i in xrange(number)])
connect_arg_dict = {'user': admin}
if password is not None:
connect_arg_dict['passwd'] = password
sql_connection = MySQLdb.Connect(**connect_arg_dict)
cursor = sql_connection.cursor()
# drop and create each database
for i in xrange(number):
database = "%s%d" % (prefix, i)
cursor.execute('DROP DATABASE IF EXISTS %s' % (database, ))
cursor.execute('CREATE DATABASE %s' % (database, ))
cursor.execute('GRANT ALL ON %s.* TO "%s"@"localhost" IDENTIFIED BY ""' %
(database, user))
cursor.close()
sql_connection.close()
def getMasterConfiguration(self, cluster='main', master_number=2, def getMasterConfiguration(self, cluster='main', master_number=2,
replicas=2, partitions=1009, uuid=None): replicas=2, partitions=1009, uuid=None):
......
...@@ -37,7 +37,7 @@ from neo.neoctl.neoctl import NeoCTL, NotReadyException ...@@ -37,7 +37,7 @@ from neo.neoctl.neoctl import NeoCTL, NotReadyException
from neo.lib import setupLog from neo.lib import setupLog
from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates
from neo.lib.util import dump from neo.lib.util import dump
from neo.tests import DB_ADMIN, DB_PASSWD, NeoTestBase, buildUrlFromString, \ from neo.tests import DB_USER, setupMySQLdb, NeoTestBase, buildUrlFromString, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, getTempDirectory ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, getTempDirectory
from neo.tests.cluster import SocketLock from neo.tests.cluster import SocketLock
from neo.client.Storage import Storage from neo.client.Storage import Storage
...@@ -230,8 +230,7 @@ class NEOProcess(object): ...@@ -230,8 +230,7 @@ class NEOProcess(object):
class NEOCluster(object): class NEOCluster(object):
def __init__(self, db_list, master_count=1, partitions=1, replicas=0, def __init__(self, db_list, master_count=1, partitions=1, replicas=0,
db_user='neo', db_password='neo', db_user=DB_USER, db_password='',
db_super_user=DB_ADMIN, db_super_password=DB_PASSWD,
cleanup_on_delete=False, temp_dir=None, clear_databases=True, cleanup_on_delete=False, temp_dir=None, clear_databases=True,
adapter=os.getenv('NEO_TESTS_ADAPTER'), adapter=os.getenv('NEO_TESTS_ADAPTER'),
verbose=True, verbose=True,
...@@ -244,8 +243,6 @@ class NEOCluster(object): ...@@ -244,8 +243,6 @@ class NEOCluster(object):
self.cleanup_on_delete = cleanup_on_delete self.cleanup_on_delete = cleanup_on_delete
self.verbose = verbose self.verbose = verbose
self.uuid_set = set() self.uuid_set = set()
self.db_super_user = db_super_user
self.db_super_password = db_super_password
self.db_user = db_user self.db_user = db_user
self.db_password = db_password self.db_password = db_password
self.db_list = db_list self.db_list = db_list
...@@ -316,36 +313,10 @@ class NEOCluster(object): ...@@ -316,36 +313,10 @@ class NEOCluster(object):
self.uuid_set.add(uuid) self.uuid_set.add(uuid)
return uuid return uuid
def __getSuperSQLConnection(self):
# Cleanup or bootstrap databases
connect_arg_dict = {'user': self.db_super_user}
password = self.db_super_password
if password is not None:
connect_arg_dict['passwd'] = password
return MySQLdb.Connect(**connect_arg_dict)
def setupDB(self, clear_databases=True): def setupDB(self, clear_databases=True):
if self.adapter == 'MySQL': if self.adapter == 'MySQL':
from MySQLdb.constants.ER import DB_CREATE_EXISTS setupMySQLdb(self.db_list, self.db_user, self.db_password,
sql_connection = self.__getSuperSQLConnection() clear_databases)
cursor = sql_connection.cursor()
for database in self.db_list:
create = 'CREATE DATABASE `%s`' % database
try:
cursor.execute(create)
except MySQLdb.ProgrammingError, (code, _):
if code != DB_CREATE_EXISTS:
raise
if clear_databases:
cursor.execute('DROP DATABASE `%s`' % database)
cursor.execute(create)
continue
cursor.execute('GRANT ALL ON `%s`.* TO "%s"@"localhost" '
'IDENTIFIED BY "%s"' % (database, self.db_user,
self.db_password))
cursor.close()
sql_connection.commit()
sql_connection.close()
def run(self, except_storages=()): def run(self, except_storages=()):
""" Start cluster processes except some storage nodes """ """ Start cluster processes except some storage nodes """
......
...@@ -32,8 +32,8 @@ from neo.lib.connector import SocketConnector, \ ...@@ -32,8 +32,8 @@ from neo.lib.connector import SocketConnector, \
from neo.lib.event import EventManager from neo.lib.event import EventManager
from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes from neo.lib.protocol import CellStates, ClusterStates, NodeStates, NodeTypes
from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList from neo.lib.util import SOCKET_CONNECTORS_DICT, parseMasterList
from neo.tests import NeoUnitTestBase, getTempDirectory, \ from neo.tests import NeoUnitTestBase, getTempDirectory, setupMySQLdb, \
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX, DB_USER
BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0 BIND = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE], 0
LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE]) LOCAL_IP = socket.inet_pton(ADDRESS_TYPE, IP_VERSION_FORMAT_DICT[ADDRESS_TYPE])
...@@ -354,7 +354,7 @@ class NEOCluster(object): ...@@ -354,7 +354,7 @@ class NEOCluster(object):
def __init__(self, master_count=1, partitions=1, replicas=0, def __init__(self, master_count=1, partitions=1, replicas=0,
adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'), adapter=os.getenv('NEO_TESTS_ADAPTER', 'BTree'),
storage_count=None, db_list=None, clear_databases=True, storage_count=None, db_list=None, clear_databases=True,
db_user='neo', db_password='neo', verbose=None): db_user=DB_USER, db_password='', verbose=None):
if verbose is not None: if verbose is not None:
temp_dir = os.getenv('TEMP') or \ temp_dir = os.getenv('TEMP') or \
os.path.join(tempfile.gettempdir(), 'neo_tests') os.path.join(tempfile.gettempdir(), 'neo_tests')
...@@ -374,7 +374,8 @@ class NEOCluster(object): ...@@ -374,7 +374,8 @@ class NEOCluster(object):
if db_list is None: if db_list is None:
if storage_count is None: if storage_count is None:
storage_count = replicas + 1 storage_count = replicas + 1
db_list = ['test_neo%u' % i for i in xrange(storage_count)] db_list = ['%s%u' % (DB_PREFIX, i) for i in xrange(storage_count)]
setupMySQLdb(db_list, db_user, db_password, clear_databases)
db = '%s:%s@%%s' % (db_user, db_password) db = '%s:%s@%%s' % (db_user, db_password)
self.storage_list = [StorageApplication(address=(ip, i), self.storage_list = [StorageApplication(address=(ip, i),
getDatabase=db % x, **kw) getDatabase=db % x, **kw)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment