Commit 6153a752 authored by Julien Muchembled's avatar Julien Muchembled

Add support for PyPy & PyMySQL

parent c29b8b2d
...@@ -48,9 +48,10 @@ Requirements ...@@ -48,9 +48,10 @@ Requirements
- Python 2.7.x (2.7.9 or later for SSL support) - Python 2.7.x (2.7.9 or later for SSL support)
- For storage nodes using MySQL backend: - For storage nodes using MySQL, one of the following backends:
- MySQLdb: https://github.com/PyMySQL/mysqlclient-python - MySQLdb: https://github.com/PyMySQL/mysqlclient
- PyMySQL: https://github.com/PyMySQL/PyMySQL
- For client nodes: ZODB 4.4.5 or later - For client nodes: ZODB 4.4.5 or later
......
...@@ -33,6 +33,7 @@ def patch(): ...@@ -33,6 +33,7 @@ def patch():
assert H(Connection.afterCompletion) in ( assert H(Connection.afterCompletion) in (
'cd3a080b80fd957190ff3bb867149448', # Python 2.7 'cd3a080b80fd957190ff3bb867149448', # Python 2.7
'b1d9685c13967d4b6d74c7ef86f68f17', # PyPy 2.7
) )
def afterCompletion(self, *ignored): def afterCompletion(self, *ignored):
......
...@@ -35,7 +35,7 @@ if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]): ...@@ -35,7 +35,7 @@ if filter(re.compile(r'--coverage$|-\w*c').match, sys.argv[1:]):
coverage.start() coverage.start()
from neo.lib import logging from neo.lib import logging
from neo.tests import getTempDirectory, NeoTestBase, Patch, \ from neo.tests import adapter, getTempDirectory, NeoTestBase, Patch, \
__dict__ as neo_tests__dict__ __dict__ as neo_tests__dict__
from neo.tests.benchmark import BenchmarkRunner from neo.tests.benchmark import BenchmarkRunner
...@@ -216,9 +216,11 @@ class NeoTestRunner(unittest.TextTestResult): ...@@ -216,9 +216,11 @@ class NeoTestRunner(unittest.TextTestResult):
add_status('Directory', self.temp_directory) add_status('Directory', self.temp_directory)
if self.testsRun: if self.testsRun:
add_status('Status', '%.3f%%' % (success * 100.0 / self.testsRun)) add_status('Status', '%.3f%%' % (success * 100.0 / self.testsRun))
for var in os.environ: for k, v in os.environ.iteritems():
if var.startswith('NEO_TEST'): if k.startswith('NEO_TEST'):
add_status(var, os.environ[var]) if k == 'NEO_TESTS_ADAPTER' and v == 'MySQL':
from neo.storage.database.mysql import binding_name as v
add_status(k, v)
# visual # visual
header = "%25s | run | unexpected | expected | skipped | time \n" % 'Test Module' header = "%25s | run | unexpected | expected | skipped | time \n" % 'Test Module'
separator = "%25s-+-------+------------+----------+---------+----------\n" % ('-' * 25) separator = "%25s-+-------+------------+----------+---------+----------\n" % ('-' * 25)
...@@ -318,7 +320,7 @@ class TestRunner(BenchmarkRunner): ...@@ -318,7 +320,7 @@ class TestRunner(BenchmarkRunner):
" passed.") " passed.")
parser.epilog = """ parser.epilog = """
Environment Variables: Environment Variables:
NEO_PYPY PyPy executable to run master nodes in functional NEOMASTER_PYPY PyPy executable to run master nodes in functional
tests (and also in zodb tests depending on tests (and also in zodb tests depending on
NEO_TEST_ZODB_FUNCTIONAL). NEO_TEST_ZODB_FUNCTIONAL).
NEO_TESTS_ADAPTER Default is SQLite for threaded clusters, NEO_TESTS_ADAPTER Default is SQLite for threaded clusters,
......
...@@ -26,7 +26,7 @@ from neo.lib.pt import PartitionTable ...@@ -26,7 +26,7 @@ from neo.lib.pt import PartitionTable
from neo.lib.util import dump from neo.lib.util import dump
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from .checker import Checker from .checker import Checker
from .database import buildDatabaseManager, DATABASE_MANAGER_DICT from .database import buildDatabaseManager, DATABASE_MANAGERS
from .handlers import identification, initialization, master from .handlers import identification, initialization, master
from .replicator import Replicator from .replicator import Replicator
from .transactions import TransactionManager from .transactions import TransactionManager
...@@ -37,7 +37,7 @@ option_defaults = { ...@@ -37,7 +37,7 @@ option_defaults = {
'adapter': 'MySQL', 'adapter': 'MySQL',
'wait': 0, 'wait': 0,
} }
assert option_defaults['adapter'] in DATABASE_MANAGER_DICT assert option_defaults['adapter'] in DATABASE_MANAGERS
@buildOptionParser @buildOptionParser
class Application(BaseApplication): class Application(BaseApplication):
...@@ -52,7 +52,7 @@ class Application(BaseApplication): ...@@ -52,7 +52,7 @@ class Application(BaseApplication):
cls.addCommonServerOptions('storage', '127.0.0.1') cls.addCommonServerOptions('storage', '127.0.0.1')
_ = parser.group('storage') _ = parser.group('storage')
_('a', 'adapter', choices=sorted(DATABASE_MANAGER_DICT), _('a', 'adapter', choices=DATABASE_MANAGERS,
help="database adapter to use") help="database adapter to use")
_('d', 'database', required=True, _('d', 'database', required=True,
help="database connections string") help="database connections string")
......
...@@ -16,18 +16,51 @@ ...@@ -16,18 +16,51 @@
LOG_QUERIES = False LOG_QUERIES = False
DATABASE_MANAGER_DICT = { def useMySQLdb():
'Importer': 'importer.ImporterDatabaseManager', import platform
'MySQL': 'mysql.MySQLDatabaseManager', py = platform.python_implementation() == 'PyPy'
'SQLite': 'sqlite.SQLiteDatabaseManager',
}
def getAdapterKlass(name):
try: try:
module, name = DATABASE_MANAGER_DICT[name or 'MySQL'].split('.') if py:
except KeyError: import pymysql
raise DatabaseFailure('Cannot find a database adapter <%s>' % name) else:
return getattr(__import__(module, globals(), level=1), name) import MySQLdb
except ImportError:
return py
return not py
class getAdapterKlass(object):
def __new__(cls, name):
try:
m = getattr(cls, name or 'MySQL')
except AttributeError:
raise DatabaseFailure('Cannot find a database adapter <%s>' % name)
return m()
@staticmethod
def Importer():
from .importer import ImporterDatabaseManager as DM
return DM
@classmethod
def MySQL(cls, MySQLdb=None):
if MySQLdb is not None:
global useMySQLdb
useMySQLdb = lambda: MySQLdb
from .mysql import binding_name, MySQLDatabaseManager as DM
assert hasattr(cls, binding_name)
return DM
MySQLdb = classmethod(lambda cls: cls.MySQL(True))
PyMySQL = classmethod(lambda cls: cls.MySQL(False))
@staticmethod
def SQLite():
from .sqlite import SQLiteDatabaseManager as DM
return DM
DATABASE_MANAGERS = tuple(sorted(
x for x in dir(getAdapterKlass) if not x.startswith('_')))
def buildDatabaseManager(name, args=(), kw={}): def buildDatabaseManager(name, args=(), kw={}):
return getAdapterKlass(name)(*args, **kw) return getAdapterKlass(name)(*args, **kw)
......
...@@ -22,7 +22,7 @@ from cStringIO import StringIO ...@@ -22,7 +22,7 @@ from cStringIO import StringIO
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from ZConfig import loadConfigFile from ZConfig import loadConfigFile
from ZODB import BaseStorage from ZODB import BaseStorage
from ZODB._compat import dumps, loads, _protocol from ZODB._compat import dumps, loads, _protocol, PersistentPickler
from ZODB.config import getStorageSchema, storageFromString from ZODB.config import getStorageSchema, storageFromString
from ZODB.POSException import POSKeyError from ZODB.POSException import POSKeyError
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
...@@ -44,6 +44,35 @@ def transactionAsTuple(txn): ...@@ -44,6 +44,35 @@ def transactionAsTuple(txn):
dumps(ext, _protocol) if ext else '', dumps(ext, _protocol) if ext else '',
txn.status == 'p', txn.tid) txn.status == 'p', txn.tid)
@apply
def patch_save_reduce(): # for _noload.__reduce__
Pickler = PersistentPickler(None, StringIO()).__class__
try:
orig_save_reduce = Pickler.save_reduce.__func__
except AttributeError: # both cPickle and C zodbpickle accept
return # that first reduce argument is None
BUILD = pickle.BUILD
REDUCE = pickle.REDUCE
def save_reduce(self, func, args, state=None,
listitems=None, dictitems=None, obj=None):
if func is not None:
return orig_save_reduce(self,
func, args, state, listitems, dictitems, obj)
assert args is ()
save = self.save
write = self.write
save(func)
save(args)
self.write(REDUCE)
if obj is not None:
self.memoize(obj)
self._batch_appends(listitems)
self._batch_setitems(dictitems)
if state is not None:
save(state)
write(BUILD)
Pickler.save_reduce = save_reduce
class Reference(object): class Reference(object):
...@@ -59,17 +88,15 @@ class Repickler(pickle.Unpickler): ...@@ -59,17 +88,15 @@ class Repickler(pickle.Unpickler):
# Use python implementation for unpickling because loading can not # Use python implementation for unpickling because loading can not
# be customized enough with cPickle. # be customized enough with cPickle.
pickle.Unpickler.__init__(self, self._f) pickle.Unpickler.__init__(self, self._f)
# For pickling, it is possible to use the fastest implementation,
# which also generates fewer useless PUT opcodes.
self._p = cPickle.Pickler(self._f, 1)
self.memo = self._p.memo # just a tiny optimization
def persistent_id(obj): def persistent_id(obj):
if isinstance(obj, Reference): if isinstance(obj, Reference):
r = obj.value r = obj.value
del obj.value # minimize refcnt like for deque+popleft del obj.value # minimize refcnt like for deque+popleft
return r return r
self._p.inst_persistent_id = persistent_id # For pickling, it is possible to use the fastest implementation,
# which also generates fewer useless PUT opcodes.
self._p = PersistentPickler(persistent_id, self._f, 1)
self.memo = self._p.memo # just a tiny optimization
def persistent_load(obj): def persistent_load(obj):
new_obj = persistent_map(obj) new_obj = persistent_map(obj)
...@@ -96,8 +123,10 @@ class Repickler(pickle.Unpickler): ...@@ -96,8 +123,10 @@ class Repickler(pickle.Unpickler):
self.memo.clear() self.memo.clear()
if self._changed: if self._changed:
f.truncate(0) f.truncate(0)
dump = self._p.dump
try: try:
self._p.dump(classmeta).dump(state) dump(classmeta)
dump(state)
finally: finally:
self.memo.clear() self.memo.clear()
return f.getvalue() return f.getvalue()
......
...@@ -800,10 +800,9 @@ class DatabaseManager(object): ...@@ -800,10 +800,9 @@ class DatabaseManager(object):
if found_undone_tid is None: if found_undone_tid is None:
return return
if transaction_object: if transaction_object:
try: transaction_tid = transaction_object[2]
current_tid = current_data_tid = u64(transaction_object[2]) current_tid = current_data_tid = \
except struct.error: tid if transaction_tid is None else u64(transaction_tid)
current_tid = current_data_tid = tid
else: else:
current_tid, current_data_tid = getDataTID(before_tid=ltid) current_tid, current_data_tid = getDataTID(before_tid=ltid)
if current_tid is None: if current_tid is None:
......
...@@ -14,25 +14,45 @@ ...@@ -14,25 +14,45 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import os, re, string, struct, sys, time
from binascii import a2b_hex from binascii import a2b_hex
from collections import OrderedDict from collections import OrderedDict
from functools import wraps from functools import wraps
import MySQLdb from . import useMySQLdb
from MySQLdb import DataError, IntegrityError, \ if useMySQLdb():
OperationalError, ProgrammingError binding_name = 'MySQLdb'
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST from MySQLdb.connections import Connection
from MySQLdb.constants.ER import DATA_TOO_LONG, DUP_ENTRY, NO_SUCH_TABLE from MySQLdb import __version__ as binding_version, DataError, \
IntegrityError, OperationalError, ProgrammingError
InternalOrOperationalError = OperationalError
from MySQLdb.constants.CR import SERVER_GONE_ERROR, SERVER_LOST
from MySQLdb.constants.ER import DATA_TOO_LONG, DUP_ENTRY, NO_SUCH_TABLE
def fetch_all(conn):
r = conn.store_result()
return r.fetch_row(r.num_rows())
# for tests
from MySQLdb import NotSupportedError
from MySQLdb.constants.ER import BAD_DB_ERROR, UNKNOWN_STORAGE_ENGINE
else:
binding_name = 'PyMySQL'
from pymysql.connections import Connection
from pymysql import __version__ as binding_version, DataError, \
IntegrityError, InternalError, OperationalError, ProgrammingError
InternalOrOperationalError = InternalError, OperationalError
from pymysql.constants.CR import (
CR_SERVER_GONE_ERROR as SERVER_GONE_ERROR,
CR_SERVER_LOST as SERVER_LOST)
from pymysql.constants.ER import DATA_TOO_LONG, DUP_ENTRY, NO_SUCH_TABLE
def fetch_all(conn):
return conn._result.rows
# for tests
from pymysql import NotSupportedError
from pymysql.constants.ER import BAD_DB_ERROR, UNKNOWN_STORAGE_ENGINE
# BBB: the following 2 constants were added to mysqlclient 1.3.8 # BBB: the following 2 constants were added to mysqlclient 1.3.8
DROP_LAST_PARTITION = 1508 DROP_LAST_PARTITION = 1508
SAME_NAME_PARTITION = 1517 SAME_NAME_PARTITION = 1517
from array import array from array import array
from hashlib import sha1 from hashlib import sha1
import os
import re
import string
import struct
import sys
import time
from . import LOG_QUERIES, DatabaseFailure from . import LOG_QUERIES, DatabaseFailure
from .manager import DatabaseManager, splitOIDField from .manager import DatabaseManager, splitOIDField
...@@ -68,7 +88,7 @@ def auto_reconnect(wrapped): ...@@ -68,7 +88,7 @@ def auto_reconnect(wrapped):
while 1: while 1:
try: try:
return wrapped(self, *args) return wrapped(self, *args)
except OperationalError as m: except InternalOrOperationalError as m:
# IDEA: Is it safe to retry in case of DISK_FULL ? # IDEA: Is it safe to retry in case of DISK_FULL ?
# XXX: However, this would another case of failure that would # XXX: However, this would another case of failure that would
# be unnoticed by other nodes (ADMIN & MASTER). When # be unnoticed by other nodes (ADMIN & MASTER). When
...@@ -121,6 +141,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -121,6 +141,7 @@ class MySQLDatabaseManager(DatabaseManager):
return super(MySQLDatabaseManager, self).__getattr__(attr) return super(MySQLDatabaseManager, self).__getattr__(attr)
def _tryConnect(self): def _tryConnect(self):
# BBB: db/passwd are deprecated favour of database/password since 1.3.8
kwd = {'db' : self.db} kwd = {'db' : self.db}
if self.user: if self.user:
kwd['user'] = self.user kwd['user'] = self.user
...@@ -128,8 +149,8 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -128,8 +149,8 @@ class MySQLDatabaseManager(DatabaseManager):
kwd['passwd'] = self.passwd kwd['passwd'] = self.passwd
if self.socket: if self.socket:
kwd['unix_socket'] = os.path.expanduser(self.socket) kwd['unix_socket'] = os.path.expanduser(self.socket)
logging.info('connecting to MySQL on the database %s with user %s', logging.info('Using %s %s to connect to the database %s with user %s',
self.db, self.user) binding_name, binding_version, self.db, self.user)
self._active = 0 self._active = 0
if self._wait < 0: if self._wait < 0:
timeout_at = None timeout_at = None
...@@ -138,7 +159,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -138,7 +159,7 @@ class MySQLDatabaseManager(DatabaseManager):
last = None last = None
while True: while True:
try: try:
self.conn = MySQLdb.connect(**kwd) self.conn = Connection(**kwd)
break break
except Exception as e: except Exception as e:
if None is not timeout_at <= time.time(): if None is not timeout_at <= time.time():
...@@ -154,15 +175,15 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -154,15 +175,15 @@ class MySQLDatabaseManager(DatabaseManager):
self._config = {} self._config = {}
conn = self.conn conn = self.conn
conn.autocommit(False) conn.autocommit(False)
conn.set_sql_mode("TRADITIONAL,NO_ENGINE_SUBSTITUTION") conn.query("SET"
conn.query("SET SESSION group_concat_max_len = %u" % (2**32-1)) " SESSION sql_mode = 'TRADITIONAL,NO_ENGINE_SUBSTITUTION',"
" SESSION group_concat_max_len = %u" % (2**32-1))
if self._engine == 'RocksDB': if self._engine == 'RocksDB':
# Maximum value for _deleteRange. # Maximum value for _deleteRange.
conn.query("SET SESSION rocksdb_max_row_locks = %u" % 2**30) conn.query("SET SESSION rocksdb_max_row_locks = %u" % 2**30)
def query(sql): def query(sql):
conn.query(sql) conn.query(sql)
r = conn.store_result() return fetch_all(conn)
return r.fetch_row(r.num_rows())
if self.LOCK: if self.LOCK:
(locked,), = query("SELECT GET_LOCK('%s.%s', 0)" (locked,), = query("SELECT GET_LOCK('%s.%s', 0)"
% (self.db, self.LOCK)) % (self.db, self.LOCK))
...@@ -220,8 +241,7 @@ class MySQLDatabaseManager(DatabaseManager): ...@@ -220,8 +241,7 @@ class MySQLDatabaseManager(DatabaseManager):
conn = self.conn conn = self.conn
conn.query(query) conn.query(query)
if query.startswith("SELECT "): if query.startswith("SELECT "):
r = conn.store_result() return fetch_all(conn)
return r.fetch_row(r.num_rows())
r = query.split(None, 1)[0] r = query.split(None, 1)[0]
if r in ("INSERT", "REPLACE", "DELETE", "UPDATE"): if r in ("INSERT", "REPLACE", "DELETE", "UPDATE"):
self._active = 1 self._active = 1
......
...@@ -83,8 +83,8 @@ class SQLiteDatabaseManager(DatabaseManager): ...@@ -83,8 +83,8 @@ class SQLiteDatabaseManager(DatabaseManager):
self.lock(self.db) self.lock(self.db)
if self.UNSAFE: if self.UNSAFE:
q = self.query q = self.query
q("PRAGMA synchronous = OFF") q("PRAGMA synchronous = OFF").fetchall()
q("PRAGMA journal_mode = MEMORY") q("PRAGMA journal_mode = MEMORY").fetchall()
self._config = {} self._config = {}
def _getDevPath(self): def _getDevPath(self):
......
...@@ -28,7 +28,7 @@ import unittest ...@@ -28,7 +28,7 @@ import unittest
import weakref import weakref
import transaction import transaction
from contextlib import contextmanager from contextlib import closing, contextmanager
from ConfigParser import SafeConfigParser from ConfigParser import SafeConfigParser
from cStringIO import StringIO from cStringIO import StringIO
try: try:
...@@ -76,6 +76,12 @@ DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db') ...@@ -76,6 +76,12 @@ DB_INSTALL = os.getenv('NEO_DB_INSTALL', 'mysql_install_db')
DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld') DB_MYSQLD = os.getenv('NEO_DB_MYSQLD', '/usr/sbin/mysqld')
DB_MYCNF = os.getenv('NEO_DB_MYCNF') DB_MYCNF = os.getenv('NEO_DB_MYCNF')
adapter = os.getenv('NEO_TESTS_ADAPTER')
if adapter:
from neo.storage.database import getAdapterKlass
if getAdapterKlass(adapter).__name__ == 'MySQLDatabaseManager':
os.environ['NEO_TESTS_ADAPTER'] = 'MySQL'
IP_VERSION_FORMAT_DICT = { IP_VERSION_FORMAT_DICT = {
socket.AF_INET: '127.0.0.1', socket.AF_INET: '127.0.0.1',
socket.AF_INET6: '::1', socket.AF_INET6: '::1',
...@@ -137,31 +143,28 @@ def getTempDirectory(): ...@@ -137,31 +143,28 @@ def getTempDirectory():
print 'Using temp directory %r.' % temp_dir print 'Using temp directory %r.' % temp_dir
return temp_dir return temp_dir
def setupMySQLdb(db_list, clear_databases=True): def setupMySQL(db_list, clear_databases=True):
if mysql_pool: if mysql_pool:
return mysql_pool.setup(db_list, clear_databases) return mysql_pool.setup(db_list, clear_databases)
import MySQLdb from neo.storage.database.mysql import \
from MySQLdb.constants.ER import BAD_DB_ERROR Connection, OperationalError, BAD_DB_ERROR
user = DB_USER user = DB_USER
password = '' password = ''
kw = {'unix_socket': os.path.expanduser(DB_SOCKET)} if DB_SOCKET else {} kw = {'unix_socket': os.path.expanduser(DB_SOCKET)} if DB_SOCKET else {}
conn = MySQLdb.connect(user=DB_ADMIN, passwd=DB_PASSWD, **kw) # BBB: passwd is deprecated favour of password since 1.3.8
cursor = conn.cursor() with closing(Connection(user=DB_ADMIN, passwd=DB_PASSWD, **kw)) as conn:
for database in db_list: for database in db_list:
try: try:
conn.select_db(database) conn.select_db(database)
if not clear_databases: if not clear_databases:
continue continue
cursor.execute('DROP DATABASE `%s`' % database) conn.query('DROP DATABASE `%s`' % database)
except MySQLdb.OperationalError, (code, _): except OperationalError, (code, _):
if code != BAD_DB_ERROR: if code != BAD_DB_ERROR:
raise raise
cursor.execute('GRANT ALL ON `%s`.* TO "%s"@"localhost" IDENTIFIED' conn.query('GRANT ALL ON `%s`.* TO "%s"@"localhost" IDENTIFIED'
' BY "%s"' % (database, user, password)) ' BY "%s"' % (database, user, password))
cursor.execute('CREATE DATABASE `%s`' % database) conn.query('CREATE DATABASE `%s`' % database)
cursor.close()
conn.commit()
conn.close()
return '{}:{}@%s{}'.format(user, password, DB_SOCKET).__mod__ return '{}:{}@%s{}'.format(user, password, DB_SOCKET).__mod__
class MySQLPool(object): class MySQLPool(object):
...@@ -178,7 +181,7 @@ class MySQLPool(object): ...@@ -178,7 +181,7 @@ class MySQLPool(object):
self.kill(*self._mysqld_dict) self.kill(*self._mysqld_dict)
def setup(self, db_list, clear_databases): def setup(self, db_list, clear_databases):
import MySQLdb from neo.storage.database.mysql import Connection
start_list = set(db_list).difference(self._mysqld_dict) start_list = set(db_list).difference(self._mysqld_dict)
if start_list: if start_list:
start_list = sorted(start_list) start_list = sorted(start_list)
...@@ -221,12 +224,11 @@ class MySQLPool(object): ...@@ -221,12 +224,11 @@ class MySQLPool(object):
if x is not None: if x is not None:
raise subprocess.CalledProcessError(x, DB_MYSQLD) raise subprocess.CalledProcessError(x, DB_MYSQLD)
for db in db_list: for db in db_list:
db = MySQLdb.connect(unix_socket=self._sock_template % db, with closing(Connection(unix_socket=self._sock_template % db,
user='root') user='root')) as db:
if clear_databases: if clear_databases:
db.query('DROP DATABASE IF EXISTS neo') db.query('DROP DATABASE IF EXISTS neo')
db.query('CREATE DATABASE IF NOT EXISTS neo') db.query('CREATE DATABASE IF NOT EXISTS neo')
db.close()
return ('root@neo' + self._sock_template).__mod__ return ('root@neo' + self._sock_template).__mod__
def start(self, *db, **kw): def start(self, *db, **kw):
...@@ -274,6 +276,8 @@ class NeoTestBase(unittest.TestCase): ...@@ -274,6 +276,8 @@ class NeoTestBase(unittest.TestCase):
assert self.tearDown.im_func is NeoTestBase.tearDown.im_func assert self.tearDown.im_func is NeoTestBase.tearDown.im_func
self._tearDown(sys._getframe(1).f_locals['success']) self._tearDown(sys._getframe(1).f_locals['success'])
assert not gc.garbage, gc.garbage assert not gc.garbage, gc.garbage
# XXX: I tried the following line to avoid random freezes on PyPy...
gc.collect()
def _tearDown(self, success): def _tearDown(self, success):
# Kill all unfinished transactions for next test. # Kill all unfinished transactions for next test.
...@@ -335,7 +339,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -335,7 +339,7 @@ class NeoUnitTestBase(NeoTestBase):
""" create empty databases """ """ create empty databases """
adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL') adapter = os.getenv('NEO_TESTS_ADAPTER', 'MySQL')
if adapter == 'MySQL': if adapter == 'MySQL':
db_template = setupMySQLdb( db_template = setupMySQL(
[prefix + str(i) for i in xrange(number)]) [prefix + str(i) for i in xrange(number)])
self.db_template = lambda i: db_template(prefix + str(i)) self.db_template = lambda i: db_template(prefix + str(i))
elif adapter == 'SQLite': elif adapter == 'SQLite':
......
...@@ -51,13 +51,20 @@ class BenchmarkRunner(object): ...@@ -51,13 +51,20 @@ class BenchmarkRunner(object):
def build_report(self, content): def build_report(self, content):
fmt = "%-25s : %s" fmt = "%-25s : %s"
py_impl = platform.python_implementation()
if py_impl == 'PyPy':
info = sys.pypy_version_info
py_impl += ' %s.%s.%s' % info[:3]
kind = info.releaselevel
if kind != 'final':
py_impl += kind[0] + str(info.serial)
status = "\n".join([fmt % item for item in [ status = "\n".join([fmt % item for item in [
('Title', self._config.title), ('Title', self._config.title),
('Date', datetime.date.today().isoformat()), ('Date', datetime.date.today().isoformat()),
('Node', platform.node()), ('Node', platform.node()),
('Machine', platform.machine()), ('Machine', platform.machine()),
('System', platform.system()), ('System', platform.system()),
('Python', platform.python_version()), ('Python', '%s [%s]' % (platform.python_version(), py_impl)),
]]) ]])
status += '\n\n' status += '\n\n'
status += "\n".join([fmt % item for item in self._status]) status += "\n".join([fmt % item for item in self._status])
......
...@@ -36,7 +36,7 @@ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \ ...@@ -36,7 +36,7 @@ from neo.lib.protocol import ClusterStates, NodeTypes, CellStates, NodeStates, \
UUID_NAMESPACES UUID_NAMESPACES
from neo.lib.util import dump, setproctitle from neo.lib.util import dump, setproctitle
from .. import (ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, SSL, from .. import (ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, SSL,
buildUrlFromString, cluster, getTempDirectory, setupMySQLdb, buildUrlFromString, cluster, getTempDirectory, setupMySQL,
ImporterConfigParser, NeoTestBase, Patch) ImporterConfigParser, NeoTestBase, Patch)
from neo.client.Storage import Storage from neo.client.Storage import Storage
from neo.storage.database import manager, buildDatabaseManager from neo.storage.database import manager, buildDatabaseManager
...@@ -55,8 +55,8 @@ command_dict = { ...@@ -55,8 +55,8 @@ command_dict = {
DELAY_SAFETY_MARGIN = 10 DELAY_SAFETY_MARGIN = 10
MAX_START_TIME = 30 MAX_START_TIME = 30
PYPY_EXECUTABLE = os.getenv('NEO_PYPY') NEOMASTER_PYPY = os.getenv('NEOMASTER_PYPY')
if PYPY_EXECUTABLE: if NEOMASTER_PYPY:
import neo, msgpack import neo, msgpack
PYPY_TEMPLATE = """\ PYPY_TEMPLATE = """\
import os, signal, sys import os, signal, sys
...@@ -194,8 +194,8 @@ class Process(object): ...@@ -194,8 +194,8 @@ class Process(object):
from coverage import Coverage from coverage import Coverage
coverage = Coverage(coverage_data_path) coverage = Coverage(coverage_data_path)
coverage.start() coverage.start()
elif PYPY_EXECUTABLE and command == 'neomaster': elif NEOMASTER_PYPY and command == 'neomaster':
os.execlp(PYPY_EXECUTABLE, PYPY_EXECUTABLE, '-c', os.execlp(NEOMASTER_PYPY, NEOMASTER_PYPY, '-c',
PYPY_TEMPLATE % ( PYPY_TEMPLATE % (
w, self._coverage_fd, w, w, self._coverage_fd, w,
logging._max_size, logging._max_packet, logging._max_size, logging._max_packet,
...@@ -348,7 +348,7 @@ class NEOCluster(object): ...@@ -348,7 +348,7 @@ class NEOCluster(object):
temp_dir = tempfile.mkdtemp(prefix='neo_') temp_dir = tempfile.mkdtemp(prefix='neo_')
print 'Using temp directory ' + temp_dir print 'Using temp directory ' + temp_dir
if adapter == 'MySQL': if adapter == 'MySQL':
self.db_template = setupMySQLdb(db_list, clear_databases) self.db_template = setupMySQL(db_list, clear_databases)
elif adapter == 'SQLite': elif adapter == 'SQLite':
self.db_template = (lambda t: lambda db: self.db_template = (lambda t: lambda db:
':memory:' if db is None else db if os.sep in db else t % db ':memory:' if db is None else db if os.sep in db else t % db
......
...@@ -47,10 +47,15 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -47,10 +47,15 @@ class StorageClientHandlerTests(NeoUnitTestBase):
def _getConnection(self, uuid=None): def _getConnection(self, uuid=None):
return self.getFakeConnection(uuid=uuid, address=('127.0.0.1', 1000)) return self.getFakeConnection(uuid=uuid, address=('127.0.0.1', 1000))
def fakeDM(self, **kw):
self.app.dm.close()
self.app.dm = dm = Mock(kw)
return dm
def test_18_askTransactionInformation1(self): def test_18_askTransactionInformation1(self):
# transaction does not exists # transaction does not exists
conn = self._getConnection() conn = self._getConnection()
self.app.dm = Mock({'getNumPartitions': 1}) self.fakeDM(getNumPartitions=1)
self.operation.askTransactionInformation(conn, INVALID_TID) self.operation.askTransactionInformation(conn, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
...@@ -58,7 +63,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -58,7 +63,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
# invalid offsets => error # invalid offsets => error
app = self.app app = self.app
app.pt = Mock() app.pt = Mock()
app.dm = Mock() self.fakeDM()
conn = self._getConnection() conn = self._getConnection()
self.checkProtocolErrorRaised(self.operation.askTIDs, conn, 1, 1, None) self.checkProtocolErrorRaised(self.operation.askTIDs, conn, 1, 1, None)
self.assertEqual(len(app.pt.mockGetNamedCalls('getCellList')), 0) self.assertEqual(len(app.pt.mockGetNamedCalls('getCellList')), 0)
...@@ -68,7 +73,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -68,7 +73,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
# well case => answer # well case => answer
conn = self._getConnection() conn = self._getConnection()
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )}) self.fakeDM(getTIDList=(INVALID_TID,))
self.operation.askTIDs(conn, 1, 2, 1) self.operation.askTIDs(conn, 1, 2, 1)
calls = self.app.dm.mockGetNamedCalls('getTIDList') calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
...@@ -77,12 +82,11 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -77,12 +82,11 @@ class StorageClientHandlerTests(NeoUnitTestBase):
def test_26_askObjectHistory1(self): def test_26_askObjectHistory1(self):
# invalid offsets => error # invalid offsets => error
app = self.app dm = self.fakeDM()
app.dm = Mock()
conn = self._getConnection() conn = self._getConnection()
self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn, self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn,
1, 1, None) 1, 1, None)
self.assertEqual(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0) self.assertEqual(len(dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_askObjectUndoSerial(self): def test_askObjectUndoSerial(self):
conn = self._getConnection(uuid=self.getClientUUID()) conn = self._getConnection(uuid=self.getClientUUID())
...@@ -94,9 +98,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -94,9 +98,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
self.app.tm = Mock({ self.app.tm = Mock({
'getObjectFromTransaction': None, 'getObjectFromTransaction': None,
}) })
self.app.dm = Mock({ self.fakeDM(findUndoTID=ReturnValues((None, None, False),))
'findUndoTID': ReturnValues((None, None, False), )
})
self.operation.askObjectUndoSerial(conn, tid, ltid, undone_tid, oid_list) self.operation.askObjectUndoSerial(conn, tid, ltid, undone_tid, oid_list)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
......
...@@ -82,8 +82,9 @@ class StorageMasterHandlerTests(NeoUnitTestBase): ...@@ -82,8 +82,9 @@ class StorageMasterHandlerTests(NeoUnitTestBase):
app.pt = PartitionTable(3, 1) app.pt = PartitionTable(3, 1)
app.pt._id = 1 app.pt._id = 1
ptid = 2 ptid = 2
app.dm = Mock({ }) app.dm.close()
app.replicator = Mock({}) app.dm = Mock()
app.replicator = Mock()
self.operation.notifyPartitionChanges(conn, ptid, 1, cells) self.operation.notifyPartitionChanges(conn, ptid, 1, cells)
# ptid set # ptid set
self.assertEqual(app.pt.getID(), ptid) self.assertEqual(app.pt.getID(), ptid)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
from binascii import a2b_hex from binascii import a2b_hex
from contextlib import contextmanager from contextlib import closing, contextmanager
import unittest import unittest
from neo.lib.util import add64, p64, u64 from neo.lib.util import add64, p64, u64
from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID from neo.lib.protocol import CellStates, ZERO_HASH, ZERO_OID, ZERO_TID, MAX_TID
...@@ -34,22 +34,18 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -34,22 +34,18 @@ class StorageDBTests(NeoUnitTestBase):
try: try:
return self._db return self._db
except AttributeError: except AttributeError:
self.setNumPartitions(1) self.setupDB(1)
return self._db return self._db
def _tearDown(self, success): def _getDB(self, reset):
try:
self.__dict__.pop('_db', None).close()
except AttributeError:
pass
NeoUnitTestBase._tearDown(self, success)
def getDB(self, reset=0):
raise NotImplementedError raise NotImplementedError
def setNumPartitions(self, num_partitions, reset=0): def setupDB(self, num_partitions=None, reset=False):
assert not hasattr(self, '_db') assert not hasattr(self, '_db')
self._db = db = self.getDB(reset) self._db = db = self._getDB(reset)
self.addCleanup(db.close)
if num_partitions is None:
return
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
db.setUUID(uuid) db.setUUID(uuid)
self.assertEqual(uuid, db.getUUID()) self.assertEqual(uuid, db.getUUID())
...@@ -80,12 +76,12 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -80,12 +76,12 @@ class StorageDBTests(NeoUnitTestBase):
self.db.abortTransaction(ttid) self.db.abortTransaction(ttid)
def test_UUID(self): def test_UUID(self):
db = self.getDB() self.setupDB()
self.checkConfigEntry(db.getUUID, db.setUUID, 123) self.checkConfigEntry(self.db.getUUID, self.db.setUUID, 123)
def test_Name(self): def test_Name(self):
db = self.getDB() self.setupDB()
self.checkConfigEntry(db.getName, db.setName, 'TEST_NAME') self.checkConfigEntry(self.db.getName, self.db.setName, 'TEST_NAME')
def getOIDs(self, count): def getOIDs(self, count):
return map(p64, xrange(count)) return map(p64, xrange(count))
...@@ -111,9 +107,8 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -111,9 +107,8 @@ class StorageDBTests(NeoUnitTestBase):
raise NotImplementedError raise NotImplementedError
def test_lockDatabase(self): def test_lockDatabase(self):
db = self._test_lockDatabase_open() with closing(self._test_lockDatabase_open()) as db:
self.assertRaises(SystemExit, self._test_lockDatabase_open) self.assertRaises(SystemExit, self._test_lockDatabase_open)
db.close()
self._test_lockDatabase_open().close() self._test_lockDatabase_open().close()
def test_getUnfinishedTIDDict(self): def test_getUnfinishedTIDDict(self):
...@@ -237,7 +232,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -237,7 +232,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_deleteRange(self): def test_deleteRange(self):
np = 4 np = 4
self.setNumPartitions(np) self.setupDB(np)
t1, t2, t3 = map(p64, (1, 2, 3)) t1, t2, t3 = map(p64, (1, 2, 3))
oid_list = self.getOIDs(np * 2) oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3: for tid in t1, t2, t3:
...@@ -310,7 +305,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -310,7 +305,7 @@ class StorageDBTests(NeoUnitTestBase):
return tid_list return tid_list
def test_getTIDList(self): def test_getTIDList(self):
self.setNumPartitions(2, True) self.setupDB(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# get tids # get tids
# - all partitions # - all partitions
...@@ -330,7 +325,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -330,7 +325,7 @@ class StorageDBTests(NeoUnitTestBase):
self.checkSet(result, []) self.checkSet(result, [])
def test_getReplicationTIDList(self): def test_getReplicationTIDList(self):
self.setNumPartitions(2, True) self.setupDB(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
# - one partition # - one partition
result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0) result = self.db.getReplicationTIDList(ZERO_TID, MAX_TID, 10, 0)
...@@ -352,7 +347,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -352,7 +347,7 @@ class StorageDBTests(NeoUnitTestBase):
def check(trans, obj, *args): def check(trans, obj, *args):
self.assertEqual(trans, self.db.checkTIDRange(*args)) self.assertEqual(trans, self.db.checkTIDRange(*args))
self.assertEqual(obj, self.db.checkSerialRange(*(args+(ZERO_OID,)))) self.assertEqual(obj, self.db.checkSerialRange(*(args+(ZERO_OID,))))
self.setNumPartitions(2, True) self.setupDB(2, True)
tid1, tid2, tid3, tid4 = self._storeTransactions(4) tid1, tid2, tid3, tid4 = self._storeTransactions(4)
z = 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID z = 0, ZERO_HASH, ZERO_TID, ZERO_HASH, ZERO_OID
# - one partition # - one partition
...@@ -380,7 +375,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -380,7 +375,7 @@ class StorageDBTests(NeoUnitTestBase):
check(y, x + y[1:], 1, 1, ZERO_TID, MAX_TID) check(y, x + y[1:], 1, 1, ZERO_TID, MAX_TID)
def test_findUndoTID(self): def test_findUndoTID(self):
self.setNumPartitions(4, True) self.setupDB(4, True)
db = self.db db = self.db
tid1 = self.getNextTID() tid1 = self.getNextTID()
tid2 = self.getNextTID() tid2 = self.getNextTID()
......
...@@ -15,17 +15,16 @@ ...@@ -15,17 +15,16 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from contextlib import contextmanager from contextlib import closing, contextmanager
from MySQLdb import NotSupportedError, OperationalError, ProgrammingError
from MySQLdb.constants.CR import SERVER_GONE_ERROR
from MySQLdb.constants.ER import UNKNOWN_STORAGE_ENGINE
from ..mock import Mock from ..mock import Mock
from neo.lib.protocol import ZERO_OID from neo.lib.protocol import ZERO_OID
from neo.lib.util import p64 from neo.lib.util import p64
from .. import DB_PREFIX, DB_USER, Patch, setupMySQLdb from .. import DB_PREFIX, DB_USER, Patch, setupMySQL
from .testStorageDBTests import StorageDBTests from .testStorageDBTests import StorageDBTests
from neo.storage.database import DatabaseFailure from neo.storage.database import DatabaseFailure
from neo.storage.database.mysql import MySQLDatabaseManager from neo.storage.database.mysql import (MySQLDatabaseManager,
NotSupportedError, OperationalError, ProgrammingError,
SERVER_GONE_ERROR, UNKNOWN_STORAGE_ENGINE)
class ServerGone(object): class ServerGone(object):
...@@ -50,17 +49,21 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -50,17 +49,21 @@ class StorageMySQLdbTests(StorageDBTests):
database = self.db_template(0) database = self.db_template(0)
return MySQLDatabaseManager(database, self.engine) return MySQLDatabaseManager(database, self.engine)
def getDB(self, reset=0): def _getDB(self, reset):
db = self._test_lockDatabase_open() db = self._test_lockDatabase_open()
self.assertEqual(db.db, DB_PREFIX + '0')
self.assertEqual(db.user, DB_USER)
try: try:
db.setup(reset, True) self.assertEqual(db.db, DB_PREFIX + '0')
except NotSupportedError as m: self.assertEqual(db.user, DB_USER)
code, m = m.args try:
if code != UNKNOWN_STORAGE_ENGINE: db.setup(reset, True)
raise except NotSupportedError as m:
raise unittest.SkipTest(m) code, m = m.args
if code != UNKNOWN_STORAGE_ENGINE:
raise
raise unittest.SkipTest(m)
except:
db.close()
raise
return db return db
def test_ServerGone(self): def test_ServerGone(self):
...@@ -75,8 +78,9 @@ class StorageMySQLdbTests(StorageDBTests): ...@@ -75,8 +78,9 @@ class StorageMySQLdbTests(StorageDBTests):
pass pass
def query(*args): def query(*args):
raise OperationalError(-1, 'this is a test') raise OperationalError(-1, 'this is a test')
self.db.conn = FakeConn() with closing(self.db.conn):
self.assertRaises(DatabaseFailure, self.db.query, 'QUERY') self.db.conn = FakeConn()
self.assertRaises(DatabaseFailure, self.db.query, 'QUERY')
def test_escape(self): def test_escape(self):
self.assertEqual(self.db.escape('a"b'), 'a\\"b') self.assertEqual(self.db.escape('a"b'), 'a\\"b')
......
...@@ -25,7 +25,7 @@ class StorageSQLiteTests(StorageDBTests): ...@@ -25,7 +25,7 @@ class StorageSQLiteTests(StorageDBTests):
db = os.path.join(getTempDirectory(), DB_PREFIX + '0.sqlite') db = os.path.join(getTempDirectory(), DB_PREFIX + '0.sqlite')
return SQLiteDatabaseManager(db) return SQLiteDatabaseManager(db)
def getDB(self, reset=0): def _getDB(self, reset=False):
db = SQLiteDatabaseManager(':memory:') db = SQLiteDatabaseManager(':memory:')
db.setup(reset, True) db.setup(reset, True)
return db return db
...@@ -33,8 +33,8 @@ class StorageSQLiteTests(StorageDBTests): ...@@ -33,8 +33,8 @@ class StorageSQLiteTests(StorageDBTests):
def test_lockDatabase(self): def test_lockDatabase(self):
super(StorageSQLiteTests, self).test_lockDatabase() super(StorageSQLiteTests, self).test_lockDatabase()
# No lock on temporary databases. # No lock on temporary databases.
db = self.getDB() db = self._getDB()
self.getDB().close() self._getDB().close()
db.close() db.close()
del StorageDBTests del StorageDBTests
......
...@@ -40,7 +40,7 @@ from neo.lib.protocol import ZERO_OID, ZERO_TID, MAX_TID, uuid_str, \ ...@@ -40,7 +40,7 @@ from neo.lib.protocol import ZERO_OID, ZERO_TID, MAX_TID, uuid_str, \
ClusterStates, Enum, NodeStates, NodeTypes, Packets ClusterStates, Enum, NodeStates, NodeTypes, Packets
from neo.lib.util import cached_property, parseMasterList, p64 from neo.lib.util import cached_property, parseMasterList, p64
from neo.master.recovery import RecoveryManager from neo.master.recovery import RecoveryManager
from .. import (getTempDirectory, setupMySQLdb, from .. import (getTempDirectory, setupMySQL,
ImporterConfigParser, NeoTestBase, Patch, ImporterConfigParser, NeoTestBase, Patch,
ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX) ADDRESS_TYPE, IP_VERSION_FORMAT_DICT, DB_PREFIX)
...@@ -787,7 +787,7 @@ class NEOCluster(object): ...@@ -787,7 +787,7 @@ class NEOCluster(object):
db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index)) db_list = ['%s%u' % (DB_PREFIX, self._allocate('db', index))
for _ in xrange(storage_count)] for _ in xrange(storage_count)]
if adapter == 'MySQL': if adapter == 'MySQL':
db = setupMySQLdb(db_list, clear_databases) db = setupMySQL(db_list, clear_databases)
elif adapter == 'SQLite': elif adapter == 'SQLite':
db = os.path.join(getTempDirectory(), '%s.sqlite').__mod__ db = os.path.join(getTempDirectory(), '%s.sqlite').__mod__
else: else:
......
...@@ -1663,6 +1663,9 @@ class Test(NEOThreadedTest): ...@@ -1663,6 +1663,9 @@ class Test(NEOThreadedTest):
m2c, = cluster.master.getConnectionList(cluster.client) m2c, = cluster.master.getConnectionList(cluster.client)
cluster.client._cache.clear() cluster.client._cache.clear()
c.cacheMinimize() c.cacheMinimize()
if not hasattr(sys, 'getrefcount'): # PyPy
# See persistent commit ff64867cca3179b1a6379c93b6ef90db565da36c
import gc; gc.collect()
# Make the master disconnects the client when the latter is about # Make the master disconnects the client when the latter is about
# to send a AskObject packet to the storage node. # to send a AskObject packet to the storage node.
with cluster.client.filterConnection(cluster.storage) as c2s: with cluster.client.filterConnection(cluster.storage) as c2s:
......
...@@ -128,7 +128,9 @@ class ImporterTests(NEOThreadedTest): ...@@ -128,7 +128,9 @@ class ImporterTests(NEOThreadedTest):
r5["foo"] = "bar" r5["foo"] = "bar"
state = {r2: r3, r4: r5} state = {r2: r3, r4: r5}
p = StringIO() p = StringIO()
Pickler(p, 1).dump(Obj).dump(state) pickler = Pickler(p, 1)
pickler.dump(Obj)
pickler.dump(state)
p = p.getvalue() p = p.getvalue()
r = DummyRepickler()(p) r = DummyRepickler()(p)
load = Unpickler(StringIO(r)).load load = Unpickler(StringIO(r)).load
......
...@@ -10,6 +10,8 @@ Intended Audience :: Developers ...@@ -10,6 +10,8 @@ Intended Audience :: Developers
License :: OSI Approved :: GNU General Public License (GPL) License :: OSI Approved :: GNU General Public License (GPL)
Operating System :: POSIX :: Linux Operating System :: POSIX :: Linux
Programming Language :: Python :: 2.7 Programming Language :: Python :: 2.7
Programming Language :: Python :: Implementation :: CPython
Programming Language :: Python :: Implementation :: PyPy
Topic :: Database Topic :: Database
Topic :: Software Development :: Libraries :: Python Modules Topic :: Software Development :: Libraries :: Python Modules
""" """
...@@ -53,6 +55,7 @@ extras_require = { ...@@ -53,6 +55,7 @@ extras_require = {
'master': [], 'master': [],
'storage-sqlite': [], 'storage-sqlite': [],
'storage-mysqldb': ['mysqlclient'], 'storage-mysqldb': ['mysqlclient'],
'storage-pymysql': ['PyMySQL'],
'storage-importer': zodb_require + ['setproctitle'], 'storage-importer': zodb_require + ['setproctitle'],
} }
extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2', extras_require['tests'] = ['coverage', 'zope.testing', 'psutil>=2',
......
...@@ -18,7 +18,7 @@ from neo.lib.debug import PdbSocket ...@@ -18,7 +18,7 @@ from neo.lib.debug import PdbSocket
from neo.lib.node import Node from neo.lib.node import Node
from neo.lib.protocol import NodeTypes from neo.lib.protocol import NodeTypes
from neo.lib.util import datetimeFromTID, p64, u64 from neo.lib.util import datetimeFromTID, p64, u64
from neo.storage.app import DATABASE_MANAGER_DICT, \ from neo.storage.app import DATABASE_MANAGERS, \
Application as StorageApplication Application as StorageApplication
from neo.tests import getTempDirectory, mysql_pool from neo.tests import getTempDirectory, mysql_pool
from neo.tests.ConflictFree import ConflictFreeLog from neo.tests.ConflictFree import ConflictFreeLog
...@@ -580,7 +580,7 @@ class ArgumentDefaultsHelpFormatter(argparse.HelpFormatter): ...@@ -580,7 +580,7 @@ class ArgumentDefaultsHelpFormatter(argparse.HelpFormatter):
def main(): def main():
adapters = sorted(DATABASE_MANAGER_DICT) adapters = list(DATABASE_MANAGERS)
adapters.remove('Importer') adapters.remove('Importer')
default_adapter = 'SQLite' default_adapter = 'SQLite'
assert default_adapter in adapters assert default_adapter in adapters
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment