Commit a5eebbae authored by Jens Vagelpohl's avatar Jens Vagelpohl

- full linting with flake8

parent 0e19e22b
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# https://github.com/zopefoundation/meta/tree/master/config/pure-python # https://github.com/zopefoundation/meta/tree/master/config/pure-python
[meta] [meta]
template = "pure-python" template = "pure-python"
commit-id = "3b712f305ca8207e971c5bf81f2bdb5872489f2f" commit-id = "0c07a1cfd78d28a07aebd23383ed16959f166574"
[python] [python]
with-windows = false with-windows = false
...@@ -13,7 +13,7 @@ with-docs = true ...@@ -13,7 +13,7 @@ with-docs = true
with-sphinx-doctests = false with-sphinx-doctests = false
[tox] [tox]
use-flake8 = false use-flake8 = true
testenv-commands = [ testenv-commands = [
"# Run unit tests first.", "# Run unit tests first.",
"zope-testrunner -u --test-path=src {posargs:-vc}", "zope-testrunner -u --test-path=src {posargs:-vc}",
......
...@@ -4,6 +4,8 @@ Changelog ...@@ -4,6 +4,8 @@ Changelog
5.4.0 (unreleased) 5.4.0 (unreleased)
------------------ ------------------
- linted the code with flake8
- Add support for Python 3.10. - Add support for Python 3.10.
- Add ``ConflictError`` to the list of unlogged server exceptions - Add ``ConflictError`` to the list of unlogged server exceptions
......
...@@ -11,11 +11,12 @@ ...@@ -11,11 +11,12 @@
# FOR A PARTICULAR PURPOSE. # FOR A PARTICULAR PURPOSE.
# #
############################################################################## ##############################################################################
version = '5.3.1.dev0'
from setuptools import setup, find_packages from setuptools import setup, find_packages
import os import os
version = '5.3.1.dev0'
install_requires = [ install_requires = [
'ZODB >= 5.1.1', 'ZODB >= 5.1.1',
'six', 'six',
...@@ -64,12 +65,14 @@ Operating System :: Unix ...@@ -64,12 +65,14 @@ Operating System :: Unix
Framework :: ZODB Framework :: ZODB
""".strip().split('\n') """.strip().split('\n')
def _modname(path, base, name=''): def _modname(path, base, name=''):
if path == base: if path == base:
return name return name
dirname, basename = os.path.split(path) dirname, basename = os.path.split(path)
return _modname(dirname, base, basename + '.' + name) return _modname(dirname, base, basename + '.' + name)
def _flatten(suite, predicate=lambda *x: True): def _flatten(suite, predicate=lambda *x: True):
from unittest import TestCase from unittest import TestCase
for suite_or_case in suite: for suite_or_case in suite:
...@@ -80,18 +83,20 @@ def _flatten(suite, predicate=lambda *x: True): ...@@ -80,18 +83,20 @@ def _flatten(suite, predicate=lambda *x: True):
for x in _flatten(suite_or_case): for x in _flatten(suite_or_case):
yield x yield x
def _no_layer(suite_or_case): def _no_layer(suite_or_case):
return getattr(suite_or_case, 'layer', None) is None return getattr(suite_or_case, 'layer', None) is None
def _unittests_only(suite, mod_suite): def _unittests_only(suite, mod_suite):
for case in _flatten(mod_suite, _no_layer): for case in _flatten(mod_suite, _no_layer):
suite.addTest(case) suite.addTest(case)
def alltests(): def alltests():
import logging import logging
import pkg_resources import pkg_resources
import unittest import unittest
import ZEO.ClientStorage
class NullHandler(logging.Handler): class NullHandler(logging.Handler):
level = 50 level = 50
...@@ -107,7 +112,8 @@ def alltests(): ...@@ -107,7 +112,8 @@ def alltests():
for dirpath, dirnames, filenames in os.walk(base): for dirpath, dirnames, filenames in os.walk(base):
if os.path.basename(dirpath) == 'tests': if os.path.basename(dirpath) == 'tests':
for filename in filenames: for filename in filenames:
if filename != 'testZEO.py': continue if filename != 'testZEO.py':
continue
if filename.endswith('.py') and filename.startswith('test'): if filename.endswith('.py') and filename.startswith('test'):
mod = __import__( mod = __import__(
_modname(dirpath, base, os.path.splitext(filename)[0]), _modname(dirpath, base, os.path.splitext(filename)[0]),
...@@ -115,11 +121,13 @@ def alltests(): ...@@ -115,11 +121,13 @@ def alltests():
_unittests_only(suite, mod.test_suite()) _unittests_only(suite, mod.test_suite())
return suite return suite
long_description = ( long_description = (
open('README.rst').read() open('README.rst').read()
+ '\n' + + '\n' +
open('CHANGES.rst').read() open('CHANGES.rst').read()
) )
setup(name="ZEO", setup(name="ZEO",
version=version, version=version,
description=long_description.split('\n', 2)[1], description=long_description.split('\n', 2)[1],
...@@ -164,4 +172,4 @@ setup(name="ZEO", ...@@ -164,4 +172,4 @@ setup(name="ZEO",
""", """,
include_package_data=True, include_package_data=True,
python_requires='>=2.7.9,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*', python_requires='>=2.7.9,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*',
) )
...@@ -52,9 +52,11 @@ import ZEO.cache ...@@ -52,9 +52,11 @@ import ZEO.cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def tid2time(tid): def tid2time(tid):
return str(TimeStamp(tid)) return str(TimeStamp(tid))
def get_timestamp(prev_ts=None): def get_timestamp(prev_ts=None):
"""Internal helper to return a unique TimeStamp instance. """Internal helper to return a unique TimeStamp instance.
...@@ -69,8 +71,10 @@ def get_timestamp(prev_ts=None): ...@@ -69,8 +71,10 @@ def get_timestamp(prev_ts=None):
t = t.laterThan(prev_ts) t = t.laterThan(prev_ts)
return t return t
MB = 1024**2 MB = 1024**2
@zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage) @zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage)
class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
"""A storage class that is a network client to a remote storage. """A storage class that is a network client to a remote storage.
...@@ -90,7 +94,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -90,7 +94,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
blob_cache_size=None, blob_cache_size_check=10, blob_cache_size=None, blob_cache_size_check=10,
client_label=None, client_label=None,
cache=None, cache=None,
ssl = None, ssl_server_hostname=None, ssl=None, ssl_server_hostname=None,
# Mostly ignored backward-compatability options # Mostly ignored backward-compatability options
client=None, var=None, client=None, var=None,
min_disconnect_poll=1, max_disconnect_poll=None, min_disconnect_poll=1, max_disconnect_poll=None,
...@@ -196,7 +200,8 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -196,7 +200,8 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
raise ValueError("Unix sockets are not available on Windows") raise ValueError("Unix sockets are not available on Windows")
addr = [addr] addr = [addr]
elif (isinstance(addr, tuple) and len(addr) == 2 and elif (isinstance(addr, tuple) and len(addr) == 2 and
isinstance(addr[0], six.string_types) and isinstance(addr[1], int)): isinstance(addr[0], six.string_types) and
isinstance(addr[1], int)):
addr = [addr] addr = [addr]
logger.info( logger.info(
...@@ -266,7 +271,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -266,7 +271,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
addr, self, cache, storage, addr, self, cache, storage,
ZEO.asyncio.client.Fallback if read_only_fallback else read_only, ZEO.asyncio.client.Fallback if read_only_fallback else read_only,
wait_timeout or 30, wait_timeout or 30,
ssl = ssl, ssl_server_hostname=ssl_server_hostname, ssl=ssl, ssl_server_hostname=ssl_server_hostname,
credentials=credentials, credentials=credentials,
) )
self._call = self._server.call self._call = self._server.call
...@@ -308,6 +313,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -308,6 +313,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._check_blob_size_thread.join() self._check_blob_size_thread.join()
_check_blob_size_thread = None _check_blob_size_thread = None
def _check_blob_size(self, bytes=None): def _check_blob_size(self, bytes=None):
if self._blob_cache_size is None: if self._blob_cache_size is None:
return return
...@@ -349,8 +355,8 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -349,8 +355,8 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
pass pass
_connection_generation = 0 _connection_generation = 0
def notify_connected(self, conn, info): def notify_connected(self, conn, info):
reconnected = self._connection_generation
self.set_server_addr(conn.get_peername()) self.set_server_addr(conn.get_peername())
self.protocol_version = conn.protocol_version self.protocol_version = conn.protocol_version
self._is_read_only = conn.is_read_only() self._is_read_only = conn.is_read_only()
...@@ -373,22 +379,20 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -373,22 +379,20 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._info.update(info) self._info.update(info)
for iface in ( for iface in (ZODB.interfaces.IStorageRestoreable,
ZODB.interfaces.IStorageRestoreable,
ZODB.interfaces.IStorageIteration, ZODB.interfaces.IStorageIteration,
ZODB.interfaces.IStorageUndoable, ZODB.interfaces.IStorageUndoable,
ZODB.interfaces.IStorageCurrentRecordIteration, ZODB.interfaces.IStorageCurrentRecordIteration,
ZODB.interfaces.IBlobStorage, ZODB.interfaces.IBlobStorage,
ZODB.interfaces.IExternalGC, ZODB.interfaces.IExternalGC):
): if (iface.__module__, iface.__name__) in \
if (iface.__module__, iface.__name__) in self._info.get( self._info.get('interfaces', ()):
'interfaces', ()):
zope.interface.alsoProvides(self, iface) zope.interface.alsoProvides(self, iface)
if self.protocol_version[1:] >= b'5': if self.protocol_version[1:] >= b'5':
self.ping = lambda : self._call('ping') self.ping = lambda: self._call('ping')
else: else:
self.ping = lambda : self._call('lastTransaction') self.ping = lambda: self._call('lastTransaction')
if self.server_sync: if self.server_sync:
self.sync = self.ping self.sync = self.ping
...@@ -735,7 +739,6 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -735,7 +739,6 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
finally: finally:
lock.close() lock.close()
def temporaryDirectory(self): def temporaryDirectory(self):
return self.fshelper.temp_dir return self.fshelper.temp_dir
...@@ -843,11 +846,11 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -843,11 +846,11 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
def tpc_abort(self, txn, timeout=None): def tpc_abort(self, txn, timeout=None):
"""Storage API: abort a transaction. """Storage API: abort a transaction.
(The timeout keyword argument is for tests to wat longer than (The timeout keyword argument is for tests to wait longer than
they normally would.) they normally would.)
""" """
try: try:
tbuf = txn.data(self) tbuf = txn.data(self) # NOQA: F841 unused variable
except KeyError: except KeyError:
return return
...@@ -899,7 +902,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -899,7 +902,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
while blobs: while blobs:
oid, blobfilename = blobs.pop() oid, blobfilename = blobs.pop()
self._blob_data_bytes_loaded += os.stat(blobfilename).st_size self._blob_data_bytes_loaded += os.stat(blobfilename).st_size
targetpath = self.fshelper.getPathForOID(oid, create=True) self.fshelper.getPathForOID(oid, create=True)
target_blob_file_name = self.fshelper.getBlobFilename(oid, tid) target_blob_file_name = self.fshelper.getBlobFilename(oid, tid)
lock = _lock_blob(target_blob_file_name) lock = _lock_blob(target_blob_file_name)
try: try:
...@@ -1037,6 +1040,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -1037,6 +1040,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
def server_status(self): def server_status(self):
return self._call('server_status') return self._call('server_status')
class TransactionIterator(object): class TransactionIterator(object):
def __init__(self, storage, iid, *args): def __init__(self, storage, iid, *args):
...@@ -1130,6 +1134,7 @@ class BlobCacheLayout(object): ...@@ -1130,6 +1134,7 @@ class BlobCacheLayout(object):
ZODB.blob.BLOB_SUFFIX) ZODB.blob.BLOB_SUFFIX)
) )
def _accessed(filename): def _accessed(filename):
try: try:
os.utime(filename, (time.time(), os.stat(filename).st_mtime)) os.utime(filename, (time.time(), os.stat(filename).st_mtime))
...@@ -1137,7 +1142,10 @@ def _accessed(filename): ...@@ -1137,7 +1142,10 @@ def _accessed(filename):
pass # We tried. :) pass # We tried. :)
return filename return filename
cache_file_name = re.compile(r'\d+$').match cache_file_name = re.compile(r'\d+$').match
def _check_blob_cache_size(blob_dir, target): def _check_blob_cache_size(blob_dir, target):
logger = logging.getLogger(__name__+'.check_blob_cache') logger = logging.getLogger(__name__+'.check_blob_cache')
...@@ -1222,7 +1230,7 @@ def _check_blob_cache_size(blob_dir, target): ...@@ -1222,7 +1230,7 @@ def _check_blob_cache_size(blob_dir, target):
fsize = os.stat(file_name).st_size fsize = os.stat(file_name).st_size
try: try:
ZODB.blob.remove_committed(file_name) ZODB.blob.remove_committed(file_name)
except OSError as v: except OSError:
pass # probably open on windows pass # probably open on windows
else: else:
size -= fsize size -= fsize
...@@ -1238,12 +1246,14 @@ def _check_blob_cache_size(blob_dir, target): ...@@ -1238,12 +1246,14 @@ def _check_blob_cache_size(blob_dir, target):
finally: finally:
check_lock.close() check_lock.close()
def check_blob_size_script(args=None): def check_blob_size_script(args=None):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
blob_dir, target = args blob_dir, target = args
_check_blob_cache_size(blob_dir, int(target)) _check_blob_cache_size(blob_dir, int(target))
def _lock_blob(path): def _lock_blob(path):
lockfilename = os.path.join(os.path.dirname(path), '.lock') lockfilename = os.path.join(os.path.dirname(path), '.lock')
n = 0 n = 0
...@@ -1258,6 +1268,7 @@ def _lock_blob(path): ...@@ -1258,6 +1268,7 @@ def _lock_blob(path):
else: else:
break break
def open_cache(cache, var, client, storage, cache_size): def open_cache(cache, var, client, storage, cache_size):
if isinstance(cache, (None.__class__, str)): if isinstance(cache, (None.__class__, str)):
from ZEO.cache import ClientCache from ZEO.cache import ClientCache
......
...@@ -17,27 +17,33 @@ import transaction.interfaces ...@@ -17,27 +17,33 @@ import transaction.interfaces
from ZODB.POSException import StorageError from ZODB.POSException import StorageError
class ClientStorageError(StorageError): class ClientStorageError(StorageError):
"""An error occurred in the ZEO Client Storage. """An error occurred in the ZEO Client Storage.
""" """
class UnrecognizedResult(ClientStorageError): class UnrecognizedResult(ClientStorageError):
"""A server call returned an unrecognized result. """A server call returned an unrecognized result.
""" """
class ClientDisconnected(ClientStorageError, class ClientDisconnected(ClientStorageError,
transaction.interfaces.TransientError): transaction.interfaces.TransientError):
"""The database storage is disconnected from the storage. """The database storage is disconnected from the storage.
""" """
class AuthError(StorageError): class AuthError(StorageError):
"""The client provided invalid authentication credentials. """The client provided invalid authentication credentials.
""" """
class ProtocolError(ClientStorageError): class ProtocolError(ClientStorageError):
"""A client contacted a server with an incomparible protocol """A client contacted a server with an incomparible protocol
""" """
class ServerException(ClientStorageError): class ServerException(ClientStorageError):
""" """
""" """
...@@ -23,13 +23,11 @@ import codecs ...@@ -23,13 +23,11 @@ import codecs
import itertools import itertools
import logging import logging
import os import os
import socket
import sys import sys
import tempfile import tempfile
import threading import threading
import time import time
import warnings import warnings
import ZEO.asyncio.server
import ZODB.blob import ZODB.blob
import ZODB.event import ZODB.event
import ZODB.serialize import ZODB.serialize
...@@ -37,8 +35,7 @@ import ZODB.TimeStamp ...@@ -37,8 +35,7 @@ import ZODB.TimeStamp
import zope.interface import zope.interface
import six import six
from ZEO._compat import Pickler, Unpickler, PY3, BytesIO from ZEO._compat import Pickler, Unpickler, PY3
from ZEO.Exceptions import AuthError
from ZEO.monitor import StorageStats from ZEO.monitor import StorageStats
from ZEO.asyncio.server import Delay, MTDelay, Result from ZEO.asyncio.server import Delay, MTDelay, Result
from ZODB.Connection import TransactionMetaData from ZODB.Connection import TransactionMetaData
...@@ -46,7 +43,7 @@ from ZODB.loglevels import BLATHER ...@@ -46,7 +43,7 @@ from ZODB.loglevels import BLATHER
from ZODB.POSException import StorageError, StorageTransactionError from ZODB.POSException import StorageError, StorageTransactionError
from ZODB.POSException import TransactionError, ReadOnlyError, ConflictError from ZODB.POSException import TransactionError, ReadOnlyError, ConflictError
from ZODB.serialize import referencesf from ZODB.serialize import referencesf
from ZODB.utils import oid_repr, p64, u64, z64, Lock, RLock from ZODB.utils import p64, u64, z64, Lock, RLock
# BBB mtacceptor is unused and will be removed in ZEO version 6 # BBB mtacceptor is unused and will be removed in ZEO version 6
if os.environ.get("ZEO_MTACCEPTOR"): # mainly for tests if os.environ.get("ZEO_MTACCEPTOR"): # mainly for tests
...@@ -58,6 +55,7 @@ else: ...@@ -58,6 +55,7 @@ else:
logger = logging.getLogger('ZEO.StorageServer') logger = logging.getLogger('ZEO.StorageServer')
def log(message, level=logging.INFO, label='', exc_info=False): def log(message, level=logging.INFO, label='', exc_info=False):
"""Internal helper to log a message.""" """Internal helper to log a message."""
if label: if label:
...@@ -68,7 +66,9 @@ def log(message, level=logging.INFO, label='', exc_info=False): ...@@ -68,7 +66,9 @@ def log(message, level=logging.INFO, label='', exc_info=False):
class StorageServerError(StorageError): class StorageServerError(StorageError):
"""Error reported when an unpicklable exception is raised.""" """Error reported when an unpicklable exception is raised."""
registered_methods = set(( 'get_info', 'lastTransaction',
registered_methods = set(
('get_info', 'lastTransaction',
'getInvalidations', 'new_oids', 'pack', 'loadBefore', 'storea', 'getInvalidations', 'new_oids', 'pack', 'loadBefore', 'storea',
'checkCurrentSerialInTransaction', 'restorea', 'storeBlobStart', 'checkCurrentSerialInTransaction', 'restorea', 'storeBlobStart',
'storeBlobChunk', 'storeBlobEnd', 'storeBlobShared', 'storeBlobChunk', 'storeBlobEnd', 'storeBlobShared',
...@@ -78,6 +78,7 @@ registered_methods = set(( 'get_info', 'lastTransaction', ...@@ -78,6 +78,7 @@ registered_methods = set(( 'get_info', 'lastTransaction',
'iterator_next', 'iterator_record_start', 'iterator_record_next', 'iterator_next', 'iterator_record_start', 'iterator_record_next',
'iterator_gc', 'server_status', 'set_client_label', 'ping')) 'iterator_gc', 'server_status', 'set_client_label', 'ping'))
class ZEOStorage(object): class ZEOStorage(object):
"""Proxy to underlying storage for a single remote client.""" """Proxy to underlying storage for a single remote client."""
...@@ -146,7 +147,7 @@ class ZEOStorage(object): ...@@ -146,7 +147,7 @@ class ZEOStorage(object):
info = self.get_info() info = self.get_info()
if not info['supportsUndo']: if not info['supportsUndo']:
self.undoLog = self.undoInfo = lambda *a,**k: () self.undoLog = self.undoInfo = lambda *a, **k: ()
# XXX deprecated: but ZODB tests use getTid. They shouldn't # XXX deprecated: but ZODB tests use getTid. They shouldn't
self.getTid = storage.getTid self.getTid = storage.getTid
...@@ -166,16 +167,16 @@ class ZEOStorage(object): ...@@ -166,16 +167,16 @@ class ZEOStorage(object):
"Falling back to using _transaction attribute, which\n." "Falling back to using _transaction attribute, which\n."
"is icky.", "is icky.",
logging.ERROR) logging.ERROR)
self.tpc_transaction = lambda : storage._transaction self.tpc_transaction = lambda: storage._transaction
else: else:
raise raise
self.connection.methods = registered_methods self.connection.methods = registered_methods
def history(self,tid,size=1): def history(self, tid, size=1):
# This caters for storages which still accept # This caters for storages which still accept
# a version parameter. # a version parameter.
return self.storage.history(tid,size=size) return self.storage.history(tid, size=size)
def _check_tid(self, tid, exc=None): def _check_tid(self, tid, exc=None):
if self.read_only: if self.read_only:
...@@ -235,7 +236,7 @@ class ZEOStorage(object): ...@@ -235,7 +236,7 @@ class ZEOStorage(object):
def get_info(self): def get_info(self):
storage = self.storage storage = self.storage
supportsUndo = (getattr(storage, 'supportsUndo', lambda : False)() supportsUndo = (getattr(storage, 'supportsUndo', lambda: False)()
and self.connection.protocol_version[1:] >= b'310') and self.connection.protocol_version[1:] >= b'310')
# Communicate the backend storage interfaces to the client # Communicate the backend storage interfaces to the client
...@@ -404,14 +405,12 @@ class ZEOStorage(object): ...@@ -404,14 +405,12 @@ class ZEOStorage(object):
oid, oldserial, data, blobfilename = self.blob_log.pop() oid, oldserial, data, blobfilename = self.blob_log.pop()
self._store(oid, oldserial, data, blobfilename) self._store(oid, oldserial, data, blobfilename)
if not self.conflicts: if not self.conflicts:
try: try:
serials = self.storage.tpc_vote(self.transaction) serials = self.storage.tpc_vote(self.transaction)
except ConflictError as err: except ConflictError as err:
if (self.client_conflict_resolution and if self.client_conflict_resolution and \
err.oid and err.serials and err.data err.oid and err.serials and err.data:
):
self.conflicts[err.oid] = dict( self.conflicts[err.oid] = dict(
oid=err.oid, serials=err.serials, data=err.data) oid=err.oid, serials=err.serials, data=err.data)
else: else:
...@@ -485,11 +484,8 @@ class ZEOStorage(object): ...@@ -485,11 +484,8 @@ class ZEOStorage(object):
assert self.txnlog is not None # effectively not allowed after undo assert self.txnlog is not None # effectively not allowed after undo
# Reconstruct the full path from the filename in the OID directory # Reconstruct the full path from the filename in the OID directory
if (os.path.sep in filename if os.path.sep in filename or \
or not (filename.endswith('.tmp') not (filename.endswith('.tmp') or filename[:-1].endswith('.tmp')):
or filename[:-1].endswith('.tmp')
)
):
logger.critical( logger.critical(
"We're under attack! (bad filename to storeBlobShared, %r)", "We're under attack! (bad filename to storeBlobShared, %r)",
filename) filename)
...@@ -623,6 +619,7 @@ class ZEOStorage(object): ...@@ -623,6 +619,7 @@ class ZEOStorage(object):
def ping(self): def ping(self):
pass pass
class StorageServerDB(object): class StorageServerDB(object):
"""Adapter from StorageServerDB to ZODB.interfaces.IStorageWrapper """Adapter from StorageServerDB to ZODB.interfaces.IStorageWrapper
...@@ -649,6 +646,7 @@ class StorageServerDB(object): ...@@ -649,6 +646,7 @@ class StorageServerDB(object):
transform_record_data = untransform_record_data = lambda self, data: data transform_record_data = untransform_record_data = lambda self, data: data
class StorageServer(object): class StorageServer(object):
"""The server side implementation of ZEO. """The server side implementation of ZEO.
...@@ -722,7 +720,6 @@ class StorageServer(object): ...@@ -722,7 +720,6 @@ class StorageServer(object):
log("%s created %s with storages: %s" % log("%s created %s with storages: %s" %
(self.__class__.__name__, read_only and "RO" or "RW", msg)) (self.__class__.__name__, read_only and "RO" or "RW", msg))
self._lock = Lock() self._lock = Lock()
self.ssl = ssl # For dev convenience self.ssl = ssl # For dev convenience
...@@ -895,6 +892,7 @@ class StorageServer(object): ...@@ -895,6 +892,7 @@ class StorageServer(object):
return latest_tid, list(oids) return latest_tid, list(oids)
__thread = None __thread = None
def start_thread(self, daemon=True): def start_thread(self, daemon=True):
self.__thread = thread = threading.Thread(target=self.loop) self.__thread = thread = threading.Thread(target=self.loop)
thread.setName("StorageServer(%s)" % _addr_label(self.addr)) thread.setName("StorageServer(%s)" % _addr_label(self.addr))
...@@ -902,6 +900,7 @@ class StorageServer(object): ...@@ -902,6 +900,7 @@ class StorageServer(object):
thread.start() thread.start()
__closed = False __closed = False
def close(self, join_timeout=1): def close(self, join_timeout=1):
"""Close the dispatcher so that there are no new connections. """Close the dispatcher so that there are no new connections.
...@@ -959,6 +958,7 @@ class StorageServer(object): ...@@ -959,6 +958,7 @@ class StorageServer(object):
return dict((storage_id, self.server_status(storage_id)) return dict((storage_id, self.server_status(storage_id))
for storage_id in self.storages) for storage_id in self.storages)
class StubTimeoutThread(object): class StubTimeoutThread(object):
def begin(self, client): def begin(self, client):
...@@ -967,7 +967,8 @@ class StubTimeoutThread(object): ...@@ -967,7 +967,8 @@ class StubTimeoutThread(object):
def end(self, client): def end(self, client):
pass pass
is_alive = lambda self: 'stub' def is_alive(self):
return 'stub'
class TimeoutThread(threading.Thread): class TimeoutThread(threading.Thread):
...@@ -1020,7 +1021,7 @@ class TimeoutThread(threading.Thread): ...@@ -1020,7 +1021,7 @@ class TimeoutThread(threading.Thread):
self._timeout, logging.CRITICAL) self._timeout, logging.CRITICAL)
try: try:
client.call_soon_threadsafe(client.connection.close) client.call_soon_threadsafe(client.connection.close)
except: except: # NOQA: E722 bare except
client.log("Timeout failure", logging.CRITICAL, client.log("Timeout failure", logging.CRITICAL,
exc_info=sys.exc_info()) exc_info=sys.exc_info())
self.end(client) self.end(client)
...@@ -1074,6 +1075,7 @@ def _addr_label(addr): ...@@ -1074,6 +1075,7 @@ def _addr_label(addr):
host, port = addr host, port = addr
return str(host) + ":" + str(port) return str(host) + ":" + str(port)
class CommitLog(object): class CommitLog(object):
def __init__(self): def __init__(self):
...@@ -1116,23 +1118,28 @@ class CommitLog(object): ...@@ -1116,23 +1118,28 @@ class CommitLog(object):
self.file.close() self.file.close()
self.file = None self.file = None
class ServerEvent(object): class ServerEvent(object):
def __init__(self, server, **kw): def __init__(self, server, **kw):
self.__dict__.update(kw) self.__dict__.update(kw)
self.server = server self.server = server
class Serving(ServerEvent): class Serving(ServerEvent):
pass pass
class Closed(ServerEvent): class Closed(ServerEvent):
pass pass
def never_resolve_conflict(oid, committedSerial, oldSerial, newpickle, def never_resolve_conflict(oid, committedSerial, oldSerial, newpickle,
committedData=b''): committedData=b''):
raise ConflictError(oid=oid, serials=(committedSerial, oldSerial), raise ConflictError(oid=oid, serials=(committedSerial, oldSerial),
data=newpickle) data=newpickle)
class LockManager(object): class LockManager(object):
def __init__(self, storage_id, stats, timeout): def __init__(self, storage_id, stats, timeout):
...@@ -1218,10 +1225,10 @@ class LockManager(object): ...@@ -1218,10 +1225,10 @@ class LockManager(object):
zs, "(%r) dequeue lock: transactions waiting: %s") zs, "(%r) dequeue lock: transactions waiting: %s")
def _log_waiting(self, zs, message): def _log_waiting(self, zs, message):
l = len(self.waiting) length = len(self.waiting)
zs.log(message % (self.storage_id, l), zs.log(message % (self.storage_id, length),
logging.CRITICAL if l > 9 else ( logging.CRITICAL if length > 9 else (
logging.WARNING if l > 3 else logging.DEBUG) logging.WARNING if length > 3 else logging.DEBUG)
) )
def _can_lock(self, zs): def _can_lock(self, zs):
......
...@@ -21,12 +21,11 @@ is used to store the data until a commit or abort. ...@@ -21,12 +21,11 @@ is used to store the data until a commit or abort.
# A faster implementation might store trans data in memory until it # A faster implementation might store trans data in memory until it
# reaches a certain size. # reaches a certain size.
import os
import tempfile import tempfile
import ZODB.blob
from ZEO._compat import Pickler, Unpickler from ZEO._compat import Pickler, Unpickler
class TransactionBuffer(object): class TransactionBuffer(object):
# The TransactionBuffer is used by client storage to hold update # The TransactionBuffer is used by client storage to hold update
...@@ -93,9 +92,7 @@ class TransactionBuffer(object): ...@@ -93,9 +92,7 @@ class TransactionBuffer(object):
if oid not in seen: if oid not in seen:
yield oid, None, True yield oid, None, True
# Support ZEO4: # Support ZEO4:
def serialnos(self, args): def serialnos(self, args):
for oid in args: for oid in args:
if isinstance(oid, bytes): if isinstance(oid, bytes):
......
...@@ -21,6 +21,7 @@ ZEO is now part of ZODB; ZODB's home on the web is ...@@ -21,6 +21,7 @@ ZEO is now part of ZODB; ZODB's home on the web is
""" """
def client(*args, **kw): def client(*args, **kw):
""" """
Shortcut for :class:`ZEO.ClientStorage.ClientStorage`. Shortcut for :class:`ZEO.ClientStorage.ClientStorage`.
...@@ -28,6 +29,7 @@ def client(*args, **kw): ...@@ -28,6 +29,7 @@ def client(*args, **kw):
import ZEO.ClientStorage import ZEO.ClientStorage
return ZEO.ClientStorage.ClientStorage(*args, **kw) return ZEO.ClientStorage.ClientStorage(*args, **kw)
def DB(*args, **kw): def DB(*args, **kw):
""" """
Shortcut for creating a :class:`ZODB.DB` using a ZEO :func:`~ZEO.client`. Shortcut for creating a :class:`ZODB.DB` using a ZEO :func:`~ZEO.client`.
...@@ -40,6 +42,7 @@ def DB(*args, **kw): ...@@ -40,6 +42,7 @@ def DB(*args, **kw):
s.close() s.close()
raise raise
def connection(*args, **kw): def connection(*args, **kw):
db = DB(*args, **kw) db = DB(*args, **kw)
try: try:
...@@ -48,6 +51,7 @@ def connection(*args, **kw): ...@@ -48,6 +51,7 @@ def connection(*args, **kw):
db.close() db.close()
raise raise
def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None, def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None,
port=0, threaded=True, **kw): port=0, threaded=True, **kw):
"""Convenience function to start a server for interactive exploration """Convenience function to start a server for interactive exploration
......
...@@ -16,13 +16,20 @@ ...@@ -16,13 +16,20 @@
import sys import sys
import platform import platform
from ZODB._compat import BytesIO # NOQA: F401 unused import
PY3 = sys.version_info[0] >= 3 PY3 = sys.version_info[0] >= 3
PY32 = sys.version_info[:2] == (3, 2) PY32 = sys.version_info[:2] == (3, 2)
PYPY = getattr(platform, 'python_implementation', lambda: None)() == 'PyPy' PYPY = getattr(platform, 'python_implementation', lambda: None)() == 'PyPy'
WIN = sys.platform.startswith('win') WIN = sys.platform.startswith('win')
if PY3: if PY3:
from zodbpickle.pickle import Pickler, Unpickler as _Unpickler, dump, dumps, loads from zodbpickle.pickle import dump
from zodbpickle.pickle import dumps
from zodbpickle.pickle import loads
from zodbpickle.pickle import Pickler
from zodbpickle.pickle import Unpickler as _Unpickler
class Unpickler(_Unpickler): class Unpickler(_Unpickler):
# Py3: Python 3 doesn't allow assignments to find_global, # Py3: Python 3 doesn't allow assignments to find_global,
# instead, find_class can be overridden # instead, find_class can be overridden
...@@ -44,24 +51,17 @@ else: ...@@ -44,24 +51,17 @@ else:
dumps = cPickle.dumps dumps = cPickle.dumps
loads = cPickle.loads loads = cPickle.loads
# String and Bytes IO
from ZODB._compat import BytesIO
if PY3: if PY3:
import _thread as thread # NOQA: F401 unused import
import _thread as thread
if PY32: if PY32:
from threading import _get_ident as get_ident from threading import _get_ident as get_ident # NOQA: F401 unused
else: else:
from threading import get_ident from threading import get_ident # NOQA: F401 unused import
else: else:
import thread # NOQA: F401 unused import
import thread from thread import get_ident # NOQA: F401 unused import
from thread import get_ident
try: try:
from cStringIO import StringIO from cStringIO import StringIO # NOQA: F401 unused import
except: except ImportError:
from io import StringIO from io import StringIO # NOQA: F401 unused import
...@@ -26,11 +26,10 @@ import six ...@@ -26,11 +26,10 @@ import six
from ZEO._compat import StringIO from ZEO._compat import StringIO
logger = logging.getLogger('ZEO.tests.forker') logger = logging.getLogger('ZEO.tests.forker')
DEBUG = os.environ.get('ZEO_TEST_SERVER_DEBUG') DEBUG = os.environ.get('ZEO_TEST_SERVER_DEBUG')
ZEO4_SERVER = os.environ.get('ZEO4_SERVER') ZEO4_SERVER = os.environ.get('ZEO4_SERVER')
class ZEOConfig(object): class ZEOConfig(object):
"""Class to generate ZEO configuration file. """ """Class to generate ZEO configuration file. """
...@@ -61,8 +60,7 @@ class ZEOConfig(object): ...@@ -61,8 +60,7 @@ class ZEOConfig(object):
for name in ( for name in (
'invalidation_queue_size', 'invalidation_age', 'invalidation_queue_size', 'invalidation_age',
'transaction_timeout', 'pid_filename', 'msgpack', 'transaction_timeout', 'pid_filename', 'msgpack',
'ssl_certificate', 'ssl_key', 'client_conflict_resolution', 'ssl_certificate', 'ssl_key', 'client_conflict_resolution'):
):
v = getattr(self, name, None) v = getattr(self, name, None)
if v: if v:
print(name.replace('_', '-'), v, file=f) print(name.replace('_', '-'), v, file=f)
...@@ -158,6 +156,7 @@ def runner(config, qin, qout, timeout=None, ...@@ -158,6 +156,7 @@ def runner(config, qin, qout, timeout=None,
ZEO.asyncio.server.best_protocol_version = old_protocol ZEO.asyncio.server.best_protocol_version = old_protocol
ZEO.asyncio.server.ServerProtocol.protocols = old_protocols ZEO.asyncio.server.ServerProtocol.protocols = old_protocols
def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None): def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None):
qin.put('stop') qin.put('stop')
try: try:
...@@ -180,6 +179,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None): ...@@ -180,6 +179,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None):
gc.collect() gc.collect()
def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False, def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
path='Data.fs', protocol=None, blob_dir=None, path='Data.fs', protocol=None, blob_dir=None,
suicide=True, debug=False, suicide=True, debug=False,
...@@ -220,7 +220,8 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False, ...@@ -220,7 +220,8 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
print(zeo_conf) print(zeo_conf)
# Store the config info in a temp file. # Store the config info in a temp file.
fd, tmpfile = tempfile.mkstemp(".conf", prefix='ZEO_forker', dir=os.getcwd()) fd, tmpfile = tempfile.mkstemp(".conf", prefix='ZEO_forker',
dir=os.getcwd())
with os.fdopen(fd, 'w') as fp: with os.fdopen(fd, 'w') as fp:
fp.write(zeo_conf) fp.write(zeo_conf)
...@@ -273,10 +274,12 @@ def debug_logging(logger='ZEO', stream='stderr', level=logging.DEBUG): ...@@ -273,10 +274,12 @@ def debug_logging(logger='ZEO', stream='stderr', level=logging.DEBUG):
return stop return stop
def whine(*message): def whine(*message):
print(*message, file=sys.stderr) print(*message, file=sys.stderr)
sys.stderr.flush() sys.stderr.flush()
class ThreadlessQueue(object): class ThreadlessQueue(object):
def __init__(self): def __init__(self):
......
...@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__) ...@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
INET_FAMILIES = socket.AF_INET, socket.AF_INET6 INET_FAMILIES = socket.AF_INET, socket.AF_INET6
class Protocol(asyncio.Protocol): class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO base interface """asyncio low-level ZEO base interface
""" """
...@@ -41,6 +42,7 @@ class Protocol(asyncio.Protocol): ...@@ -41,6 +42,7 @@ class Protocol(asyncio.Protocol):
return self.name return self.name
closed = False closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -50,7 +52,6 @@ class Protocol(asyncio.Protocol): ...@@ -50,7 +52,6 @@ class Protocol(asyncio.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
logger.info("Connected %s", self) logger.info("Connected %s", self)
if sys.version_info < (3, 6): if sys.version_info < (3, 6):
sock = transport.get_extra_info('socket') sock = transport.get_extra_info('socket')
if sock is not None and sock.family in INET_FAMILIES: if sock is not None and sock.family in INET_FAMILIES:
...@@ -91,6 +92,7 @@ class Protocol(asyncio.Protocol): ...@@ -91,6 +92,7 @@ class Protocol(asyncio.Protocol):
got = 0 got = 0
want = 4 want = 4
getting_size = True getting_size = True
def data_received(self, data): def data_received(self, data):
# Low-level input handler collects data into sized messages. # Low-level input handler collects data into sized messages.
......
...@@ -21,6 +21,7 @@ Fallback = object() ...@@ -21,6 +21,7 @@ Fallback = object()
local_random = random.Random() # use separate generator to facilitate tests local_random = random.Random() # use separate generator to facilitate tests
def future_generator(func): def future_generator(func):
"""Decorates a generator that generates futures """Decorates a generator that generates futures
""" """
...@@ -52,6 +53,7 @@ def future_generator(func): ...@@ -52,6 +53,7 @@ def future_generator(func):
return call_generator return call_generator
class Protocol(base.Protocol): class Protocol(base.Protocol):
"""asyncio low-level ZEO client interface """asyncio low-level ZEO client interface
""" """
...@@ -132,7 +134,9 @@ class Protocol(base.Protocol): ...@@ -132,7 +134,9 @@ class Protocol(base.Protocol):
elif future.exception() is not None: elif future.exception() is not None:
logger.info("Connection to %r failed, %s", logger.info("Connection to %r failed, %s",
self.addr, future.exception()) self.addr, future.exception())
else: return else:
return
# keep trying # keep trying
if not self.closed: if not self.closed:
logger.info("retry connecting %r", self.addr) logger.info("retry connecting %r", self.addr)
...@@ -141,7 +145,6 @@ class Protocol(base.Protocol): ...@@ -141,7 +145,6 @@ class Protocol(base.Protocol):
self.connect, self.connect,
) )
def connection_made(self, transport): def connection_made(self, transport):
super(Protocol, self).connection_made(transport) super(Protocol, self).connection_made(transport)
self.heartbeat(write=False) self.heartbeat(write=False)
...@@ -190,7 +193,8 @@ class Protocol(base.Protocol): ...@@ -190,7 +193,8 @@ class Protocol(base.Protocol):
try: try:
server_tid = yield self.fut( server_tid = yield self.fut(
'register', self.storage_key, 'register', self.storage_key,
self.read_only if self.read_only is not Fallback else False, (self.read_only if self.read_only is not Fallback
else False),
*credentials) *credentials)
except ZODB.POSException.ReadOnlyError: except ZODB.POSException.ReadOnlyError:
if self.read_only is Fallback: if self.read_only is Fallback:
...@@ -208,6 +212,7 @@ class Protocol(base.Protocol): ...@@ -208,6 +212,7 @@ class Protocol(base.Protocol):
self.client.registered(self, server_tid) self.client.registered(self, server_tid)
exception_type_type = type(Exception) exception_type_type = type(Exception)
def message_received(self, data): def message_received(self, data):
msgid, async_, name, args = self.decode(data) msgid, async_, name, args = self.decode(data)
if name == '.reply': if name == '.reply':
...@@ -244,6 +249,7 @@ class Protocol(base.Protocol): ...@@ -244,6 +249,7 @@ class Protocol(base.Protocol):
raise AttributeError(name) raise AttributeError(name)
message_id = 0 message_id = 0
def call(self, future, method, args): def call(self, future, method, args):
self.message_id += 1 self.message_id += 1
self.futures[self.message_id] = future self.futures[self.message_id] = future
...@@ -262,6 +268,7 @@ class Protocol(base.Protocol): ...@@ -262,6 +268,7 @@ class Protocol(base.Protocol):
self.futures[message_id] = future self.futures[message_id] = future
self._write( self._write(
self.encode(message_id, False, 'loadBefore', (oid, tid))) self.encode(message_id, False, 'loadBefore', (oid, tid)))
@future.add_done_callback @future.add_done_callback
def _(future): def _(future):
try: try:
...@@ -271,6 +278,7 @@ class Protocol(base.Protocol): ...@@ -271,6 +278,7 @@ class Protocol(base.Protocol):
if data: if data:
data, start, end = data data, start, end = data
self.client.cache.store(oid, start, end, data) self.client.cache.store(oid, start, end, data)
return future return future
# Methods called by the server. # Methods called by the server.
...@@ -290,29 +298,34 @@ class Protocol(base.Protocol): ...@@ -290,29 +298,34 @@ class Protocol(base.Protocol):
self.heartbeat_handle = self.loop.call_later( self.heartbeat_handle = self.loop.call_later(
self.heartbeat_interval, self.heartbeat) self.heartbeat_interval, self.heartbeat)
def create_Exception(class_, args): def create_Exception(class_, args):
return exc_classes[class_](*args) return exc_classes[class_](*args)
def create_ConflictError(class_, args): def create_ConflictError(class_, args):
exc = exc_classes[class_]( exc = exc_classes[class_](
message = args['message'], message=args['message'],
oid = args['oid'], oid=args['oid'],
serials = args['serials'], serials=args['serials'],
) )
exc.class_name = args.get('class_name') exc.class_name = args.get('class_name')
return exc return exc
def create_BTreesConflictError(class_, args): def create_BTreesConflictError(class_, args):
return ZODB.POSException.BTreesConflictError( return ZODB.POSException.BTreesConflictError(
p1 = args['p1'], p1=args['p1'],
p2 = args['p2'], p2=args['p2'],
p3 = args['p3'], p3=args['p3'],
reason = args['reason'], reason=args['reason'],
) )
def create_MultipleUndoErrors(class_, args): def create_MultipleUndoErrors(class_, args):
return ZODB.POSException.MultipleUndoErrors(args['_errs']) return ZODB.POSException.MultipleUndoErrors(args['_errs'])
exc_classes = { exc_classes = {
'builtins.KeyError': KeyError, 'builtins.KeyError': KeyError,
'builtins.TypeError': TypeError, 'builtins.TypeError': TypeError,
...@@ -340,6 +353,8 @@ exc_factories = { ...@@ -340,6 +353,8 @@ exc_factories = {
} }
unlogged_exceptions = (ZODB.POSException.POSKeyError, unlogged_exceptions = (ZODB.POSException.POSKeyError,
ZODB.POSException.ConflictError) ZODB.POSException.ConflictError)
class Client(object): class Client(object):
"""asyncio low-level ZEO client interface """asyncio low-level ZEO client interface
""" """
...@@ -352,8 +367,11 @@ class Client(object): ...@@ -352,8 +367,11 @@ class Client(object):
# connect. # connect.
protocol = None protocol = None
ready = None # Tri-value: None=Never connected, True=connected, # ready can have three values:
# None=Never connected
# True=connected
# False=Disconnected # False=Disconnected
ready = None
def __init__(self, loop, def __init__(self, loop,
addrs, client, cache, storage_key, read_only, connect_poll, addrs, client, cache, storage_key, read_only, connect_poll,
...@@ -404,6 +422,7 @@ class Client(object): ...@@ -404,6 +422,7 @@ class Client(object):
self.is_read_only() and self.read_only is Fallback) self.is_read_only() and self.read_only is Fallback)
closed = False closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -474,9 +493,8 @@ class Client(object): ...@@ -474,9 +493,8 @@ class Client(object):
if protocol is not self: if protocol is not self:
protocol.close() protocol.close()
logger.exception("Registration or cache validation failed, %s", exc) logger.exception("Registration or cache validation failed, %s", exc)
if (self.protocol is None and not if self.protocol is None and \
any(not p.closed for p in self.protocols) not any(not p.closed for p in self.protocols):
):
self.loop.call_later( self.loop.call_later(
self.register_failed_poll + local_random.random(), self.register_failed_poll + local_random.random(),
self.try_connecting) self.try_connecting)
...@@ -739,6 +757,7 @@ class Client(object): ...@@ -739,6 +757,7 @@ class Client(object):
else: else:
return protocol.read_only return protocol.read_only
class ClientRunner(object): class ClientRunner(object):
def set_options(self, addrs, wrapper, cache, storage_key, read_only, def set_options(self, addrs, wrapper, cache, storage_key, read_only,
...@@ -855,6 +874,7 @@ class ClientRunner(object): ...@@ -855,6 +874,7 @@ class ClientRunner(object):
timeout = self.timeout timeout = self.timeout
self.wait_for_result(self.client.connected, timeout) self.wait_for_result(self.client.connected, timeout)
class ClientThread(ClientRunner): class ClientThread(ClientRunner):
"""Thread wrapper for client interface """Thread wrapper for client interface
...@@ -883,6 +903,7 @@ class ClientThread(ClientRunner): ...@@ -883,6 +903,7 @@ class ClientThread(ClientRunner):
raise self.exception raise self.exception
exception = None exception = None
def run(self): def run(self):
loop = None loop = None
try: try:
...@@ -909,6 +930,7 @@ class ClientThread(ClientRunner): ...@@ -909,6 +930,7 @@ class ClientThread(ClientRunner):
logger.debug('Stopping client thread') logger.debug('Stopping client thread')
closed = False closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -918,6 +940,7 @@ class ClientThread(ClientRunner): ...@@ -918,6 +940,7 @@ class ClientThread(ClientRunner):
if self.exception: if self.exception:
raise self.exception raise self.exception
class Fut(object): class Fut(object):
"""Lightweight future that calls it's callbacks immediately rather than soon """Lightweight future that calls it's callbacks immediately rather than soon
""" """
...@@ -929,6 +952,7 @@ class Fut(object): ...@@ -929,6 +952,7 @@ class Fut(object):
self.cbv.append(cb) self.cbv.append(cb)
exc = None exc = None
def set_exception(self, exc): def set_exception(self, exc):
self.exc = exc self.exc = exc
for cb in self.cbv: for cb in self.cbv:
......
...@@ -6,5 +6,5 @@ if PY3: ...@@ -6,5 +6,5 @@ if PY3:
except ImportError: except ImportError:
from asyncio import new_event_loop from asyncio import new_event_loop
else: else:
import trollius as asyncio import trollius as asyncio # NOQA: F401 unused import
from trollius import new_event_loop from trollius import new_event_loop # NOQA: F401 unused import
...@@ -21,19 +21,22 @@ Python-independent format, or possibly a minimal pickle subset. ...@@ -21,19 +21,22 @@ Python-independent format, or possibly a minimal pickle subset.
import logging import logging
from .._compat import Unpickler, Pickler, BytesIO, PY3, PYPY from .._compat import Unpickler, Pickler, BytesIO, PY3
from ..shortrepr import short_repr from ..shortrepr import short_repr
PY2 = not PY3 PY2 = not PY3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def encoder(protocol, server=False): def encoder(protocol, server=False):
"""Return a non-thread-safe encoder """Return a non-thread-safe encoder
""" """
if protocol[:1] == b'M': if protocol[:1] == b'M':
from msgpack import packb from msgpack import packb
default = server_default if server else None default = server_default if server else None
def encode(*args): def encode(*args):
return packb( return packb(
args, use_bin_type=True, default=default) args, use_bin_type=True, default=default)
...@@ -49,6 +52,7 @@ def encoder(protocol, server=False): ...@@ -49,6 +52,7 @@ def encoder(protocol, server=False):
pickler = Pickler(f, 3) pickler = Pickler(f, 3)
pickler.fast = 1 pickler.fast = 1
dump = pickler.dump dump = pickler.dump
def encode(*args): def encode(*args):
seek(0) seek(0)
truncate() truncate()
...@@ -57,21 +61,26 @@ def encoder(protocol, server=False): ...@@ -57,21 +61,26 @@ def encoder(protocol, server=False):
return encode return encode
def encode(*args): def encode(*args):
return encoder(b'Z')(*args) return encoder(b'Z')(*args)
def decoder(protocol): def decoder(protocol):
if protocol[:1] == b'M': if protocol[:1] == b'M':
from msgpack import unpackb from msgpack import unpackb
def msgpack_decode(data): def msgpack_decode(data):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
return unpackb(data, raw=False, use_list=False) return unpackb(data, raw=False, use_list=False)
return msgpack_decode return msgpack_decode
else: else:
assert protocol[:1] == b'Z' assert protocol[:1] == b'Z'
return pickle_decode return pickle_decode
def pickle_decode(msg): def pickle_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
...@@ -83,10 +92,11 @@ def pickle_decode(msg): ...@@ -83,10 +92,11 @@ def pickle_decode(msg):
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except: # NOQA: E722 bare except
logger.error("can't decode message: %s" % short_repr(msg)) logger.error("can't decode message: %s" % short_repr(msg))
raise raise
def server_decoder(protocol): def server_decoder(protocol):
if protocol[:1] == b'M': if protocol[:1] == b'M':
return decoder(protocol) return decoder(protocol)
...@@ -94,6 +104,7 @@ def server_decoder(protocol): ...@@ -94,6 +104,7 @@ def server_decoder(protocol):
assert protocol[:1] == b'Z' assert protocol[:1] == b'Z'
return pickle_server_decode return pickle_server_decode
def pickle_server_decode(msg): def pickle_server_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
...@@ -106,21 +117,24 @@ def pickle_server_decode(msg): ...@@ -106,21 +117,24 @@ def pickle_server_decode(msg):
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except: # NOQA: E722 bare except
logger.error("can't decode message: %s" % short_repr(msg)) logger.error("can't decode message: %s" % short_repr(msg))
raise raise
def server_default(obj): def server_default(obj):
if isinstance(obj, Exception): if isinstance(obj, Exception):
return reduce_exception(obj) return reduce_exception(obj)
else: else:
return obj return obj
def reduce_exception(exc): def reduce_exception(exc):
class_ = exc.__class__ class_ = exc.__class__
class_ = "%s.%s" % (class_.__module__, class_.__name__) class_ = "%s.%s" % (class_.__module__, class_.__name__)
return class_, exc.__dict__ or exc.args return class_, exc.__dict__ or exc.args
_globals = globals() _globals = globals()
_silly = ('__doc__',) _silly = ('__doc__',)
...@@ -131,6 +145,7 @@ _SAFE_MODULE_NAMES = ( ...@@ -131,6 +145,7 @@ _SAFE_MODULE_NAMES = (
'builtins', 'copy_reg', '__builtin__', 'builtins', 'copy_reg', '__builtin__',
) )
def find_global(module, name): def find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
try: try:
...@@ -143,7 +158,8 @@ def find_global(module, name): ...@@ -143,7 +158,8 @@ def find_global(module, name):
except AttributeError: except AttributeError:
raise ImportError("module %s has no global %s" % (module, name)) raise ImportError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) or (PY2 and module in _SAFE_MODULE_NAMES) safe = (getattr(r, '__no_side_effects__', 0) or
(PY2 and module in _SAFE_MODULE_NAMES))
if safe: if safe:
return r return r
...@@ -153,6 +169,7 @@ def find_global(module, name): ...@@ -153,6 +169,7 @@ def find_global(module, name):
raise ImportError("Unsafe global: %s.%s" % (module, name)) raise ImportError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name): def server_find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
if module not in _SAFE_MODULE_NAMES: if module not in _SAFE_MODULE_NAMES:
......
...@@ -72,6 +72,7 @@ import logging ...@@ -72,6 +72,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Acceptor(asyncore.dispatcher): class Acceptor(asyncore.dispatcher):
"""A server that accepts incoming RPC connections """A server that accepts incoming RPC connections
...@@ -115,13 +116,13 @@ class Acceptor(asyncore.dispatcher): ...@@ -115,13 +116,13 @@ class Acceptor(asyncore.dispatcher):
for i in range(25): for i in range(25):
try: try:
self.bind(addr) self.bind(addr)
except Exception as exc: except Exception:
logger.info("bind on %s failed %s waiting", addr, i) logger.info("bind on %s failed %s waiting", addr, i)
if i == 24: if i == 24:
raise raise
else: else:
time.sleep(5) time.sleep(5)
except: except: # NOQA: E722 bare except
logger.exception('binding') logger.exception('binding')
raise raise
else: else:
...@@ -146,7 +147,6 @@ class Acceptor(asyncore.dispatcher): ...@@ -146,7 +147,6 @@ class Acceptor(asyncore.dispatcher):
logger.info("accepted failed: %s", msg) logger.info("accepted failed: %s", msg)
return return
# We could short-circuit the attempt below in some edge cases # We could short-circuit the attempt below in some edge cases
# and avoid a log message by checking for addr being None. # and avoid a log message by checking for addr being None.
# Unfortunately, our test for the code below, # Unfortunately, our test for the code below,
...@@ -172,23 +172,25 @@ class Acceptor(asyncore.dispatcher): ...@@ -172,23 +172,25 @@ class Acceptor(asyncore.dispatcher):
protocol.stop = loop.stop protocol.stop = loop.stop
if self.ssl_context is None: if self.ssl_context is None:
cr = loop.create_connection((lambda : protocol), sock=sock) cr = loop.create_connection((lambda: protocol), sock=sock)
else: else:
if hasattr(loop, 'connect_accepted_socket'): if hasattr(loop, 'connect_accepted_socket'):
cr = loop.connect_accepted_socket( cr = loop.connect_accepted_socket(
(lambda : protocol), sock, ssl=self.ssl_context) (lambda: protocol), sock, ssl=self.ssl_context)
else: else:
####################################################### #######################################################
# XXX See http://bugs.python.org/issue27392 :( # XXX See http://bugs.python.org/issue27392 :(
_make_ssl_transport = loop._make_ssl_transport _make_ssl_transport = loop._make_ssl_transport
def make_ssl_transport(*a, **kw): def make_ssl_transport(*a, **kw):
kw['server_side'] = True kw['server_side'] = True
return _make_ssl_transport(*a, **kw) return _make_ssl_transport(*a, **kw)
loop._make_ssl_transport = make_ssl_transport loop._make_ssl_transport = make_ssl_transport
# #
####################################################### #######################################################
cr = loop.create_connection( cr = loop.create_connection(
(lambda : protocol), sock=sock, (lambda: protocol), sock=sock,
ssl=self.ssl_context, ssl=self.ssl_context,
server_hostname='' server_hostname=''
) )
...@@ -217,6 +219,7 @@ class Acceptor(asyncore.dispatcher): ...@@ -217,6 +219,7 @@ class Acceptor(asyncore.dispatcher):
logger.debug('acceptor %s loop stopped', self.addr) logger.debug('acceptor %s loop stopped', self.addr)
__closed = False __closed = False
def close(self): def close(self):
if not self.__closed: if not self.__closed:
self.__closed = True self.__closed = True
......
import json import json
import logging import logging
import os import os
import random
import threading import threading
import ZODB.POSException import ZODB.POSException
logger = logging.getLogger(__name__)
from ..shortrepr import short_repr from ..shortrepr import short_repr
from . import base from . import base
from .compat import asyncio, new_event_loop from .compat import asyncio, new_event_loop
from .marshal import server_decoder, encoder, reduce_exception from .marshal import server_decoder, encoder, reduce_exception
logger = logging.getLogger(__name__)
class ServerProtocol(base.Protocol): class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface """asyncio low-level ZEO server interface
""" """
...@@ -39,6 +40,7 @@ class ServerProtocol(base.Protocol): ...@@ -39,6 +40,7 @@ class ServerProtocol(base.Protocol):
) )
closed = False closed = False
def close(self): def close(self):
logger.debug("Closing server protocol") logger.debug("Closing server protocol")
if not self.closed: if not self.closed:
...@@ -47,6 +49,7 @@ class ServerProtocol(base.Protocol): ...@@ -47,6 +49,7 @@ class ServerProtocol(base.Protocol):
self.transport.close() self.transport.close()
connected = None # for tests connected = None # for tests
def connection_made(self, transport): def connection_made(self, transport):
self.connected = True self.connected = True
super(ServerProtocol, self).connection_made(transport) super(ServerProtocol, self).connection_made(transport)
...@@ -147,16 +150,19 @@ class ServerProtocol(base.Protocol): ...@@ -147,16 +150,19 @@ class ServerProtocol(base.Protocol):
def async_threadsafe(self, method, *args): def async_threadsafe(self, method, *args):
self.call_soon_threadsafe(self.call_async, method, args) self.call_soon_threadsafe(self.call_async, method, args)
best_protocol_version = os.environ.get( best_protocol_version = os.environ.get(
'ZEO_SERVER_PROTOCOL', 'ZEO_SERVER_PROTOCOL',
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8') ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols assert best_protocol_version in ServerProtocol.protocols
def new_connection(loop, addr, socket, zeo_storage, msgpack): def new_connection(loop, addr, socket, zeo_storage, msgpack):
protocol = ServerProtocol(loop, addr, zeo_storage, msgpack) protocol = ServerProtocol(loop, addr, zeo_storage, msgpack)
cr = loop.create_connection((lambda : protocol), sock=socket) cr = loop.create_connection((lambda: protocol), sock=socket)
asyncio.ensure_future(cr, loop=loop) asyncio.ensure_future(cr, loop=loop)
class Delay(object): class Delay(object):
"""Used to delay response to client for synchronous calls. """Used to delay response to client for synchronous calls.
...@@ -192,6 +198,7 @@ class Delay(object): ...@@ -192,6 +198,7 @@ class Delay(object):
def __reduce__(self): def __reduce__(self):
raise TypeError("Can't pickle delays.") raise TypeError("Can't pickle delays.")
class Result(Delay): class Result(Delay):
def __init__(self, *args): def __init__(self, *args):
...@@ -202,6 +209,7 @@ class Result(Delay): ...@@ -202,6 +209,7 @@ class Result(Delay):
protocol.send_reply(msgid, reply) protocol.send_reply(msgid, reply)
callback() callback()
class MTDelay(Delay): class MTDelay(Delay):
def __init__(self): def __init__(self):
...@@ -266,6 +274,7 @@ class Acceptor(object): ...@@ -266,6 +274,7 @@ class Acceptor(object):
self.event_loop.close() self.event_loop.close()
closed = False closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -277,6 +286,7 @@ class Acceptor(object): ...@@ -277,6 +286,7 @@ class Acceptor(object):
self.server.close() self.server.close()
f = asyncio.ensure_future(self.server.wait_closed(), loop=loop) f = asyncio.ensure_future(self.server.wait_closed(), loop=loop)
@f.add_done_callback @f.add_done_callback
def server_closed(f): def server_closed(f):
# stop the loop when the server closes: # stop the loop when the server closes:
......
...@@ -11,7 +11,6 @@ except NameError: ...@@ -11,7 +11,6 @@ except NameError:
class ConnectionRefusedError(OSError): class ConnectionRefusedError(OSError):
pass pass
import pprint
class Loop(object): class Loop(object):
...@@ -19,7 +18,7 @@ class Loop(object): ...@@ -19,7 +18,7 @@ class Loop(object):
def __init__(self, addrs=(), debug=True): def __init__(self, addrs=(), debug=True):
self.addrs = addrs self.addrs = addrs
self.get_debug = lambda : debug self.get_debug = lambda: debug
self.connecting = {} self.connecting = {}
self.later = [] self.later = []
self.exceptions = [] self.exceptions = []
...@@ -45,10 +44,8 @@ class Loop(object): ...@@ -45,10 +44,8 @@ class Loop(object):
if not future.cancelled(): if not future.cancelled():
future.set_exception(ConnectionRefusedError()) future.set_exception(ConnectionRefusedError())
def create_connection( def create_connection(self, protocol_factory, host=None, port=None,
self, protocol_factory, host=None, port=None, sock=None, sock=None, ssl=None, server_hostname=None):
ssl=None, server_hostname=None
):
future = asyncio.Future(loop=self) future = asyncio.Future(loop=self)
if sock is None: if sock is None:
addr = host, port addr = host, port
...@@ -83,13 +80,16 @@ class Loop(object): ...@@ -83,13 +80,16 @@ class Loop(object):
self.exceptions.append(context) self.exceptions.append(context)
closed = False closed = False
def close(self): def close(self):
self.closed = True self.closed = True
stopped = False stopped = False
def stop(self): def stop(self):
self.stopped = True self.stopped = True
class Handle(object): class Handle(object):
cancelled = False cancelled = False
...@@ -97,6 +97,7 @@ class Handle(object): ...@@ -97,6 +97,7 @@ class Handle(object):
def cancel(self): def cancel(self):
self.cancelled = True self.cancelled = True
class Transport(object): class Transport(object):
capacity = 1 << 64 capacity = 1 << 64
...@@ -136,12 +137,14 @@ class Transport(object): ...@@ -136,12 +137,14 @@ class Transport(object):
self.protocol.resume_writing() self.protocol.resume_writing()
closed = False closed = False
def close(self): def close(self):
self.closed = True self.closed = True
def get_extra_info(self, name): def get_extra_info(self, name):
return self.extra[name] return self.extra[name]
class AsyncRPC(object): class AsyncRPC(object):
"""Adapt an asyncio API to an RPC to help hysterical tests """Adapt an asyncio API to an RPC to help hysterical tests
""" """
...@@ -151,6 +154,7 @@ class AsyncRPC(object): ...@@ -151,6 +154,7 @@ class AsyncRPC(object):
def __getattr__(self, name): def __getattr__(self, name):
return lambda *a, **kw: self.api.call(name, *a, **kw) return lambda *a, **kw: self.api.call(name, *a, **kw)
class ClientRunner(object): class ClientRunner(object):
def __init__(self, addr, client, cache, storage, read_only, timeout, def __init__(self, addr, client, cache, storage, read_only, timeout,
......
...@@ -2,17 +2,17 @@ from .._compat import PY3 ...@@ -2,17 +2,17 @@ from .._compat import PY3
if PY3: if PY3:
import asyncio import asyncio
def to_byte(i): def to_byte(i):
return bytes([i]) return bytes([i])
else: else:
import trollius as asyncio import trollius as asyncio # NOQA: F401 unused import
def to_byte(b): def to_byte(b):
return b return b
from zope.testing import setupstack from zope.testing import setupstack
from concurrent.futures import Future
import mock import mock
from ZODB.POSException import ReadOnlyError
from ZODB.utils import maxtid, RLock from ZODB.utils import maxtid, RLock
import collections import collections
...@@ -28,6 +28,7 @@ from .client import ClientRunner, Fallback ...@@ -28,6 +28,7 @@ from .client import ClientRunner, Fallback
from .server import new_connection, best_protocol_version from .server import new_connection, best_protocol_version
from .marshal import encoder, decoder from .marshal import encoder, decoder
class Base(object): class Base(object):
enc = b'Z' enc = b'Z'
...@@ -56,6 +57,7 @@ class Base(object): ...@@ -56,6 +57,7 @@ class Base(object):
return self.unsized(data, True) return self.unsized(data, True)
target = None target = None
def send(self, method, *args, **kw): def send(self, method, *args, **kw):
target = kw.pop('target', self.target) target = kw.pop('target', self.target)
called = kw.pop('called', True) called = kw.pop('called', True)
...@@ -77,6 +79,7 @@ class Base(object): ...@@ -77,6 +79,7 @@ class Base(object):
def pop(self, count=None, parse=True): def pop(self, count=None, parse=True):
return self.unsized(self.loop.transport.pop(count), parse) return self.unsized(self.loop.transport.pop(count), parse)
class ClientTests(Base, setupstack.TestCase, ClientRunner): class ClientTests(Base, setupstack.TestCase, ClientRunner):
maxDiff = None maxDiff = None
...@@ -204,7 +207,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -204,7 +207,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
loaded = self.load_before(b'1'*8, maxtid) loaded = self.load_before(b'1'*8, maxtid)
# The data wasn't in the cache, so we made a server call: # The data wasn't in the cache, so we made a server call:
self.assertEqual(self.pop(), ((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid))) self.assertEqual(self.pop(),
((b'1'*8, maxtid),
False,
'loadBefore',
(b'1'*8, maxtid)))
# Note load_before uses the oid as the message id. # Note load_before uses the oid as the message id.
self.respond((b'1'*8, maxtid), (b'data', b'a'*8, None)) self.respond((b'1'*8, maxtid), (b'data', b'a'*8, None))
self.assertEqual(loaded.result(), (b'data', b'a'*8, None)) self.assertEqual(loaded.result(), (b'data', b'a'*8, None))
...@@ -224,7 +231,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -224,7 +231,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# the requests will be collapsed: # the requests will be collapsed:
loaded2 = self.load_before(b'1'*8, maxtid) loaded2 = self.load_before(b'1'*8, maxtid)
self.assertEqual(self.pop(), ((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid))) self.assertEqual(self.pop(),
((b'1'*8, maxtid),
False,
'loadBefore',
(b'1'*8, maxtid)))
self.respond((b'1'*8, maxtid), (b'data2', b'b'*8, None)) self.respond((b'1'*8, maxtid), (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None)) self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None)) self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None))
...@@ -238,7 +249,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -238,7 +249,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(transport.data) self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8) loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(self.pop(), ((b'1'*8, b'_'*8), False, 'loadBefore', (b'1'*8, b'_'*8))) self.assertEqual(self.pop(),
((b'1'*8, b'_'*8),
False,
'loadBefore',
(b'1'*8, b'_'*8)))
self.respond((b'1'*8, b'_'*8), (b'data0', b'^'*8, b'_'*8)) self.respond((b'1'*8, b'_'*8), (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8)) self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
...@@ -247,6 +262,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -247,6 +262,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# iteratable to tpc_finish_threadsafe. # iteratable to tpc_finish_threadsafe.
tids = [] tids = []
def finished_cb(tid): def finished_cb(tid):
tids.append(tid) tids.append(tid)
...@@ -349,7 +365,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -349,7 +365,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, ))) self.assertEqual(self.pop(),
(3, False, 'getInvalidations', (b'a'*8, )))
self.respond(3, (b'e'*8, [b'4'*8])) self.respond(3, (b'e'*8, [b'4'*8]))
self.assertEqual(self.pop(), (4, False, 'get_info', ())) self.assertEqual(self.pop(), (4, False, 'get_info', ()))
...@@ -384,7 +401,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -384,7 +401,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, ))) self.assertEqual(self.pop(),
(3, False, 'getInvalidations', (b'a'*8, )))
# We respond None, indicating that we're too far out of date: # We respond None, indicating that we're too far out of date:
self.respond(3, None) self.respond(3, None)
...@@ -512,7 +530,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -512,7 +530,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We connect the second address: # We connect the second address:
loop.connect_connecting(addrs[1]) loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(self.enc + b'3101')) loop.protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), self.enc + b'3101') self.assertEqual(self.unsized(loop.transport.pop(2)),
self.enc + b'3101')
self.assertEqual(self.parse(loop.transport.pop()), self.assertEqual(self.parse(loop.transport.pop()),
(1, False, 'register', ('TEST', False))) (1, False, 'register', ('TEST', False)))
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
...@@ -613,7 +632,6 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -613,7 +632,6 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.data_received(sized(self.enc + b'200')) protocol.data_received(sized(self.enc + b'200'))
self.assertTrue(isinstance(error.call_args[0][1], ProtocolError)) self.assertTrue(isinstance(error.call_args[0][1], ProtocolError))
def test_get_peername(self): def test_get_peername(self):
wrapper, cache, loop, client, protocol, transport = self.start( wrapper, cache, loop, client, protocol, transport = self.start(
finish_start=True) finish_start=True)
...@@ -641,7 +659,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -641,7 +659,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# that caused it to fail badly if errors were raised while # that caused it to fail badly if errors were raised while
# handling data. # handling data.
wrapper, cache, loop, client, protocol, transport =self.start( wrapper, cache, loop, client, protocol, transport = self.start(
finish_start=True) finish_start=True)
wrapper.receiveBlobStart.side_effect = ValueError('test') wrapper.receiveBlobStart.side_effect = ValueError('test')
...@@ -694,10 +712,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -694,10 +712,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.connection_lost(None) protocol.connection_lost(None)
self.assertTrue(handle.cancelled) self.assertTrue(handle.cancelled)
class MsgpackClientTests(ClientTests): class MsgpackClientTests(ClientTests):
enc = b'M' enc = b'M'
seq_type = tuple seq_type = tuple
class MemoryCache(object): class MemoryCache(object):
def __init__(self): def __init__(self):
...@@ -709,6 +729,7 @@ class MemoryCache(object): ...@@ -709,6 +729,7 @@ class MemoryCache(object):
clear = __init__ clear = __init__
closed = False closed = False
def close(self): def close(self):
self.closed = True self.closed = True
...@@ -771,6 +792,7 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -771,6 +792,7 @@ class ServerTests(Base, setupstack.TestCase):
message_id = 0 message_id = 0
target = None target = None
def call(self, meth, *args, **kw): def call(self, meth, *args, **kw):
if kw: if kw:
expect = kw.pop('expect', self) expect = kw.pop('expect', self)
...@@ -835,10 +857,12 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -835,10 +857,12 @@ class ServerTests(Base, setupstack.TestCase):
self.call('foo', target=None) self.call('foo', target=None)
self.assertTrue(protocol.loop.transport.closed) self.assertTrue(protocol.loop.transport.closed)
class MsgpackServerTests(ServerTests): class MsgpackServerTests(ServerTests):
enc = b'M' enc = b'M'
seq_type = tuple seq_type = tuple
def server_protocol(msgpack, def server_protocol(msgpack,
zeo_storage=None, zeo_storage=None,
protocol_version=None, protocol_version=None,
...@@ -853,12 +877,11 @@ def server_protocol(msgpack, ...@@ -853,12 +877,11 @@ def server_protocol(msgpack,
loop.protocol.data_received(sized(protocol_version)) loop.protocol.data_received(sized(protocol_version))
return loop.protocol return loop.protocol
def response(*data):
return sized(self.encode(*data))
def sized(message): def sized(message):
return struct.pack(">I", len(message)) + message return struct.pack(">I", len(message)) + message
class Logging(object): class Logging(object):
def __init__(self, level=logging.ERROR): def __init__(self, level=logging.ERROR):
...@@ -885,9 +908,11 @@ class ProtocolTests(setupstack.TestCase): ...@@ -885,9 +908,11 @@ class ProtocolTests(setupstack.TestCase):
loop = self.loop loop = self.loop
protocol, transport = loop.protocol, loop.transport protocol, transport = loop.protocol, loop.transport
transport.capacity = 1 # single message transport.capacity = 1 # single message
def it(tag): def it(tag):
yield tag yield tag
yield tag yield tag
protocol._writeit(it(b"0")) protocol._writeit(it(b"0"))
protocol._writeit(it(b"1")) protocol._writeit(it(b"1"))
for b in b"0011": for b in b"0011":
......
...@@ -86,7 +86,7 @@ ZEC_HEADER_SIZE = 12 ...@@ -86,7 +86,7 @@ ZEC_HEADER_SIZE = 12
# need to write a free block that is almost twice as big. If we die # need to write a free block that is almost twice as big. If we die
# in the middle of a store, then we need to split the large free records # in the middle of a store, then we need to split the large free records
# while opening. # while opening.
max_block_size = (1<<31) - 1 max_block_size = (1 << 31) - 1
# After the header, the file contains a contiguous sequence of blocks. All # After the header, the file contains a contiguous sequence of blocks. All
...@@ -132,12 +132,13 @@ allocated_record_overhead = 43 ...@@ -132,12 +132,13 @@ allocated_record_overhead = 43
# Under PyPy, the available dict specializations perform significantly # Under PyPy, the available dict specializations perform significantly
# better (faster) than the pure-Python BTree implementation. They may # better (faster) than the pure-Python BTree implementation. They may
# use less memory too. And we don't require any of the special BTree features... # use less memory too. And we don't require any of the special BTree features.
_current_index_type = ZODB.fsIndex.fsIndex if not PYPY else dict _current_index_type = ZODB.fsIndex.fsIndex if not PYPY else dict
_noncurrent_index_type = BTrees.LOBTree.LOBTree if not PYPY else dict _noncurrent_index_type = BTrees.LOBTree.LOBTree if not PYPY else dict
# ...except at this leaf level # ...except at this leaf level
_noncurrent_bucket_type = BTrees.LLBTree.LLBucket _noncurrent_bucket_type = BTrees.LLBTree.LLBucket
class ClientCache(object): class ClientCache(object):
"""A simple in-memory cache.""" """A simple in-memory cache."""
...@@ -209,7 +210,7 @@ class ClientCache(object): ...@@ -209,7 +210,7 @@ class ClientCache(object):
try: try:
self._initfile(fsize) self._initfile(fsize)
except: except: # NOQA: E722 bare except
self.f.close() self.f.close()
if not path: if not path:
raise # unrecoverable temp file error :( raise # unrecoverable temp file error :(
...@@ -271,7 +272,7 @@ class ClientCache(object): ...@@ -271,7 +272,7 @@ class ClientCache(object):
self.current = _current_index_type() self.current = _current_index_type()
self.noncurrent = _noncurrent_index_type() self.noncurrent = _noncurrent_index_type()
l = 0 length = 0
last = ofs = ZEC_HEADER_SIZE last = ofs = ZEC_HEADER_SIZE
first_free_offset = 0 first_free_offset = 0
current = self.current current = self.current
...@@ -290,7 +291,7 @@ class ClientCache(object): ...@@ -290,7 +291,7 @@ class ClientCache(object):
assert start_tid < end_tid, (ofs, f.tell()) assert start_tid < end_tid, (ofs, f.tell())
self._set_noncurrent(oid, start_tid, ofs) self._set_noncurrent(oid, start_tid, ofs)
assert lver == 0, "Versions aren't supported" assert lver == 0, "Versions aren't supported"
l += 1 length += 1
else: else:
# free block # free block
if first_free_offset == 0: if first_free_offset == 0:
...@@ -331,7 +332,7 @@ class ClientCache(object): ...@@ -331,7 +332,7 @@ class ClientCache(object):
break break
if fsize < maxsize: if fsize < maxsize:
assert ofs==fsize assert ofs == fsize
# Make sure the OS really saves enough bytes for the file. # Make sure the OS really saves enough bytes for the file.
seek(self.maxsize - 1) seek(self.maxsize - 1)
write(b'x') write(b'x')
...@@ -349,7 +350,7 @@ class ClientCache(object): ...@@ -349,7 +350,7 @@ class ClientCache(object):
assert last and (status in b' f1234') assert last and (status in b' f1234')
first_free_offset = last first_free_offset = last
else: else:
assert ofs==maxsize assert ofs == maxsize
if maxsize < fsize: if maxsize < fsize:
seek(maxsize) seek(maxsize)
f.truncate() f.truncate()
...@@ -357,7 +358,7 @@ class ClientCache(object): ...@@ -357,7 +358,7 @@ class ClientCache(object):
# We use the first_free_offset because it is most likely the # We use the first_free_offset because it is most likely the
# place where we last wrote. # place where we last wrote.
self.currentofs = first_free_offset or ZEC_HEADER_SIZE self.currentofs = first_free_offset or ZEC_HEADER_SIZE
self._len = l self._len = length
def _set_noncurrent(self, oid, tid, ofs): def _set_noncurrent(self, oid, tid, ofs):
noncurrent_for_oid = self.noncurrent.get(u64(oid)) noncurrent_for_oid = self.noncurrent.get(u64(oid))
...@@ -375,7 +376,6 @@ class ClientCache(object): ...@@ -375,7 +376,6 @@ class ClientCache(object):
except KeyError: except KeyError:
logger.error("Couldn't find non-current %r", (oid, tid)) logger.error("Couldn't find non-current %r", (oid, tid))
def clearStats(self): def clearStats(self):
self._n_adds = self._n_added_bytes = 0 self._n_adds = self._n_added_bytes = 0
self._n_evicts = self._n_evicted_bytes = 0 self._n_evicts = self._n_evicted_bytes = 0
...@@ -384,8 +384,7 @@ class ClientCache(object): ...@@ -384,8 +384,7 @@ class ClientCache(object):
def getStats(self): def getStats(self):
return (self._n_adds, self._n_added_bytes, return (self._n_adds, self._n_added_bytes,
self._n_evicts, self._n_evicted_bytes, self._n_evicts, self._n_evicted_bytes,
self._n_accesses self._n_accesses)
)
## ##
# The number of objects currently in the cache. # The number of objects currently in the cache.
...@@ -403,7 +402,7 @@ class ClientCache(object): ...@@ -403,7 +402,7 @@ class ClientCache(object):
sync(f) sync(f)
f.close() f.close()
if hasattr(self,'_lock_file'): if hasattr(self, '_lock_file'):
self._lock_file.close() self._lock_file.close()
## ##
...@@ -517,9 +516,9 @@ class ClientCache(object): ...@@ -517,9 +516,9 @@ class ClientCache(object):
if ofsofs < 0: if ofsofs < 0:
ofsofs += self.maxsize ofsofs += self.maxsize
if (ofsofs > self.rearrange and if ofsofs > self.rearrange and \
self.maxsize > 10*len(data) and self.maxsize > 10*len(data) and \
size > 4): size > 4:
# The record is far back and might get evicted, but it's # The record is far back and might get evicted, but it's
# valuable, so move it forward. # valuable, so move it forward.
...@@ -619,8 +618,8 @@ class ClientCache(object): ...@@ -619,8 +618,8 @@ class ClientCache(object):
raise ValueError("already have current data for oid") raise ValueError("already have current data for oid")
else: else:
noncurrent_for_oid = self.noncurrent.get(u64(oid)) noncurrent_for_oid = self.noncurrent.get(u64(oid))
if noncurrent_for_oid and ( if noncurrent_for_oid and \
u64(start_tid) in noncurrent_for_oid): u64(start_tid) in noncurrent_for_oid:
return return
size = allocated_record_overhead + len(data) size = allocated_record_overhead + len(data)
...@@ -692,7 +691,6 @@ class ClientCache(object): ...@@ -692,7 +691,6 @@ class ClientCache(object):
self.currentofs += size self.currentofs += size
## ##
# If `tid` is None, # If `tid` is None,
# forget all knowledge of `oid`. (`tid` can be None only for # forget all knowledge of `oid`. (`tid` can be None only for
...@@ -765,8 +763,7 @@ class ClientCache(object): ...@@ -765,8 +763,7 @@ class ClientCache(object):
for oid, tid in L: for oid, tid in L:
print(oid_repr(oid), oid_repr(tid)) print(oid_repr(oid), oid_repr(tid))
print("dll contents") print("dll contents")
L = list(self) L = sorted(list(self), key=lambda x: x.key)
L.sort(lambda x, y: cmp(x.key, y.key))
for x in L: for x in L:
end_tid = x.end_tid or z64 end_tid = x.end_tid or z64
print(oid_repr(x.key[0]), oid_repr(x.key[1]), oid_repr(end_tid)) print(oid_repr(x.key[0]), oid_repr(x.key[1]), oid_repr(end_tid))
...@@ -779,6 +776,7 @@ class ClientCache(object): ...@@ -779,6 +776,7 @@ class ClientCache(object):
# tracing by setting self._trace to a dummy function, and set # tracing by setting self._trace to a dummy function, and set
# self._tracefile to None. # self._tracefile to None.
_tracefile = None _tracefile = None
def _trace(self, *a, **kw): def _trace(self, *a, **kw):
pass pass
...@@ -797,6 +795,7 @@ class ClientCache(object): ...@@ -797,6 +795,7 @@ class ClientCache(object):
return return
now = time.time now = time.time
def _trace(code, oid=b"", tid=z64, end_tid=z64, dlen=0): def _trace(code, oid=b"", tid=z64, end_tid=z64, dlen=0):
# The code argument is two hex digits; bits 0 and 7 must be zero. # The code argument is two hex digits; bits 0 and 7 must be zero.
# The first hex digit shows the operation, the second the outcome. # The first hex digit shows the operation, the second the outcome.
...@@ -812,7 +811,7 @@ class ClientCache(object): ...@@ -812,7 +811,7 @@ class ClientCache(object):
pack(">iiH8s8s", pack(">iiH8s8s",
int(now()), encoded, len(oid), tid, end_tid) + oid, int(now()), encoded, len(oid), tid, end_tid) + oid,
) )
except: except: # NOQA: E722 bare except
print(repr(tid), repr(end_tid)) print(repr(tid), repr(end_tid))
raise raise
...@@ -826,10 +825,7 @@ class ClientCache(object): ...@@ -826,10 +825,7 @@ class ClientCache(object):
self._tracefile.close() self._tracefile.close()
del self._tracefile del self._tracefile
def sync(f):
f.flush()
if hasattr(os, 'fsync'): def sync(f):
def sync(f):
f.flush() f.flush()
os.fsync(f.fileno()) os.fsync(f.fileno())
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import zope.interface import zope.interface
class StaleCache(object): class StaleCache(object):
"""A ZEO cache is stale and requires verification. """A ZEO cache is stale and requires verification.
""" """
...@@ -21,6 +22,7 @@ class StaleCache(object): ...@@ -21,6 +22,7 @@ class StaleCache(object):
def __init__(self, storage): def __init__(self, storage):
self.storage = storage self.storage = storage
class IClientCache(zope.interface.Interface): class IClientCache(zope.interface.Interface):
"""Client cache interface. """Client cache interface.
...@@ -86,6 +88,7 @@ class IClientCache(zope.interface.Interface): ...@@ -86,6 +88,7 @@ class IClientCache(zope.interface.Interface):
"""Clear/empty the cache """Clear/empty the cache
""" """
class IServeable(zope.interface.Interface): class IServeable(zope.interface.Interface):
"""Interface provided by storages that can be served by ZEO """Interface provided by storages that can be served by ZEO
""" """
......
...@@ -30,10 +30,7 @@ from __future__ import print_function ...@@ -30,10 +30,7 @@ from __future__ import print_function
from __future__ import print_function from __future__ import print_function
from __future__ import print_function from __future__ import print_function
import asyncore
import socket
import time import time
import logging
zeo_version = 'unknown' zeo_version = 'unknown'
try: try:
...@@ -47,6 +44,7 @@ else: ...@@ -47,6 +44,7 @@ else:
if zeo_dist is not None: if zeo_dist is not None:
zeo_version = zeo_dist.version zeo_version = zeo_dist.version
class StorageStats(object): class StorageStats(object):
"""Per-storage usage statistics.""" """Per-storage usage statistics."""
......
...@@ -33,6 +33,7 @@ diff_names = 'aborts commits conflicts conflicts_resolved loads stores'.split() ...@@ -33,6 +33,7 @@ diff_names = 'aborts commits conflicts conflicts_resolved loads stores'.split()
per_times = dict(seconds=1.0, minutes=60.0, hours=3600.0, days=86400.0) per_times = dict(seconds=1.0, minutes=60.0, hours=3600.0, days=86400.0)
def new_metric(metrics, storage_id, name, value): def new_metric(metrics, storage_id, name, value):
if storage_id == '1': if storage_id == '1':
label = name label = name
...@@ -43,6 +44,7 @@ def new_metric(metrics, storage_id, name, value): ...@@ -43,6 +44,7 @@ def new_metric(metrics, storage_id, name, value):
label = "%s:%s" % (storage_id, name) label = "%s:%s" % (storage_id, name)
metrics.append("%s=%s" % (label, value)) metrics.append("%s=%s" % (label, value))
def result(messages, metrics=(), status=None): def result(messages, metrics=(), status=None):
if metrics: if metrics:
messages[0] += '|' + metrics[0] messages[0] += '|' + metrics[0]
...@@ -51,12 +53,15 @@ def result(messages, metrics=(), status=None): ...@@ -51,12 +53,15 @@ def result(messages, metrics=(), status=None):
print('\n'.join(messages)) print('\n'.join(messages))
return status return status
def error(message): def error(message):
return result((message, ), (), 2) return result((message, ), (), 2)
def warn(message): def warn(message):
return result((message, ), (), 1) return result((message, ), (), 1)
def check(addr, output_metrics, status, per): def check(addr, output_metrics, status, per):
m = re.match(r'\[(\S+)\]:(\d+)$', addr) m = re.match(r'\[(\S+)\]:(\d+)$', addr)
if m: if m:
...@@ -75,7 +80,7 @@ def check(addr, output_metrics, status, per): ...@@ -75,7 +80,7 @@ def check(addr, output_metrics, status, per):
return error("Can't connect %s" % err) return error("Can't connect %s" % err)
s.sendall(b'\x00\x00\x00\x04ruok') s.sendall(b'\x00\x00\x00\x04ruok')
proto = s.recv(struct.unpack(">I", s.recv(4))[0]) proto = s.recv(struct.unpack(">I", s.recv(4))[0]) # NOQA: F841 unused
datas = s.recv(struct.unpack(">I", s.recv(4))[0]) datas = s.recv(struct.unpack(">I", s.recv(4))[0])
s.close() s.close()
data = json.loads(datas.decode("ascii")) data = json.loads(datas.decode("ascii"))
...@@ -116,6 +121,7 @@ def check(addr, output_metrics, status, per): ...@@ -116,6 +121,7 @@ def check(addr, output_metrics, status, per):
messages.append('OK') messages.append('OK')
return result(messages, metrics, level or None) return result(messages, metrics, level or None)
def main(args=None): def main(args=None):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
...@@ -139,5 +145,6 @@ def main(args=None): ...@@ -139,5 +145,6 @@ def main(args=None):
return check( return check(
addr, options.output_metrics, options.status_path, options.time_units) addr, options.output_metrics, options.status_path, options.time_units)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -46,21 +46,25 @@ from zdaemon.zdoptions import ZDOptions ...@@ -46,21 +46,25 @@ from zdaemon.zdoptions import ZDOptions
logger = logging.getLogger('ZEO.runzeo') logger = logging.getLogger('ZEO.runzeo')
_pid = str(os.getpid()) _pid = str(os.getpid())
def log(msg, level=logging.INFO, exc_info=False): def log(msg, level=logging.INFO, exc_info=False):
"""Internal: generic logging function.""" """Internal: generic logging function."""
message = "(%s) %s" % (_pid, msg) message = "(%s) %s" % (_pid, msg)
logger.log(level, message, exc_info=exc_info) logger.log(level, message, exc_info=exc_info)
def parse_binding_address(arg): def parse_binding_address(arg):
# Caution: Not part of the official ZConfig API. # Caution: Not part of the official ZConfig API.
obj = ZConfig.datatypes.SocketBindingAddress(arg) obj = ZConfig.datatypes.SocketBindingAddress(arg)
return obj.family, obj.address return obj.family, obj.address
def windows_shutdown_handler(): def windows_shutdown_handler():
# Called by the signal mechanism on Windows to perform shutdown. # Called by the signal mechanism on Windows to perform shutdown.
import asyncore import asyncore
asyncore.close_all() asyncore.close_all()
class ZEOOptionsMixin(object): class ZEOOptionsMixin(object):
storages = None storages = None
...@@ -70,13 +74,16 @@ class ZEOOptionsMixin(object): ...@@ -70,13 +74,16 @@ class ZEOOptionsMixin(object):
def handle_filename(self, arg): def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*! from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig(object): class FSConfig(object):
def __init__(self, name, path): def __init__(self, name, path):
self._name = name self._name = name
self.path = path self.path = path
self.stop = None self.stop = None
def getSectionName(self): def getSectionName(self):
return self._name return self._name
if not self.storages: if not self.storages:
self.storages = [] self.storages = []
name = str(1 + len(self.storages)) name = str(1 + len(self.storages))
...@@ -84,6 +91,7 @@ class ZEOOptionsMixin(object): ...@@ -84,6 +91,7 @@ class ZEOOptionsMixin(object):
self.storages.append(conf) self.storages.append(conf)
testing_exit_immediately = False testing_exit_immediately = False
def handle_test(self, *args): def handle_test(self, *args):
self.testing_exit_immediately = True self.testing_exit_immediately = True
...@@ -108,6 +116,7 @@ class ZEOOptionsMixin(object): ...@@ -108,6 +116,7 @@ class ZEOOptionsMixin(object):
None, 'pid-file=') None, 'pid-file=')
self.add("ssl", "zeo.ssl") self.add("ssl", "zeo.ssl")
class ZEOOptions(ZDOptions, ZEOOptionsMixin): class ZEOOptions(ZDOptions, ZEOOptionsMixin):
__doc__ = __doc__ __doc__ = __doc__
...@@ -171,8 +180,8 @@ class ZEOServer(object): ...@@ -171,8 +180,8 @@ class ZEOServer(object):
root.addHandler(handler) root.addHandler(handler)
def check_socket(self): def check_socket(self):
if (isinstance(self.options.address, tuple) and if isinstance(self.options.address, tuple) and \
self.options.address[1] is None): self.options.address[1] is None:
self.options.address = self.options.address[0], 0 self.options.address = self.options.address[0], 0
return return
...@@ -278,7 +287,8 @@ class ZEOServer(object): ...@@ -278,7 +287,8 @@ class ZEOServer(object):
def handle_sigusr2(self): def handle_sigusr2(self):
# log rotation signal - do the same as Zope 2.7/2.8... # log rotation signal - do the same as Zope 2.7/2.8...
if self.options.config_logger is None or os.name not in ("posix", "nt"): if self.options.config_logger is None or \
os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!", log("received SIGUSR2, but it was not handled!",
level=logging.WARNING) level=logging.WARNING)
return return
...@@ -286,13 +296,13 @@ class ZEOServer(object): ...@@ -286,13 +296,13 @@ class ZEOServer(object):
loggers = [self.options.config_logger] loggers = [self.options.config_logger]
if os.name == "posix": if os.name == "posix":
for l in loggers: for logger in loggers:
l.reopen() logger.reopen()
log("Log files reopened successfully", level=logging.INFO) log("Log files reopened successfully", level=logging.INFO)
else: # nt - same rotation code as in Zope's Signals/Signals.py else: # nt - same rotation code as in Zope's Signals/Signals.py
for l in loggers: for logger in loggers:
for f in l.handler_factories: for factory in logger.handler_factories:
handler = f() handler = factory()
if hasattr(handler, 'rotate') and callable(handler.rotate): if hasattr(handler, 'rotate') and callable(handler.rotate):
handler.rotate() handler.rotate()
log("Log files rotation complete", level=logging.INFO) log("Log files rotation complete", level=logging.INFO)
...@@ -350,21 +360,21 @@ def create_server(storages, options): ...@@ -350,21 +360,21 @@ def create_server(storages, options):
return StorageServer( return StorageServer(
options.address, options.address,
storages, storages,
read_only = options.read_only, read_only=options.read_only,
client_conflict_resolution=options.client_conflict_resolution, client_conflict_resolution=options.client_conflict_resolution,
msgpack=(options.msgpack if isinstance(options.msgpack, bool) msgpack=(options.msgpack if isinstance(options.msgpack, bool)
else os.environ.get('ZEO_MSGPACK')), else os.environ.get('ZEO_MSGPACK')),
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size=options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age=options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout=options.transaction_timeout,
ssl = options.ssl, ssl=options.ssl)
)
# Signal names # Signal names
signames = None signames = None
def signame(sig): def signame(sig):
"""Return a symbolic name for a signal. """Return a symbolic name for a signal.
...@@ -376,6 +386,7 @@ def signame(sig): ...@@ -376,6 +386,7 @@ def signame(sig):
init_signames() init_signames()
return signames.get(sig) or "signal %d" % sig return signames.get(sig) or "signal %d" % sig
def init_signames(): def init_signames():
global signames global signames
signames = {} signames = {}
...@@ -395,11 +406,13 @@ def main(args=None): ...@@ -395,11 +406,13 @@ def main(args=None):
s = ZEOServer(options) s = ZEOServer(options)
s.main() s.main()
def run(args): def run(args):
options = ZEOOptions() options = ZEOOptions()
options.realize(args) options.realize(args)
s = ZEOServer(options) s = ZEOServer(options)
s.run() s.run()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -27,6 +27,7 @@ from __future__ import print_function, absolute_import ...@@ -27,6 +27,7 @@ from __future__ import print_function, absolute_import
import bisect import bisect
import struct import struct
import random
import re import re
import sys import sys
import ZEO.cache import ZEO.cache
...@@ -34,6 +35,7 @@ import argparse ...@@ -34,6 +35,7 @@ import argparse
from ZODB.utils import z64 from ZODB.utils import z64
from ..cache import ZEC_HEADER_SIZE
from .cache_stats import add_interval_argument from .cache_stats import add_interval_argument
from .cache_stats import add_tracefile_argument from .cache_stats import add_tracefile_argument
...@@ -46,7 +48,7 @@ def main(args=None): ...@@ -46,7 +48,7 @@ def main(args=None):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
# Parse options. # Parse options.
MB = 1<<20 MB = 1 << 20
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--size", "-s", parser.add_argument("--size", "-s",
default=20*MB, dest="cachelimit", default=20*MB, dest="cachelimit",
...@@ -115,6 +117,7 @@ def main(args=None): ...@@ -115,6 +117,7 @@ def main(args=None):
interval_sim.report() interval_sim.report()
sim.finish() sim.finish()
class Simulation(object): class Simulation(object):
"""Base class for simulations. """Base class for simulations.
...@@ -270,7 +273,6 @@ class CircularCacheEntry(object): ...@@ -270,7 +273,6 @@ class CircularCacheEntry(object):
self.end_tid = end_tid self.end_tid = end_tid
self.offset = offset self.offset = offset
from ZEO.cache import ZEC_HEADER_SIZE
class CircularCacheSimulation(Simulation): class CircularCacheSimulation(Simulation):
"""Simulate the ZEO 3.0 cache.""" """Simulate the ZEO 3.0 cache."""
...@@ -285,8 +287,6 @@ class CircularCacheSimulation(Simulation): ...@@ -285,8 +287,6 @@ class CircularCacheSimulation(Simulation):
evicts = 0 evicts = 0
def __init__(self, cachelimit, rearrange): def __init__(self, cachelimit, rearrange):
from ZEO import cache
Simulation.__init__(self, cachelimit, rearrange) Simulation.__init__(self, cachelimit, rearrange)
self.total_evicts = 0 # number of cache evictions self.total_evicts = 0 # number of cache evictions
...@@ -322,6 +322,7 @@ class CircularCacheSimulation(Simulation): ...@@ -322,6 +322,7 @@ class CircularCacheSimulation(Simulation):
self.evicted_hit = self.evicted_miss = 0 self.evicted_hit = self.evicted_miss = 0
evicted_hit = evicted_miss = 0 evicted_hit = evicted_miss = 0
def load(self, oid, size, tid, code): def load(self, oid, size, tid, code):
if (code == 0x20) or (code == 0x22): if (code == 0x20) or (code == 0x22):
# Trying to load current revision. # Trying to load current revision.
...@@ -433,7 +434,8 @@ class CircularCacheSimulation(Simulation): ...@@ -433,7 +434,8 @@ class CircularCacheSimulation(Simulation):
# Storing current revision. # Storing current revision.
if oid in self.current: # we already have it in cache if oid in self.current: # we already have it in cache
if evhit: if evhit:
import pdb; pdb.set_trace() import pdb
pdb.set_trace()
raise ValueError('WTF') raise ValueError('WTF')
return return
self.current[oid] = start_tid self.current[oid] = start_tid
...@@ -442,7 +444,8 @@ class CircularCacheSimulation(Simulation): ...@@ -442,7 +444,8 @@ class CircularCacheSimulation(Simulation):
self.add(oid, size, start_tid) self.add(oid, size, start_tid)
return return
if evhit: if evhit:
import pdb; pdb.set_trace() import pdb
pdb.set_trace()
raise ValueError('WTF') raise ValueError('WTF')
# Storing non-current revision. # Storing non-current revision.
L = self.noncurrent.setdefault(oid, []) L = self.noncurrent.setdefault(oid, [])
...@@ -514,7 +517,7 @@ class CircularCacheSimulation(Simulation): ...@@ -514,7 +517,7 @@ class CircularCacheSimulation(Simulation):
self.inuse = round(100.0 * used / total, 1) self.inuse = round(100.0 * used / total, 1)
self.total_inuse = self.inuse self.total_inuse = self.inuse
Simulation.report(self) Simulation.report(self)
#print self.evicted_hit, self.evicted_miss # print self.evicted_hit, self.evicted_miss
def check(self): def check(self):
oidcount = 0 oidcount = 0
...@@ -538,16 +541,18 @@ class CircularCacheSimulation(Simulation): ...@@ -538,16 +541,18 @@ class CircularCacheSimulation(Simulation):
def roundup(size): def roundup(size):
k = MINSIZE k = MINSIZE # NOQA: F821 undefined name
while k < size: while k < size:
k += k k += k
return k return k
def hitrate(loads, hits): def hitrate(loads, hits):
if loads < 1: if loads < 1:
return 'n/a' return 'n/a'
return "%5.1f%%" % (100.0 * hits / loads) return "%5.1f%%" % (100.0 * hits / loads)
def duration(secs): def duration(secs):
mm, ss = divmod(secs, 60) mm, ss = divmod(secs, 60)
hh, mm = divmod(mm, 60) hh, mm = divmod(mm, 60)
...@@ -557,7 +562,10 @@ def duration(secs): ...@@ -557,7 +562,10 @@ def duration(secs):
return "%d:%02d" % (mm, ss) return "%d:%02d" % (mm, ss)
return "%d" % ss return "%d" % ss
nre = re.compile('([=-]?)(\d+)([.]\d*)?').match
nre = re.compile(r'([=-]?)(\d+)([.]\d*)?').match
def addcommas(n): def addcommas(n):
sign, s, d = nre(str(n)).group(1, 2, 3) sign, s, d = nre(str(n)).group(1, 2, 3)
if d == '.0': if d == '.0':
...@@ -571,11 +579,11 @@ def addcommas(n): ...@@ -571,11 +579,11 @@ def addcommas(n):
return (sign or '') + result + (d or '') return (sign or '') + result + (d or '')
import random
def maybe(f, p=0.5): def maybe(f, p=0.5):
if random.random() < p: if random.random() < p:
f() f()
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(main()) sys.exit(main())
...@@ -55,6 +55,7 @@ import gzip ...@@ -55,6 +55,7 @@ import gzip
from time import ctime from time import ctime
import six import six
def add_interval_argument(parser): def add_interval_argument(parser):
def _interval(a): def _interval(a):
interval = int(60 * float(a)) interval = int(60 * float(a))
...@@ -63,10 +64,12 @@ def add_interval_argument(parser): ...@@ -63,10 +64,12 @@ def add_interval_argument(parser):
elif interval > 3600: elif interval > 3600:
interval = 3600 interval = 3600
return interval return interval
parser.add_argument("--interval", "-i", parser.add_argument(
"--interval", "-i",
default=15*60, type=_interval, default=15*60, type=_interval,
help="summarizing interval in minutes (default 15; max 60)") help="summarizing interval in minutes (default 15; max 60)")
def add_tracefile_argument(parser): def add_tracefile_argument(parser):
class GzipFileType(argparse.FileType): class GzipFileType(argparse.FileType):
...@@ -82,11 +85,13 @@ def add_tracefile_argument(parser): ...@@ -82,11 +85,13 @@ def add_tracefile_argument(parser):
parser.add_argument("tracefile", type=GzipFileType(), parser.add_argument("tracefile", type=GzipFileType(),
help="The trace to read; may be gzipped") help="The trace to read; may be gzipped")
def main(args=None): def main(args=None):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
# Parse options # Parse options
parser = argparse.ArgumentParser(description="Trace file statistics analyzer", parser = argparse.ArgumentParser(
description="Trace file statistics analyzer",
# Our -h, short for --load-histogram # Our -h, short for --load-histogram
# conflicts with default for help, so we handle # conflicts with default for help, so we handle
# manually. # manually.
...@@ -99,18 +104,22 @@ def main(args=None): ...@@ -99,18 +104,22 @@ def main(args=None):
default=False, action='store_true', default=False, action='store_true',
help="Reduce output; don't print summaries") help="Reduce output; don't print summaries")
parser.add_argument("--sizes", '-s', parser.add_argument("--sizes", '-s',
default=False, action="store_true", dest="print_size_histogram", default=False, action="store_true",
dest="print_size_histogram",
help="print histogram of object sizes") help="print histogram of object sizes")
parser.add_argument("--no-stats", '-S', parser.add_argument("--no-stats", '-S',
default=True, action="store_false", dest="dostats", default=True, action="store_false", dest="dostats",
help="don't print statistics") help="don't print statistics")
parser.add_argument("--load-histogram", "-h", parser.add_argument("--load-histogram", "-h",
default=False, action="store_true", dest="print_histogram", default=False, action="store_true",
dest="print_histogram",
help="print histogram of object load frequencies") help="print histogram of object load frequencies")
parser.add_argument("--check", "-X", parser.add_argument("--check", "-X",
default=False, action="store_true", dest="heuristic", default=False, action="store_true", dest="heuristic",
help=" enable heuristic checking for misaligned records: oids > 2**32" help=" enable heuristic checking for misaligned "
" will be rejected; this requires the tracefile to be seekable") "records: oids > 2**32"
" will be rejected; this requires the tracefile "
"to be seekable")
add_interval_argument(parser) add_interval_argument(parser)
add_tracefile_argument(parser) add_tracefile_argument(parser)
...@@ -144,7 +153,8 @@ def main(args=None): ...@@ -144,7 +153,8 @@ def main(args=None):
FMT_SIZE = struct.calcsize(FMT) FMT_SIZE = struct.calcsize(FMT)
assert FMT_SIZE == 26 assert FMT_SIZE == 26
# Read file, gathering statistics, and printing each record if verbose. # Read file, gathering statistics, and printing each record if verbose.
print(' '*16, "%7s %7s %7s %7s" % ('loads', 'hits', 'inv(h)', 'writes'), end=' ') print(' '*16, "%7s %7s %7s %7s" % (
'loads', 'hits', 'inv(h)', 'writes'), end=' ')
print('hitrate') print('hitrate')
try: try:
while 1: while 1:
...@@ -279,6 +289,7 @@ def main(args=None): ...@@ -279,6 +289,7 @@ def main(args=None):
dumpbysize(bysizew, "written", "writes") dumpbysize(bysizew, "written", "writes")
dumpbysize(bysize, "loaded", "loads") dumpbysize(bysize, "loaded", "loads")
def dumpbysize(bysize, how, how2): def dumpbysize(bysize, how, how2):
print() print()
print("Unique sizes %s: %s" % (how, addcommas(len(bysize)))) print("Unique sizes %s: %s" % (how, addcommas(len(bysize))))
...@@ -292,6 +303,7 @@ def dumpbysize(bysize, how, how2): ...@@ -292,6 +303,7 @@ def dumpbysize(bysize, how, how2):
len(bysize.get(size, "")), len(bysize.get(size, "")),
loads)) loads))
def dumpbyinterval(byinterval, h0, he): def dumpbyinterval(byinterval, h0, he):
loads = hits = invals = writes = 0 loads = hits = invals = writes = 0
for code in byinterval: for code in byinterval:
...@@ -315,6 +327,7 @@ def dumpbyinterval(byinterval, h0, he): ...@@ -315,6 +327,7 @@ def dumpbyinterval(byinterval, h0, he):
ctime(h0)[4:-8], ctime(he)[14:-8], ctime(h0)[4:-8], ctime(he)[14:-8],
loads, hits, invals, writes, hr)) loads, hits, invals, writes, hr))
def hitrate(bycode): def hitrate(bycode):
loads = hits = 0 loads = hits = 0
for code in bycode: for code in bycode:
...@@ -328,6 +341,7 @@ def hitrate(bycode): ...@@ -328,6 +341,7 @@ def hitrate(bycode):
else: else:
return 0.0 return 0.0
def histogram(d): def histogram(d):
bins = {} bins = {}
for v in six.itervalues(d): for v in six.itervalues(d):
...@@ -335,15 +349,18 @@ def histogram(d): ...@@ -335,15 +349,18 @@ def histogram(d):
L = sorted(bins.items()) L = sorted(bins.items())
return L return L
def U64(s): def U64(s):
return struct.unpack(">Q", s)[0] return struct.unpack(">Q", s)[0]
def oid_repr(oid): def oid_repr(oid):
if isinstance(oid, six.binary_type) and len(oid) == 8: if isinstance(oid, six.binary_type) and len(oid) == 8:
return '%16x' % U64(oid) return '%16x' % U64(oid)
else: else:
return repr(oid) return repr(oid)
def addcommas(n): def addcommas(n):
sign, s = '', str(n) sign, s = '', str(n)
if s[0] == '-': if s[0] == '-':
...@@ -354,6 +371,7 @@ def addcommas(n): ...@@ -354,6 +371,7 @@ def addcommas(n):
i -= 3 i -= 3
return sign + s return sign + s
explain = { explain = {
# The first hex digit shows the operation, the second the outcome. # The first hex digit shows the operation, the second the outcome.
# If the second digit is in "02468" then it is a 'miss'. # If the second digit is in "02468" then it is a 'miss'.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Parse the BLATHER logging generated by ZEO2. """Parse the BLATHER logging generated by ZEO2.
An example of the log format is: An example of the log format is:
2002-04-15T13:05:29 BLATHER(-100) ZEO Server storea(3235680, [714], 235339406490168806) ('10.0.26.30', 45514) 2002-04-15T13:05:29 BLATHER(-100) ZEO Server storea(3235680, [714], 235339406490168806) ('10.0.26.30', 45514) # NOQA: E501 line too long
""" """
from __future__ import print_function from __future__ import print_function
from __future__ import print_function from __future__ import print_function
...@@ -14,7 +14,8 @@ from __future__ import print_function ...@@ -14,7 +14,8 @@ from __future__ import print_function
import re import re
import time import time
rx_time = re.compile('(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)') rx_time = re.compile(r'(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
def parse_time(line): def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None.""" """Return the time portion of a zLOG line in seconds or None."""
...@@ -26,11 +27,14 @@ def parse_time(line): ...@@ -26,11 +27,14 @@ def parse_time(line):
time_l = [int(elt) for elt in time_.split(':')] time_l = [int(elt) for elt in time_.split(':')]
return int(time.mktime(date_l + time_l + [0, 0, 0])) return int(time.mktime(date_l + time_l + [0, 0, 0]))
rx_meth = re.compile("zrpc:\d+ calling (\w+)\((.*)")
rx_meth = re.compile(r"zrpc:\d+ calling (\w+)\((.*)")
def parse_method(line): def parse_method(line):
pass pass
def parse_line(line): def parse_line(line):
"""Parse a log entry and return time, method info, and client.""" """Parse a log entry and return time, method info, and client."""
t = parse_time(line) t = parse_time(line)
...@@ -47,6 +51,7 @@ def parse_line(line): ...@@ -47,6 +51,7 @@ def parse_line(line):
m = meth_name, tuple(meth_args) m = meth_name, tuple(meth_args)
return t, m return t, m
class TStats(object): class TStats(object):
counter = 1 counter = 1
...@@ -61,7 +66,6 @@ class TStats(object): ...@@ -61,7 +66,6 @@ class TStats(object):
def report(self): def report(self):
"""Print a report about the transaction""" """Print a report about the transaction"""
t = time.ctime(self.begin)
if hasattr(self, "vote"): if hasattr(self, "vote"):
d_vote = self.vote - self.begin d_vote = self.vote - self.begin
else: else:
...@@ -73,6 +77,7 @@ class TStats(object): ...@@ -73,6 +77,7 @@ class TStats(object):
print(self.fmt % (time.ctime(self.begin), d_vote, d_finish, print(self.fmt % (time.ctime(self.begin), d_vote, d_finish,
self.user, self.url)) self.user, self.url))
class TransactionParser(object): class TransactionParser(object):
def __init__(self): def __init__(self):
...@@ -122,6 +127,7 @@ class TransactionParser(object): ...@@ -122,6 +127,7 @@ class TransactionParser(object):
L.sort() L.sort()
return [t for (id, t) in L] return [t for (id, t) in L]
if __name__ == "__main__": if __name__ == "__main__":
import fileinput import fileinput
...@@ -131,7 +137,7 @@ if __name__ == "__main__": ...@@ -131,7 +137,7 @@ if __name__ == "__main__":
i += 1 i += 1
try: try:
p.parse(line) p.parse(line)
except: except: # NOQA: E722 bare except
print("line", i) print("line", i)
raise raise
print("Transaction: %d" % len(p.txns)) print("Transaction: %d" % len(p.txns))
......
...@@ -12,9 +12,13 @@ ...@@ -12,9 +12,13 @@
# #
############################################################################## ##############################################################################
from __future__ import print_function from __future__ import print_function
import doctest, re, unittest import doctest
import re
import unittest
from zope.testing import renormalizing from zope.testing import renormalizing
def test_suite(): def test_suite():
return unittest.TestSuite(( return unittest.TestSuite((
doctest.DocFileSuite( doctest.DocFileSuite(
...@@ -26,4 +30,3 @@ def test_suite(): ...@@ -26,4 +30,3 @@ def test_suite():
globs={'print_function': print_function}, globs={'print_function': print_function},
), ),
)) ))
...@@ -25,6 +25,7 @@ from ZEO.ClientStorage import ClientStorage ...@@ -25,6 +25,7 @@ from ZEO.ClientStorage import ClientStorage
ZERO = '\0'*8 ZERO = '\0'*8
def main(): def main():
if len(sys.argv) not in (3, 4): if len(sys.argv) not in (3, 4):
sys.stderr.write("Usage: timeout.py address delay [storage-name]\n" % sys.stderr.write("Usage: timeout.py address delay [storage-name]\n" %
...@@ -68,5 +69,6 @@ def main(): ...@@ -68,5 +69,6 @@ def main():
time.sleep(delay) time.sleep(delay)
print("Done.") print("Done.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -8,7 +8,6 @@ import time ...@@ -8,7 +8,6 @@ import time
import traceback import traceback
import ZEO.ClientStorage import ZEO.ClientStorage
from six.moves import map from six.moves import map
from six.moves import zip
usage = """Usage: %prog [options] [servers] usage = """Usage: %prog [options] [servers]
...@@ -23,6 +22,7 @@ each is of the form: ...@@ -23,6 +22,7 @@ each is of the form:
WAIT = 10 # wait no more than 10 seconds for client to connect WAIT = 10 # wait no more than 10 seconds for client to connect
def _main(args=None, prog=None): def _main(args=None, prog=None):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
...@@ -160,10 +160,11 @@ def _main(args=None, prog=None): ...@@ -160,10 +160,11 @@ def _main(args=None, prog=None):
continue continue
cs.pack(packt, wait=True) cs.pack(packt, wait=True)
cs.close() cs.close()
except: except: # NOQA: E722 bare except
traceback.print_exception(*(sys.exc_info()+(99, sys.stderr))) traceback.print_exception(*(sys.exc_info()+(99, sys.stderr)))
error("Error packing storage %s in %r" % (name, addr)) error("Error packing storage %s in %r" % (name, addr))
def main(*args): def main(*args):
root_logger = logging.getLogger() root_logger = logging.getLogger()
old_level = root_logger.getEffectiveLevel() old_level = root_logger.getEffectiveLevel()
...@@ -178,6 +179,6 @@ def main(*args): ...@@ -178,6 +179,6 @@ def main(*args):
logging.getLogger().setLevel(old_level) logging.getLogger().setLevel(old_level)
logging.getLogger().removeHandler(handler) logging.getLogger().removeHandler(handler)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -37,7 +37,6 @@ STATEFILE = 'zeoqueue.pck' ...@@ -37,7 +37,6 @@ STATEFILE = 'zeoqueue.pck'
PROGRAM = sys.argv[0] PROGRAM = sys.argv[0]
tcre = re.compile(r""" tcre = re.compile(r"""
(?P<ymd> (?P<ymd>
\d{4}- # year \d{4}- # year
...@@ -67,7 +66,6 @@ ccre = re.compile(r""" ...@@ -67,7 +66,6 @@ ccre = re.compile(r"""
wcre = re.compile(r'Clients waiting: (?P<num>\d+)') wcre = re.compile(r'Clients waiting: (?P<num>\d+)')
def parse_time(line): def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None.""" """Return the time portion of a zLOG line in seconds or None."""
mo = tcre.match(line) mo = tcre.match(line)
...@@ -97,7 +95,6 @@ class Txn(object): ...@@ -97,7 +95,6 @@ class Txn(object):
return False return False
class Status(object): class Status(object):
"""Track status of ZEO server by replaying log records. """Track status of ZEO server by replaying log records.
...@@ -303,7 +300,6 @@ class Status(object): ...@@ -303,7 +300,6 @@ class Status(object):
break break
def usage(code, msg=''): def usage(code, msg=''):
print(__doc__ % globals(), file=sys.stderr) print(__doc__ % globals(), file=sys.stderr)
if msg: if msg:
......
...@@ -41,25 +41,25 @@ import time ...@@ -41,25 +41,25 @@ import time
import getopt import getopt
import operator import operator
# ZEO logs measure wall-clock time so for consistency we need to do the same # ZEO logs measure wall-clock time so for consistency we need to do the same
#from time import clock as now # from time import clock as now
from time import time as now from time import time as now
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
#from BDBStorage.BDBFullStorage import BDBFullStorage # from BDBStorage.BDBFullStorage import BDBFullStorage
#from Standby.primary import PrimaryStorage # from Standby.primary import PrimaryStorage
#from Standby.config import RS_PORT # from Standby.config import RS_PORT
from ZODB.Connection import TransactionMetaData from ZODB.Connection import TransactionMetaData
from ZODB.utils import p64 from ZODB.utils import p64
from functools import reduce from functools import reduce
datecre = re.compile('(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)') datecre = re.compile(r'(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
methcre = re.compile("ZEO Server (\w+)\((.*)\) \('(.*)', (\d+)") methcre = re.compile(r"ZEO Server (\w+)\((.*)\) \('(.*)', (\d+)")
class StopParsing(Exception): class StopParsing(Exception):
pass pass
def usage(code, msg=''): def usage(code, msg=''):
print(__doc__) print(__doc__)
if msg: if msg:
...@@ -67,7 +67,6 @@ def usage(code, msg=''): ...@@ -67,7 +67,6 @@ def usage(code, msg=''):
sys.exit(code) sys.exit(code)
def parse_time(line): def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None.""" """Return the time portion of a zLOG line in seconds or None."""
mo = datecre.match(line) mo = datecre.match(line)
...@@ -95,7 +94,6 @@ def parse_line(line): ...@@ -95,7 +94,6 @@ def parse_line(line):
return t, m, c return t, m, c
class StoreStat(object): class StoreStat(object):
def __init__(self, when, oid, size): def __init__(self, when, oid, size):
self.when = when self.when = when
...@@ -104,8 +102,10 @@ class StoreStat(object): ...@@ -104,8 +102,10 @@ class StoreStat(object):
# Crufty # Crufty
def __getitem__(self, i): def __getitem__(self, i):
if i == 0: return self.oid if i == 0:
if i == 1: return self.size return self.oid
if i == 1:
return self.size
raise IndexError raise IndexError
...@@ -136,10 +136,10 @@ class TxnStat(object): ...@@ -136,10 +136,10 @@ class TxnStat(object):
self._finishtime = when self._finishtime = when
# Mapping oid -> revid # Mapping oid -> revid
_revids = {} _revids = {}
class ReplayTxn(TxnStat): class ReplayTxn(TxnStat):
def __init__(self, storage): def __init__(self, storage):
self._storage = storage self._storage = storage
...@@ -172,7 +172,6 @@ class ReplayTxn(TxnStat): ...@@ -172,7 +172,6 @@ class ReplayTxn(TxnStat):
self._replaydelta = t1 - t0 - origdelta self._replaydelta = t1 - t0 - origdelta
class ZEOParser(object): class ZEOParser(object):
def __init__(self, maxtxns=-1, report=1, storage=None): def __init__(self, maxtxns=-1, report=1, storage=None):
self.__txns = [] self.__txns = []
...@@ -261,7 +260,6 @@ class ZEOParser(object): ...@@ -261,7 +260,6 @@ class ZEOParser(object):
print('average faster txn was:', float(sum) / len(faster)) print('average faster txn was:', float(sum) / len(faster))
def main(): def main():
try: try:
opts, args = getopt.getopt( opts, args = getopt.getopt(
...@@ -294,8 +292,8 @@ def main(): ...@@ -294,8 +292,8 @@ def main():
if replay: if replay:
storage = FileStorage(storagefile) storage = FileStorage(storagefile)
#storage = BDBFullStorage(storagefile) # storage = BDBFullStorage(storagefile)
#storage = PrimaryStorage('yyz', storage, RS_PORT) # storage = PrimaryStorage('yyz', storage, RS_PORT)
t0 = now() t0 = now()
p = ZEOParser(maxtxns, report, storage) p = ZEOParser(maxtxns, report, storage)
i = 0 i = 0
...@@ -308,7 +306,7 @@ def main(): ...@@ -308,7 +306,7 @@ def main():
p.parse(line) p.parse(line)
except StopParsing: except StopParsing:
break break
except: except: # NOQA: E722 bare except
print('input file line:', i) print('input file line:', i)
raise raise
t1 = now() t1 = now()
...@@ -321,6 +319,5 @@ def main(): ...@@ -321,6 +319,5 @@ def main():
print('total time:', t3-t0) print('total time:', t3-t0)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -169,9 +169,11 @@ from __future__ import print_function ...@@ -169,9 +169,11 @@ from __future__ import print_function
from __future__ import print_function from __future__ import print_function
from __future__ import print_function from __future__ import print_function
import datetime, sys, re, os import datetime
import os
import re
import sys
from six.moves import map from six.moves import map
from six.moves import zip
def time(line): def time(line):
...@@ -187,9 +189,10 @@ def sub(t1, t2): ...@@ -187,9 +189,10 @@ def sub(t1, t2):
return delta.days*86400.0+delta.seconds+delta.microseconds/1000000.0 return delta.days*86400.0+delta.seconds+delta.microseconds/1000000.0
waitre = re.compile(r'Clients waiting: (\d+)') waitre = re.compile(r'Clients waiting: (\d+)')
idre = re.compile(r' ZSS:\d+/(\d+.\d+.\d+.\d+:\d+) ') idre = re.compile(r' ZSS:\d+/(\d+.\d+.\d+.\d+:\d+) ')
def blocked_times(args): def blocked_times(args):
f, thresh = args f, thresh = args
...@@ -217,7 +220,6 @@ def blocked_times(args): ...@@ -217,7 +220,6 @@ def blocked_times(args):
t2 = t1 t2 = t1
if not blocking and last_blocking: if not blocking and last_blocking:
last_wait = 0
t2 = time(line) t2 = time(line)
cid = idre.search(line).group(1) cid = idre.search(line).group(1)
...@@ -225,11 +227,14 @@ def blocked_times(args): ...@@ -225,11 +227,14 @@ def blocked_times(args):
d = sub(t1, time(line)) d = sub(t1, time(line))
if d >= thresh: if d >= thresh:
print(t1, sub(t1, t2), cid, d) print(t1, sub(t1, t2), cid, d)
t1 = t2 = cid = blocking = waiting = last_wait = max_wait = 0 t1 = t2 = cid = blocking = waiting = 0
last_blocking = blocking last_blocking = blocking
connidre = re.compile(r' zrpc-conn:(\d+.\d+.\d+.\d+:\d+) ') connidre = re.compile(r' zrpc-conn:(\d+.\d+.\d+.\d+:\d+) ')
def time_calls(f): def time_calls(f):
f, thresh = f f, thresh = f
if f == '-': if f == '-':
...@@ -255,6 +260,7 @@ def time_calls(f): ...@@ -255,6 +260,7 @@ def time_calls(f):
print(maxd) print(maxd)
def xopen(f): def xopen(f):
if f == '-': if f == '-':
return sys.stdin return sys.stdin
...@@ -262,6 +268,7 @@ def xopen(f): ...@@ -262,6 +268,7 @@ def xopen(f):
return os.popen(f, 'r') return os.popen(f, 'r')
return open(f) return open(f)
def time_tpc(f): def time_tpc(f):
f, thresh = f f, thresh = f
if f == '-': if f == '-':
...@@ -307,11 +314,14 @@ def time_tpc(f): ...@@ -307,11 +314,14 @@ def time_tpc(f):
t = time(line) t = time(line)
d = sub(t1, t) d = sub(t1, t)
if d >= thresh: if d >= thresh:
print('c', t1, cid, sub(t1, t2), vs, sub(t2, t3), sub(t3, t)) print('c', t1, cid, sub(t1, t2),
vs, sub(t2, t3), sub(t3, t))
del transactions[cid] del transactions[cid]
newobre = re.compile(r"storea\(.*, '\\x00\\x00\\x00\\x00\\x00") newobre = re.compile(r"storea\(.*, '\\x00\\x00\\x00\\x00\\x00")
def time_trans(f): def time_trans(f):
f, thresh = f f, thresh = f
if f == '-': if f == '-':
...@@ -363,8 +373,8 @@ def time_trans(f): ...@@ -363,8 +373,8 @@ def time_trans(f):
t = time(line) t = time(line)
d = sub(t1, t) d = sub(t1, t)
if d >= thresh: if d >= thresh:
print(t1, cid, "%s/%s" % (stores, old), \ print(t1, cid, "%s/%s" % (stores, old),
sub(t0, t1), sub(t1, t2), vs, \ sub(t0, t1), sub(t1, t2), vs,
sub(t2, t), 'abort') sub(t2, t), 'abort')
del transactions[cid] del transactions[cid]
elif ' calling tpc_finish(' in line: elif ' calling tpc_finish(' in line:
...@@ -377,11 +387,12 @@ def time_trans(f): ...@@ -377,11 +387,12 @@ def time_trans(f):
t = time(line) t = time(line)
d = sub(t1, t) d = sub(t1, t)
if d >= thresh: if d >= thresh:
print(t1, cid, "%s/%s" % (stores, old), \ print(t1, cid, "%s/%s" % (stores, old),
sub(t0, t1), sub(t1, t2), vs, \ sub(t0, t1), sub(t1, t2), vs,
sub(t2, t3), sub(t3, t)) sub(t2, t3), sub(t3, t))
del transactions[cid] del transactions[cid]
def minute(f, slice=16, detail=1, summary=1): def minute(f, slice=16, detail=1, summary=1):
f, = f f, = f
...@@ -405,10 +416,9 @@ def minute(f, slice=16, detail=1, summary=1): ...@@ -405,10 +416,9 @@ def minute(f, slice=16, detail=1, summary=1):
for line in f: for line in f:
line = line.strip() line = line.strip()
if (line.find('returns') > 0 if line.find('returns') > 0 or \
or line.find('storea') > 0 line.find('storea') > 0 or \
or line.find('tpc_abort') > 0 line.find('tpc_abort') > 0:
):
client = connidre.search(line).group(1) client = connidre.search(line).group(1)
m = line[:slice] m = line[:slice]
if m != mlast: if m != mlast:
...@@ -452,12 +462,13 @@ def minute(f, slice=16, detail=1, summary=1): ...@@ -452,12 +462,13 @@ def minute(f, slice=16, detail=1, summary=1):
print('Summary: \t', '\t'.join(('min', '10%', '25%', 'med', print('Summary: \t', '\t'.join(('min', '10%', '25%', 'med',
'75%', '90%', 'max', 'mean'))) '75%', '90%', 'max', 'mean')))
print("n=%6d\t" % len(cls), '-'*62) print("n=%6d\t" % len(cls), '-'*62)
print('Clients: \t', '\t'.join(map(str,stats(cls)))) print('Clients: \t', '\t'.join(map(str, stats(cls))))
print('Reads: \t', '\t'.join(map(str,stats(rs)))) print('Reads: \t', '\t'.join(map(str, stats(rs))))
print('Stores: \t', '\t'.join(map(str,stats(ss)))) print('Stores: \t', '\t'.join(map(str, stats(ss))))
print('Commits: \t', '\t'.join(map(str,stats(cs)))) print('Commits: \t', '\t'.join(map(str, stats(cs))))
print('Aborts: \t', '\t'.join(map(str,stats(aborts)))) print('Aborts: \t', '\t'.join(map(str, stats(aborts))))
print('Trans: \t', '\t'.join(map(str,stats(ts)))) print('Trans: \t', '\t'.join(map(str, stats(ts))))
def stats(s): def stats(s):
s.sort() s.sort()
...@@ -468,13 +479,14 @@ def stats(s): ...@@ -468,13 +479,14 @@ def stats(s):
ni = n + 1 ni = n + 1
for p in .1, .25, .5, .75, .90: for p in .1, .25, .5, .75, .90:
lp = ni*p lp = ni*p
l = int(lp) lp_int = int(lp)
if lp < 1 or lp > n: if lp < 1 or lp > n:
out.append('-') out.append('-')
elif abs(lp-l) < .00001: elif abs(lp-lp_int) < .00001:
out.append(s[l-1]) out.append(s[lp_int-1])
else: else:
out.append(int(s[l-1] + (lp - l) * (s[l] - s[l-1]))) out.append(
int(s[lp_int-1] + (lp - lp_int) * (s[lp_int] - s[lp_int-1])))
mean = 0.0 mean = 0.0
for v in s: for v in s:
...@@ -484,24 +496,31 @@ def stats(s): ...@@ -484,24 +496,31 @@ def stats(s):
return out return out
def minutes(f): def minutes(f):
minute(f, 16, detail=0) minute(f, 16, detail=0)
def hour(f): def hour(f):
minute(f, 13) minute(f, 13)
def day(f): def day(f):
minute(f, 10) minute(f, 10)
def hours(f): def hours(f):
minute(f, 13, detail=0) minute(f, 13, detail=0)
def days(f): def days(f):
minute(f, 10, detail=0) minute(f, 10, detail=0)
new_connection_idre = re.compile( new_connection_idre = re.compile(
r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):") r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):")
def verify(f): def verify(f):
f, = f f, = f
...@@ -527,6 +546,7 @@ def verify(f): ...@@ -527,6 +546,7 @@ def verify(f):
d = sub(t1, time(line)) d = sub(t1, time(line))
print(cid, t1, n, d, n and (d*1000.0/n) or '-') print(cid, t1, n, d, n and (d*1000.0/n) or '-')
def recovery(f): def recovery(f):
f, = f f, = f
...@@ -542,16 +562,16 @@ def recovery(f): ...@@ -542,16 +562,16 @@ def recovery(f):
n += 1 n += 1
if line.find('RecoveryServer') < 0: if line.find('RecoveryServer') < 0:
continue continue
l = line.find('sending transaction ') pos = line.find('sending transaction ')
if l > 0 and last.find('sending transaction ') > 0: if pos > 0 and last.find('sending transaction ') > 0:
trans.append(line[l+20:].strip()) trans.append(line[pos+20:].strip())
else: else:
if trans: if trans:
if len(trans) > 1: if len(trans) > 1:
print(" ... %s similar records skipped ..." % ( print(" ... %s similar records skipped ..." % (
len(trans) - 1)) len(trans) - 1))
print(n, last.strip()) print(n, last.strip())
trans=[] trans = []
print(n, line.strip()) print(n, line.strip())
last = line last = line
...@@ -561,6 +581,5 @@ def recovery(f): ...@@ -561,6 +581,5 @@ def recovery(f):
print(n, last.strip()) print(n, last.strip())
if __name__ == '__main__': if __name__ == '__main__':
globals()[sys.argv[1]](sys.argv[2:]) globals()[sys.argv[1]](sys.argv[2:])
...@@ -47,6 +47,7 @@ from ZEO.ClientStorage import ClientStorage ...@@ -47,6 +47,7 @@ from ZEO.ClientStorage import ClientStorage
ZEO_VERSION = 2 ZEO_VERSION = 2
def setup_logging(): def setup_logging():
# Set up logging to stderr which will show messages originating # Set up logging to stderr which will show messages originating
# at severity ERROR or higher. # at severity ERROR or higher.
...@@ -59,6 +60,7 @@ def setup_logging(): ...@@ -59,6 +60,7 @@ def setup_logging():
handler.setFormatter(fmt) handler.setFormatter(fmt)
root.addHandler(handler) root.addHandler(handler)
def check_server(addr, storage, write): def check_server(addr, storage, write):
t0 = time.time() t0 = time.time()
if ZEO_VERSION == 2: if ZEO_VERSION == 2:
...@@ -97,11 +99,13 @@ def check_server(addr, storage, write): ...@@ -97,11 +99,13 @@ def check_server(addr, storage, write):
t1 = time.time() t1 = time.time()
print("Elapsed time: %.2f" % (t1 - t0)) print("Elapsed time: %.2f" % (t1 - t0))
def usage(exit=1): def usage(exit=1):
print(__doc__) print(__doc__)
print(" ".join(sys.argv)) print(" ".join(sys.argv))
sys.exit(exit) sys.exit(exit)
def main(): def main():
host = None host = None
port = None port = None
...@@ -123,7 +127,7 @@ def main(): ...@@ -123,7 +127,7 @@ def main():
elif o == '--nowrite': elif o == '--nowrite':
write = 0 write = 0
elif o == '-1': elif o == '-1':
ZEO_VERSION = 1 ZEO_VERSION = 1 # NOQA: F841 unused variable
except Exception as err: except Exception as err:
s = str(err) s = str(err)
if s: if s:
...@@ -143,6 +147,7 @@ def main(): ...@@ -143,6 +147,7 @@ def main():
setup_logging() setup_logging()
check_server(addr, storage, write) check_server(addr, storage, write)
if __name__ == "__main__": if __name__ == "__main__":
try: try:
main() main()
......
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
REPR_LIMIT = 60 REPR_LIMIT = 60
def short_repr(obj): def short_repr(obj):
"Return an object repr limited to REPR_LIMIT bytes." """Return an object repr limited to REPR_LIMIT bytes."""
# Some of the objects being repr'd are large strings. A lot of memory # Some of the objects being repr'd are large strings. A lot of memory
# would be wasted to repr them and then truncate, so they are treated # would be wasted to repr them and then truncate, so they are treated
......
...@@ -17,6 +17,7 @@ from ZODB.Connection import TransactionMetaData ...@@ -17,6 +17,7 @@ from ZODB.Connection import TransactionMetaData
from ZODB.tests.MinPO import MinPO from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_unpickle from ZODB.tests.StorageTestBase import zodb_unpickle
class TransUndoStorageWithCache(object): class TransUndoStorageWithCache(object):
def checkUndoInvalidation(self): def checkUndoInvalidation(self):
......
...@@ -20,12 +20,12 @@ from persistent.TimeStamp import TimeStamp ...@@ -20,12 +20,12 @@ from persistent.TimeStamp import TimeStamp
from ZODB.Connection import TransactionMetaData from ZODB.Connection import TransactionMetaData
from ZODB.tests.StorageTestBase import zodb_pickle, MinPO from ZODB.tests.StorageTestBase import zodb_pickle, MinPO
import ZEO.ClientStorage
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
from ZEO.tests.TestThread import TestThread from ZEO.tests.TestThread import TestThread
ZERO = b'\0'*8 ZERO = b'\0'*8
class WorkerThread(TestThread): class WorkerThread(TestThread):
# run the entire test in a thread so that the blocking call for # run the entire test in a thread so that the blocking call for
...@@ -62,6 +62,7 @@ class WorkerThread(TestThread): ...@@ -62,6 +62,7 @@ class WorkerThread(TestThread):
self.ready.set() self.ready.set()
future.result(9) future.result(9)
class CommitLockTests(object): class CommitLockTests(object):
NUM_CLIENTS = 5 NUM_CLIENTS = 5
...@@ -99,7 +100,7 @@ class CommitLockTests(object): ...@@ -99,7 +100,7 @@ class CommitLockTests(object):
for i in range(self.NUM_CLIENTS): for i in range(self.NUM_CLIENTS):
storage = self._new_storage_client() storage = self._new_storage_client()
txn = TransactionMetaData() txn = TransactionMetaData()
tid = self._get_timestamp() self._get_timestamp()
t = WorkerThread(self, storage, txn) t = WorkerThread(self, storage, txn)
self._threads.append(t) self._threads.append(t)
...@@ -118,9 +119,10 @@ class CommitLockTests(object): ...@@ -118,9 +119,10 @@ class CommitLockTests(object):
def _get_timestamp(self): def _get_timestamp(self):
t = time.time() t = time.time()
t = TimeStamp(*time.gmtime(t)[:5]+(t%60,)) t = TimeStamp(*time.gmtime(t)[:5]+(t % 60,))
return repr(t) return repr(t)
class CommitLockVoteTests(CommitLockTests): class CommitLockVoteTests(CommitLockTests):
def checkCommitLockVoteFinish(self): def checkCommitLockVoteFinish(self):
......
...@@ -26,11 +26,10 @@ from ZEO.tests import forker ...@@ -26,11 +26,10 @@ from ZEO.tests import forker
from ZODB.Connection import TransactionMetaData from ZODB.Connection import TransactionMetaData
from ZODB.DB import DB from ZODB.DB import DB
from ZODB.POSException import ReadOnlyError, ConflictError from ZODB.POSException import ReadOnlyError
from ZODB.tests.StorageTestBase import StorageTestBase from ZODB.tests.StorageTestBase import StorageTestBase
from ZODB.tests.MinPO import MinPO from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_pickle, zodb_unpickle from ZODB.tests.StorageTestBase import zodb_pickle, zodb_unpickle
import ZODB.tests.util
import transaction import transaction
...@@ -40,6 +39,7 @@ logger = logging.getLogger('ZEO.tests.ConnectionTests') ...@@ -40,6 +39,7 @@ logger = logging.getLogger('ZEO.tests.ConnectionTests')
ZERO = '\0'*8 ZERO = '\0'*8
class TestClientStorage(ClientStorage): class TestClientStorage(ClientStorage):
test_connection = False test_connection = False
...@@ -51,6 +51,7 @@ class TestClientStorage(ClientStorage): ...@@ -51,6 +51,7 @@ class TestClientStorage(ClientStorage):
self.connection_count_for_tests += 1 self.connection_count_for_tests += 1
self.verify_result = conn.verify_result self.verify_result = conn.verify_result
class DummyDB(object): class DummyDB(object):
def invalidate(self, *args, **kwargs): def invalidate(self, *args, **kwargs):
pass pass
...@@ -113,7 +114,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -113,7 +114,7 @@ class CommonSetupTearDown(StorageTestBase):
for dummy in range(5): for dummy in range(5):
try: try:
os.unlink(path) os.unlink(path)
except: except: # NOQA: E722 bare except
time.sleep(0.5) time.sleep(0.5)
else: else:
need_to_delete = False need_to_delete = False
...@@ -188,7 +189,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -188,7 +189,7 @@ class CommonSetupTearDown(StorageTestBase):
stop = self._servers[index] stop = self._servers[index]
if stop is not None: if stop is not None:
stop() stop()
self._servers[index] = lambda : None self._servers[index] = lambda: None
def pollUp(self, timeout=30.0, storage=None): def pollUp(self, timeout=30.0, storage=None):
if storage is None: if storage is None:
...@@ -271,7 +272,6 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -271,7 +272,6 @@ class ConnectionTests(CommonSetupTearDown):
self.assertRaises(ReadOnlyError, self._dostore) self.assertRaises(ReadOnlyError, self._dostore)
self._storage.close() self._storage.close()
def checkDisconnectionError(self): def checkDisconnectionError(self):
# Make sure we get a ClientDisconnected when we try to read an # Make sure we get a ClientDisconnected when we try to read an
# object when we're not connected to a storage server and the # object when we're not connected to a storage server and the
...@@ -416,6 +416,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -416,6 +416,7 @@ class ConnectionTests(CommonSetupTearDown):
def checkBadMessage2(self): def checkBadMessage2(self):
# just like a real message, but with an unpicklable argument # just like a real message, but with an unpicklable argument
global Hack global Hack
class Hack(object): class Hack(object):
pass pass
...@@ -551,6 +552,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -551,6 +552,7 @@ class ConnectionTests(CommonSetupTearDown):
self.assertRaises(ClientDisconnected, self.assertRaises(ClientDisconnected,
self._storage.load, b'\0'*8, '') self._storage.load, b'\0'*8, '')
class SSLConnectionTests(ConnectionTests): class SSLConnectionTests(ConnectionTests):
def getServerConfig(self, addr, ro_svr): def getServerConfig(self, addr, ro_svr):
...@@ -585,13 +587,13 @@ class InvqTests(CommonSetupTearDown): ...@@ -585,13 +587,13 @@ class InvqTests(CommonSetupTearDown):
revid2 = self._dostore(oid2, revid2) revid2 = self._dostore(oid2, revid2)
forker.wait_until( forker.wait_until(
lambda : lambda:
perstorage.lastTransaction() == self._storage.lastTransaction()) perstorage.lastTransaction() == self._storage.lastTransaction())
perstorage.load(oid, '') perstorage.load(oid, '')
perstorage.close() perstorage.close()
forker.wait_until(lambda : os.path.exists('test-1.zec')) forker.wait_until(lambda: os.path.exists('test-1.zec'))
revid = self._dostore(oid, revid) revid = self._dostore(oid, revid)
...@@ -617,7 +619,7 @@ class InvqTests(CommonSetupTearDown): ...@@ -617,7 +619,7 @@ class InvqTests(CommonSetupTearDown):
revid = self._dostore(oid, revid) revid = self._dostore(oid, revid)
forker.wait_until( forker.wait_until(
"Client has seen all of the transactions from the server", "Client has seen all of the transactions from the server",
lambda : lambda:
perstorage.lastTransaction() == self._storage.lastTransaction() perstorage.lastTransaction() == self._storage.lastTransaction()
) )
perstorage.load(oid, '') perstorage.load(oid, '')
...@@ -635,6 +637,7 @@ class InvqTests(CommonSetupTearDown): ...@@ -635,6 +637,7 @@ class InvqTests(CommonSetupTearDown):
perstorage.close() perstorage.close()
class ReconnectionTests(CommonSetupTearDown): class ReconnectionTests(CommonSetupTearDown):
# The setUp() starts a server automatically. In order for its # The setUp() starts a server automatically. In order for its
# state to persist, we set the class variable keep to 1. In # state to persist, we set the class variable keep to 1. In
...@@ -840,7 +843,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -840,7 +843,7 @@ class ReconnectionTests(CommonSetupTearDown):
revid = self._dostore(oid, revid) revid = self._dostore(oid, revid)
forker.wait_until( forker.wait_until(
"Client has seen all of the transactions from the server", "Client has seen all of the transactions from the server",
lambda : lambda:
perstorage.lastTransaction() == self._storage.lastTransaction() perstorage.lastTransaction() == self._storage.lastTransaction()
) )
perstorage.load(oid, '') perstorage.load(oid, '')
...@@ -894,7 +897,6 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -894,7 +897,6 @@ class ReconnectionTests(CommonSetupTearDown):
# Module ZEO.ClientStorage, line 709, in _update_cache # Module ZEO.ClientStorage, line 709, in _update_cache
# KeyError: ... # KeyError: ...
def checkReconnection(self): def checkReconnection(self):
# Check that the client reconnects when a server restarts. # Check that the client reconnects when a server restarts.
...@@ -952,6 +954,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -952,6 +954,7 @@ class ReconnectionTests(CommonSetupTearDown):
self.assertTrue(did_a_store) self.assertTrue(did_a_store)
self._storage.close() self._storage.close()
class TimeoutTests(CommonSetupTearDown): class TimeoutTests(CommonSetupTearDown):
timeout = 1 timeout = 1
...@@ -967,9 +970,8 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -967,9 +970,8 @@ class TimeoutTests(CommonSetupTearDown):
# Make sure it's logged as CRITICAL # Make sure it's logged as CRITICAL
with open("server.log") as f: with open("server.log") as f:
for line in f: for line in f:
if (('Transaction timeout after' in line) and if ('Transaction timeout after' in line) and \
('CRITICAL ZEO.StorageServer' in line) ('CRITICAL ZEO.StorageServer' in line):
):
break break
else: else:
self.fail('bad logging') self.fail('bad logging')
...@@ -1002,7 +1004,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -1002,7 +1004,7 @@ class TimeoutTests(CommonSetupTearDown):
t = TransactionMetaData() t = TransactionMetaData()
old_connection_count = storage.connection_count_for_tests old_connection_count = storage.connection_count_for_tests
storage.tpc_begin(t) storage.tpc_begin(t)
revid1 = storage.store(oid, ZERO, zodb_pickle(obj), '', t) storage.store(oid, ZERO, zodb_pickle(obj), '', t)
storage.tpc_vote(t) storage.tpc_vote(t)
# Now sleep long enough for the storage to time out # Now sleep long enough for the storage to time out
time.sleep(3) time.sleep(3)
...@@ -1021,6 +1023,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -1021,6 +1023,7 @@ class TimeoutTests(CommonSetupTearDown):
# or the server. # or the server.
self.assertRaises(KeyError, storage.load, oid, '') self.assertRaises(KeyError, storage.load, oid, '')
class MSTThread(threading.Thread): class MSTThread(threading.Thread):
__super_init = threading.Thread.__init__ __super_init = threading.Thread.__init__
...@@ -1054,7 +1057,7 @@ class MSTThread(threading.Thread): ...@@ -1054,7 +1057,7 @@ class MSTThread(threading.Thread):
# Begin a transaction # Begin a transaction
t = TransactionMetaData() t = TransactionMetaData()
for c in clients: for c in clients:
#print("%s.%s.%s begin" % (tname, c.__name, i)) # print("%s.%s.%s begin" % (tname, c.__name, i))
c.tpc_begin(t) c.tpc_begin(t)
for j in range(testcase.nobj): for j in range(testcase.nobj):
...@@ -1063,18 +1066,18 @@ class MSTThread(threading.Thread): ...@@ -1063,18 +1066,18 @@ class MSTThread(threading.Thread):
oid = c.new_oid() oid = c.new_oid()
c.__oids.append(oid) c.__oids.append(oid)
data = MinPO("%s.%s.t%d.o%d" % (tname, c.__name, i, j)) data = MinPO("%s.%s.t%d.o%d" % (tname, c.__name, i, j))
#print(data.value) # print(data.value)
data = zodb_pickle(data) data = zodb_pickle(data)
c.store(oid, ZERO, data, '', t) c.store(oid, ZERO, data, '', t)
# Vote on all servers and handle serials # Vote on all servers and handle serials
for c in clients: for c in clients:
#print("%s.%s.%s vote" % (tname, c.__name, i)) # print("%s.%s.%s vote" % (tname, c.__name, i))
c.tpc_vote(t) c.tpc_vote(t)
# Finish on all servers # Finish on all servers
for c in clients: for c in clients:
#print("%s.%s.%s finish\n" % (tname, c.__name, i)) # print("%s.%s.%s finish\n" % (tname, c.__name, i))
c.tpc_finish(t) c.tpc_finish(t)
for c in clients: for c in clients:
...@@ -1090,7 +1093,7 @@ class MSTThread(threading.Thread): ...@@ -1090,7 +1093,7 @@ class MSTThread(threading.Thread):
for c in self.clients: for c in self.clients:
try: try:
c.close() c.close()
except: except: # NOQA: E722 bare except
pass pass
...@@ -1101,6 +1104,7 @@ def short_timeout(self): ...@@ -1101,6 +1104,7 @@ def short_timeout(self):
yield yield
self._storage._server.timeout = old self._storage._server.timeout = old
# Run IPv6 tests if V6 sockets are supported # Run IPv6 tests if V6 sockets are supported
try: try:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
......
...@@ -41,6 +41,7 @@ from ZODB.POSException import ReadConflictError, ConflictError ...@@ -41,6 +41,7 @@ from ZODB.POSException import ReadConflictError, ConflictError
# thought they added (i.e., the keys for which transaction.commit() # thought they added (i.e., the keys for which transaction.commit()
# did not raise any exception). # did not raise any exception).
class FailableThread(TestThread): class FailableThread(TestThread):
# mixin class # mixin class
...@@ -52,7 +53,7 @@ class FailableThread(TestThread): ...@@ -52,7 +53,7 @@ class FailableThread(TestThread):
def testrun(self): def testrun(self):
try: try:
self._testrun() self._testrun()
except: except: # NOQA: E722 bare except
# Report the failure here to all the other threads, so # Report the failure here to all the other threads, so
# that they stop quickly. # that they stop quickly.
self.stop.set() self.stop.set()
...@@ -81,12 +82,11 @@ class StressTask(object): ...@@ -81,12 +82,11 @@ class StressTask(object):
tree[key] = self.threadnum tree[key] = self.threadnum
def commit(self): def commit(self):
cn = self.cn
key = self.startnum key = self.startnum
self.tm.get().note(u"add key %s" % key) self.tm.get().note(u"add key %s" % key)
try: try:
self.tm.get().commit() self.tm.get().commit()
except ConflictError as msg: except ConflictError:
self.tm.abort() self.tm.abort()
else: else:
if self.sleep: if self.sleep:
...@@ -98,13 +98,16 @@ class StressTask(object): ...@@ -98,13 +98,16 @@ class StressTask(object):
self.tm.get().abort() self.tm.get().abort()
self.cn.close() self.cn.close()
def _runTasks(rounds, *tasks): def _runTasks(rounds, *tasks):
'''run *task* interleaved for *rounds* rounds.''' '''run *task* interleaved for *rounds* rounds.'''
def commit(run, actions): def commit(run, actions):
actions.append(':') actions.append(':')
for t in run: for t in run:
t.commit() t.commit()
del run[:] del run[:]
r = Random() r = Random()
r.seed(1064589285) # make it deterministic r.seed(1064589285) # make it deterministic
run = [] run = []
...@@ -117,7 +120,7 @@ def _runTasks(rounds, *tasks): ...@@ -117,7 +120,7 @@ def _runTasks(rounds, *tasks):
run.append(t) run.append(t)
t.doStep() t.doStep()
actions.append(repr(t.startnum)) actions.append(repr(t.startnum))
commit(run,actions) commit(run, actions)
# stderr.write(' '.join(actions)+'\n') # stderr.write(' '.join(actions)+'\n')
finally: finally:
for t in tasks: for t in tasks:
...@@ -160,13 +163,14 @@ class StressThread(FailableThread): ...@@ -160,13 +163,14 @@ class StressThread(FailableThread):
self.commitdict[self] = 1 self.commitdict[self] = 1
if self.sleep: if self.sleep:
time.sleep(self.sleep) time.sleep(self.sleep)
except (ReadConflictError, ConflictError) as msg: except (ReadConflictError, ConflictError):
tm.abort() tm.abort()
else: else:
self.added_keys.append(key) self.added_keys.append(key)
key += self.step key += self.step
cn.close() cn.close()
class LargeUpdatesThread(FailableThread): class LargeUpdatesThread(FailableThread):
# A thread that performs a lot of updates. It attempts to modify # A thread that performs a lot of updates. It attempts to modify
...@@ -212,7 +216,7 @@ class LargeUpdatesThread(FailableThread): ...@@ -212,7 +216,7 @@ class LargeUpdatesThread(FailableThread):
for key in keys: for key in keys:
try: try:
tree[key] = self.threadnum tree[key] = self.threadnum
except (ReadConflictError, ConflictError) as msg: except (ReadConflictError, ConflictError): # as msg:
# print("%d setting key %s" % (self.threadnum, msg)) # print("%d setting key %s" % (self.threadnum, msg))
transaction.abort() transaction.abort()
break break
...@@ -224,7 +228,7 @@ class LargeUpdatesThread(FailableThread): ...@@ -224,7 +228,7 @@ class LargeUpdatesThread(FailableThread):
self.commitdict[self] = 1 self.commitdict[self] = 1
if self.sleep: if self.sleep:
time.sleep(self.sleep) time.sleep(self.sleep)
except ConflictError as msg: except ConflictError: # as msg
# print("%d commit %s" % (self.threadnum, msg)) # print("%d commit %s" % (self.threadnum, msg))
transaction.abort() transaction.abort()
continue continue
...@@ -234,6 +238,7 @@ class LargeUpdatesThread(FailableThread): ...@@ -234,6 +238,7 @@ class LargeUpdatesThread(FailableThread):
self.added_keys = keys_added.keys() self.added_keys = keys_added.keys()
cn.close() cn.close()
class InvalidationTests(object): class InvalidationTests(object):
# Minimum # of seconds the main thread lets the workers run. The # Minimum # of seconds the main thread lets the workers run. The
...@@ -261,7 +266,7 @@ class InvalidationTests(object): ...@@ -261,7 +266,7 @@ class InvalidationTests(object):
transaction.abort() transaction.abort()
else: else:
raise raise
except: except: # NOQA: E722 bare except
display(tree) display(tree)
raise raise
......
...@@ -21,6 +21,7 @@ from ZODB.Connection import TransactionMetaData ...@@ -21,6 +21,7 @@ from ZODB.Connection import TransactionMetaData
from ..asyncio.testing import AsyncRPC from ..asyncio.testing import AsyncRPC
class IterationTests(object): class IterationTests(object):
def _assertIteratorIdsEmpty(self): def _assertIteratorIdsEmpty(self):
...@@ -147,7 +148,6 @@ class IterationTests(object): ...@@ -147,7 +148,6 @@ class IterationTests(object):
self._dostore() self._dostore()
six.advance_iterator(self._storage.iterator()) six.advance_iterator(self._storage.iterator())
iid = list(self._storage._iterator_ids)[0]
t = TransactionMetaData() t = TransactionMetaData()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
# Show that after disconnecting, the client side GCs the iterators # Show that after disconnecting, the client side GCs the iterators
...@@ -176,12 +176,12 @@ def iterator_sane_after_reconnect(): ...@@ -176,12 +176,12 @@ def iterator_sane_after_reconnect():
Start a server: Start a server:
>>> addr, adminaddr = start_server( >>> addr, adminaddr = start_server( # NOQA: F821 undefined
... '<filestorage>\npath fs\n</filestorage>', keep=1) ... '<filestorage>\npath fs\n</filestorage>', keep=1)
Open a client storage to it and commit a some transactions: Open a client storage to it and commit a some transactions:
>>> import ZEO, ZODB, transaction >>> import ZEO, ZODB
>>> client = ZEO.client(addr) >>> client = ZEO.client(addr)
>>> db = ZODB.DB(client) >>> db = ZODB.DB(client)
>>> conn = db.open() >>> conn = db.open()
...@@ -196,10 +196,11 @@ Create an iterator: ...@@ -196,10 +196,11 @@ Create an iterator:
Restart the storage: Restart the storage:
>>> stop_server(adminaddr) >>> stop_server(adminaddr) # NOQA: F821 undefined
>>> wait_disconnected(client) >>> wait_disconnected(client) # NOQA: F821 undefined
>>> _ = start_server('<filestorage>\npath fs\n</filestorage>', addr=addr) >>> _ = start_server( # NOQA: F821 undefined
>>> wait_connected(client) ... '<filestorage>\npath fs\n</filestorage>', addr=addr)
>>> wait_connected(client) # NOQA: F821 undefined
Now, we'll create a second iterator: Now, we'll create a second iterator:
......
...@@ -16,6 +16,7 @@ import threading ...@@ -16,6 +16,7 @@ import threading
import sys import sys
import six import six
class TestThread(threading.Thread): class TestThread(threading.Thread):
"""Base class for defining threads that run from unittest. """Base class for defining threads that run from unittest.
...@@ -46,12 +47,14 @@ class TestThread(threading.Thread): ...@@ -46,12 +47,14 @@ class TestThread(threading.Thread):
def run(self): def run(self):
try: try:
self.testrun() self.testrun()
except: except: # NOQA: E722 blank except
self._exc_info = sys.exc_info() self._exc_info = sys.exc_info()
def cleanup(self, timeout=15): def cleanup(self, timeout=15):
self.join(timeout) self.join(timeout)
if self._exc_info: if self._exc_info:
six.reraise(self._exc_info[0], self._exc_info[1], self._exc_info[2]) six.reraise(self._exc_info[0],
self._exc_info[1],
self._exc_info[2])
if self.is_alive(): if self.is_alive():
self._testcase.fail("Thread did not finish: %s" % self) self._testcase.fail("Thread did not finish: %s" % self)
...@@ -21,6 +21,7 @@ import ZEO.Exceptions ...@@ -21,6 +21,7 @@ import ZEO.Exceptions
ZERO = '\0'*8 ZERO = '\0'*8
class BasicThread(threading.Thread): class BasicThread(threading.Thread):
def __init__(self, storage, doNextEvent, threadStartedEvent): def __init__(self, storage, doNextEvent, threadStartedEvent):
self.storage = storage self.storage = storage
...@@ -123,7 +124,6 @@ class ThreadTests(object): ...@@ -123,7 +124,6 @@ class ThreadTests(object):
# Helper for checkMTStores # Helper for checkMTStores
def mtstorehelper(self): def mtstorehelper(self):
name = threading.currentThread().getName()
objs = [] objs = []
for i in range(10): for i in range(10):
objs.append(MinPO("X" * 200000)) objs.append(MinPO("X" * 200000))
......
...@@ -41,7 +41,6 @@ from ZEO._compat import Pickler, Unpickler, PY3, BytesIO ...@@ -41,7 +41,6 @@ from ZEO._compat import Pickler, Unpickler, PY3, BytesIO
from ZEO.Exceptions import AuthError from ZEO.Exceptions import AuthError
from .monitor import StorageStats, StatsServer from .monitor import StorageStats, StatsServer
from .zrpc.connection import ManagedServerConnection, Delay, MTDelay, Result from .zrpc.connection import ManagedServerConnection, Delay, MTDelay, Result
from .zrpc.server import Dispatcher
from ZODB.Connection import TransactionMetaData from ZODB.Connection import TransactionMetaData
from ZODB.loglevels import BLATHER from ZODB.loglevels import BLATHER
from ZODB.POSException import StorageError, StorageTransactionError from ZODB.POSException import StorageError, StorageTransactionError
...@@ -53,6 +52,7 @@ ResolvedSerial = b'rs' ...@@ -53,6 +52,7 @@ ResolvedSerial = b'rs'
logger = logging.getLogger('ZEO.StorageServer') logger = logging.getLogger('ZEO.StorageServer')
def log(message, level=logging.INFO, label='', exc_info=False): def log(message, level=logging.INFO, label='', exc_info=False):
"""Internal helper to log a message.""" """Internal helper to log a message."""
if label: if label:
...@@ -152,7 +152,7 @@ class ZEOStorage(object): ...@@ -152,7 +152,7 @@ class ZEOStorage(object):
info = self.get_info() info = self.get_info()
if not info['supportsUndo']: if not info['supportsUndo']:
self.undoLog = self.undoInfo = lambda *a,**k: () self.undoLog = self.undoInfo = lambda *a, **k: ()
self.getTid = storage.getTid self.getTid = storage.getTid
self.load = storage.load self.load = storage.load
...@@ -182,14 +182,14 @@ class ZEOStorage(object): ...@@ -182,14 +182,14 @@ class ZEOStorage(object):
"Falling back to using _transaction attribute, which\n." "Falling back to using _transaction attribute, which\n."
"is icky.", "is icky.",
logging.ERROR) logging.ERROR)
self.tpc_transaction = lambda : storage._transaction self.tpc_transaction = lambda: storage._transaction
else: else:
raise raise
def history(self,tid,size=1): def history(self, tid, size=1):
# This caters for storages which still accept # This caters for storages which still accept
# a version parameter. # a version parameter.
return self.storage.history(tid,size=size) return self.storage.history(tid, size=size)
def _check_tid(self, tid, exc=None): def _check_tid(self, tid, exc=None):
if self.read_only: if self.read_only:
...@@ -253,8 +253,7 @@ class ZEOStorage(object): ...@@ -253,8 +253,7 @@ class ZEOStorage(object):
def get_info(self): def get_info(self):
storage = self.storage storage = self.storage
supportsUndo = (getattr(storage, 'supportsUndo', lambda: False)()
supportsUndo = (getattr(storage, 'supportsUndo', lambda : False)()
and self.connection.peer_protocol_version >= b'Z310') and self.connection.peer_protocol_version >= b'Z310')
# Communicate the backend storage interfaces to the client # Communicate the backend storage interfaces to the client
...@@ -473,7 +472,6 @@ class ZEOStorage(object): ...@@ -473,7 +472,6 @@ class ZEOStorage(object):
if not getattr(self, op)(*args): if not getattr(self, op)(*args):
break break
# Blob support # Blob support
while self.blob_log and not self.store_failed: while self.blob_log and not self.store_failed:
oid, oldserial, data, blobfilename = self.blob_log.pop() oid, oldserial, data, blobfilename = self.blob_log.pop()
...@@ -558,11 +556,9 @@ class ZEOStorage(object): ...@@ -558,11 +556,9 @@ class ZEOStorage(object):
assert self.txnlog is not None # effectively not allowed after undo assert self.txnlog is not None # effectively not allowed after undo
# Reconstruct the full path from the filename in the OID directory # Reconstruct the full path from the filename in the OID directory
if (os.path.sep in filename if os.path.sep in filename or \
or not (filename.endswith('.tmp') not (filename.endswith('.tmp')
or filename[:-1].endswith('.tmp') or filename[:-1].endswith('.tmp')):
)
):
logger.critical( logger.critical(
"We're under attack! (bad filename to storeBlobShared, %r)", "We're under attack! (bad filename to storeBlobShared, %r)",
filename) filename)
...@@ -590,7 +586,7 @@ class ZEOStorage(object): ...@@ -590,7 +586,7 @@ class ZEOStorage(object):
(oid_repr(oid), str(err)), BLATHER) (oid_repr(oid), str(err)), BLATHER)
if not isinstance(err, TransactionError): if not isinstance(err, TransactionError):
# Unexpected errors are logged and passed to the client # Unexpected errors are logged and passed to the client
self.log("%s error: %s, %s" % ((op,)+ sys.exc_info()[:2]), self.log("%s error: %s, %s" % ((op,) + sys.exc_info()[:2]),
logging.ERROR, exc_info=True) logging.ERROR, exc_info=True)
err = self._marshal_error(err) err = self._marshal_error(err)
# The exception is reported back as newserial for this oid # The exception is reported back as newserial for this oid
...@@ -691,7 +687,7 @@ class ZEOStorage(object): ...@@ -691,7 +687,7 @@ class ZEOStorage(object):
pickler.fast = 1 pickler.fast = 1
try: try:
pickler.dump(error) pickler.dump(error)
except: except: # NOQA: E722 bare except
msg = "Couldn't pickle storage exception: %s" % repr(error) msg = "Couldn't pickle storage exception: %s" % repr(error)
self.log(msg, logging.ERROR) self.log(msg, logging.ERROR)
error = StorageServerError(msg) error = StorageServerError(msg)
...@@ -758,6 +754,7 @@ class ZEOStorage(object): ...@@ -758,6 +754,7 @@ class ZEOStorage(object):
def set_client_label(self, label): def set_client_label(self, label):
self.log_label = str(label)+' '+_addr_label(self.connection.addr) self.log_label = str(label)+' '+_addr_label(self.connection.addr)
class StorageServerDB(object): class StorageServerDB(object):
def __init__(self, server, storage_id): def __init__(self, server, storage_id):
...@@ -776,6 +773,7 @@ class StorageServerDB(object): ...@@ -776,6 +773,7 @@ class StorageServerDB(object):
transform_record_data = untransform_record_data = lambda self, data: data transform_record_data = untransform_record_data = lambda self, data: data
class StorageServer(object): class StorageServer(object):
"""The server side implementation of ZEO. """The server side implementation of ZEO.
...@@ -876,7 +874,6 @@ class StorageServer(object): ...@@ -876,7 +874,6 @@ class StorageServer(object):
log("%s created %s with storages: %s" % log("%s created %s with storages: %s" %
(self.__class__.__name__, read_only and "RO" or "RW", msg)) (self.__class__.__name__, read_only and "RO" or "RW", msg))
self._lock = threading.Lock() self._lock = threading.Lock()
self._commit_locks = {} self._commit_locks = {}
self._waiting = dict((name, []) for name in storages) self._waiting = dict((name, []) for name in storages)
...@@ -942,7 +939,6 @@ class StorageServer(object): ...@@ -942,7 +939,6 @@ class StorageServer(object):
self.invq[name] = list(lastInvalidations(self.invq_bound)) self.invq[name] = list(lastInvalidations(self.invq_bound))
self.invq[name].reverse() self.invq[name].reverse()
def _setup_auth(self, protocol): def _setup_auth(self, protocol):
# Can't be done in global scope, because of cyclic references # Can't be done in global scope, because of cyclic references
from .auth import get_module from .auth import get_module
...@@ -976,7 +972,6 @@ class StorageServer(object): ...@@ -976,7 +972,6 @@ class StorageServer(object):
"does not match storage realm %r" "does not match storage realm %r"
% (self.database.realm, self.auth_realm)) % (self.database.realm, self.auth_realm))
def new_connection(self, sock, addr): def new_connection(self, sock, addr):
"""Internal: factory to create a new connection. """Internal: factory to create a new connection.
...@@ -1050,7 +1045,6 @@ class StorageServer(object): ...@@ -1050,7 +1045,6 @@ class StorageServer(object):
except DisconnectedError: except DisconnectedError:
pass pass
def invalidate(self, conn, storage_id, tid, invalidated=(), info=None): def invalidate(self, conn, storage_id, tid, invalidated=(), info=None):
"""Internal: broadcast info and invalidations to clients. """Internal: broadcast info and invalidations to clients.
...@@ -1096,7 +1090,6 @@ class StorageServer(object): ...@@ -1096,7 +1090,6 @@ class StorageServer(object):
# b. A connection is closes while we are iterating. We'll need # b. A connection is closes while we are iterating. We'll need
# to cactch and ignore Disconnected errors. # to cactch and ignore Disconnected errors.
if invalidated: if invalidated:
invq = self.invq[storage_id] invq = self.invq[storage_id]
if len(invq) >= self.invq_bound: if len(invq) >= self.invq_bound:
...@@ -1159,6 +1152,7 @@ class StorageServer(object): ...@@ -1159,6 +1152,7 @@ class StorageServer(object):
raise # Unexpected exc raise # Unexpected exc
__thread = None __thread = None
def start_thread(self, daemon=True): def start_thread(self, daemon=True):
self.__thread = thread = threading.Thread(target=self.loop) self.__thread = thread = threading.Thread(target=self.loop)
thread.setName("StorageServer(%s)" % _addr_label(self.addr)) thread.setName("StorageServer(%s)" % _addr_label(self.addr))
...@@ -1166,6 +1160,7 @@ class StorageServer(object): ...@@ -1166,6 +1160,7 @@ class StorageServer(object):
thread.start() thread.start()
__closed = False __closed = False
def close(self, join_timeout=1): def close(self, join_timeout=1):
"""Close the dispatcher so that there are no new connections. """Close the dispatcher so that there are no new connections.
...@@ -1187,7 +1182,7 @@ class StorageServer(object): ...@@ -1187,7 +1182,7 @@ class StorageServer(object):
for conn in connections[:]: for conn in connections[:]:
try: try:
conn.connection.close() conn.connection.close()
except: except: # NOQA: E722 bare except
pass pass
for name, storage in six.iteritems(self.storages): for name, storage in six.iteritems(self.storages):
...@@ -1282,7 +1277,6 @@ class StorageServer(object): ...@@ -1282,7 +1277,6 @@ class StorageServer(object):
except Exception: except Exception:
logger.exception("Calling unlock callback") logger.exception("Calling unlock callback")
def stop_waiting(self, zeostore): def stop_waiting(self, zeostore):
storage_id = zeostore.storage_id storage_id = zeostore.storage_id
waiting = self._waiting[storage_id] waiting = self._waiting[storage_id]
...@@ -1307,7 +1301,8 @@ class StorageServer(object): ...@@ -1307,7 +1301,8 @@ class StorageServer(object):
status = self.stats[storage_id].__dict__.copy() status = self.stats[storage_id].__dict__.copy()
status['connections'] = len(status['connections']) status['connections'] = len(status['connections'])
status['waiting'] = len(self._waiting[storage_id]) status['waiting'] = len(self._waiting[storage_id])
status['timeout-thread-is-alive'] = self.timeouts[storage_id].is_alive() status['timeout-thread-is-alive'] = \
self.timeouts[storage_id].is_alive()
last_transaction = self.storages[storage_id].lastTransaction() last_transaction = self.storages[storage_id].lastTransaction()
last_transaction_hex = codecs.encode(last_transaction, 'hex_codec') last_transaction_hex = codecs.encode(last_transaction, 'hex_codec')
if PY3: if PY3:
...@@ -1320,6 +1315,7 @@ class StorageServer(object): ...@@ -1320,6 +1315,7 @@ class StorageServer(object):
return dict((storage_id, self.server_status(storage_id)) return dict((storage_id, self.server_status(storage_id))
for storage_id in self.storages) for storage_id in self.storages)
def _level_for_waiting(waiting): def _level_for_waiting(waiting):
if len(waiting) > 9: if len(waiting) > 9:
return logging.CRITICAL return logging.CRITICAL
...@@ -1328,6 +1324,7 @@ def _level_for_waiting(waiting): ...@@ -1328,6 +1324,7 @@ def _level_for_waiting(waiting):
else: else:
return logging.DEBUG return logging.DEBUG
class StubTimeoutThread(object): class StubTimeoutThread(object):
def begin(self, client): def begin(self, client):
...@@ -1336,7 +1333,8 @@ class StubTimeoutThread(object): ...@@ -1336,7 +1333,8 @@ class StubTimeoutThread(object):
def end(self, client): def end(self, client):
pass pass
is_alive = lambda self: 'stub' def is_alive(self):
return 'stub'
class TimeoutThread(threading.Thread): class TimeoutThread(threading.Thread):
...@@ -1389,7 +1387,7 @@ class TimeoutThread(threading.Thread): ...@@ -1389,7 +1387,7 @@ class TimeoutThread(threading.Thread):
self._timeout, logging.CRITICAL) self._timeout, logging.CRITICAL)
try: try:
client.connection.call_from_thread(client.connection.close) client.connection.call_from_thread(client.connection.close)
except: except: # NOQA: E722 bare except
client.log("Timeout failure", logging.CRITICAL, client.log("Timeout failure", logging.CRITICAL,
exc_info=sys.exc_info()) exc_info=sys.exc_info())
self.end(client) self.end(client)
...@@ -1485,6 +1483,7 @@ class ClientStub(object): ...@@ -1485,6 +1483,7 @@ class ClientStub(object):
self.rpc.callAsyncIterator(store()) self.rpc.callAsyncIterator(store())
class ClientStub308(ClientStub): class ClientStub308(ClientStub):
def invalidateTransaction(self, tid, args): def invalidateTransaction(self, tid, args):
...@@ -1494,6 +1493,7 @@ class ClientStub308(ClientStub): ...@@ -1494,6 +1493,7 @@ class ClientStub308(ClientStub):
def invalidateVerify(self, oid): def invalidateVerify(self, oid):
ClientStub.invalidateVerify(self, (oid, '')) ClientStub.invalidateVerify(self, (oid, ''))
class ZEOStorage308Adapter(object): class ZEOStorage308Adapter(object):
def __init__(self, storage): def __init__(self, storage):
...@@ -1573,6 +1573,7 @@ class ZEOStorage308Adapter(object): ...@@ -1573,6 +1573,7 @@ class ZEOStorage308Adapter(object):
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.storage, name) return getattr(self.storage, name)
def _addr_label(addr): def _addr_label(addr):
if isinstance(addr, six.binary_type): if isinstance(addr, six.binary_type):
return addr.decode('ascii') return addr.decode('ascii')
...@@ -1582,6 +1583,7 @@ def _addr_label(addr): ...@@ -1582,6 +1583,7 @@ def _addr_label(addr):
host, port = addr host, port = addr
return str(host) + ":" + str(port) return str(host) + ":" + str(port)
class CommitLog(object): class CommitLog(object):
def __init__(self): def __init__(self):
...@@ -1624,14 +1626,17 @@ class CommitLog(object): ...@@ -1624,14 +1626,17 @@ class CommitLog(object):
self.file.close() self.file.close()
self.file = None self.file = None
class ServerEvent(object): class ServerEvent(object):
def __init__(self, server, **kw): def __init__(self, server, **kw):
self.__dict__.update(kw) self.__dict__.update(kw)
self.server = server self.server = server
class Serving(ServerEvent): class Serving(ServerEvent):
pass pass
class Closed(ServerEvent): class Closed(ServerEvent):
pass pass
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
_auth_modules = {} _auth_modules = {}
def get_module(name): def get_module(name):
if name == 'sha': if name == 'sha':
from auth_sha import StorageClass, SHAClient, Database from auth_sha import StorageClass, SHAClient, Database
...@@ -24,6 +25,7 @@ def get_module(name): ...@@ -24,6 +25,7 @@ def get_module(name):
else: else:
return _auth_modules.get(name) return _auth_modules.get(name)
def register_module(name, storage_class, client, db): def register_module(name, storage_class, client, db):
if name in _auth_modules: if name in _auth_modules:
raise TypeError("%s is already registred" % name) raise TypeError("%s is already registred" % name)
......
...@@ -45,6 +45,7 @@ from ..StorageServer import ZEOStorage ...@@ -45,6 +45,7 @@ from ..StorageServer import ZEOStorage
from ZEO.Exceptions import AuthError from ZEO.Exceptions import AuthError
from ..hash import sha1 from ..hash import sha1
def get_random_bytes(n=8): def get_random_bytes(n=8):
try: try:
b = os.urandom(n) b = os.urandom(n)
...@@ -53,9 +54,11 @@ def get_random_bytes(n=8): ...@@ -53,9 +54,11 @@ def get_random_bytes(n=8):
b = b"".join(L) b = b"".join(L)
return b return b
def hexdigest(s): def hexdigest(s):
return sha1(s.encode()).hexdigest() return sha1(s.encode()).hexdigest()
class DigestDatabase(Database): class DigestDatabase(Database):
def __init__(self, filename, realm=None): def __init__(self, filename, realm=None):
Database.__init__(self, filename, realm) Database.__init__(self, filename, realm)
...@@ -69,6 +72,7 @@ class DigestDatabase(Database): ...@@ -69,6 +72,7 @@ class DigestDatabase(Database):
dig = hexdigest("%s:%s:%s" % (username, self.realm, password)) dig = hexdigest("%s:%s:%s" % (username, self.realm, password))
self._users[username] = dig self._users[username] = dig
def session_key(h_up, nonce): def session_key(h_up, nonce):
# The hash itself is a bit too short to be a session key. # The hash itself is a bit too short to be a session key.
# HMAC wants a 64-byte key. We don't want to use h_up # HMAC wants a 64-byte key. We don't want to use h_up
...@@ -77,6 +81,7 @@ def session_key(h_up, nonce): ...@@ -77,6 +81,7 @@ def session_key(h_up, nonce):
return (sha1(("%s:%s" % (h_up, nonce)).encode('latin-1')).digest() + return (sha1(("%s:%s" % (h_up, nonce)).encode('latin-1')).digest() +
h_up.encode('utf-8')[:44]) h_up.encode('utf-8')[:44])
class StorageClass(ZEOStorage): class StorageClass(ZEOStorage):
def set_database(self, database): def set_database(self, database):
assert isinstance(database, DigestDatabase) assert isinstance(database, DigestDatabase)
...@@ -124,6 +129,7 @@ class StorageClass(ZEOStorage): ...@@ -124,6 +129,7 @@ class StorageClass(ZEOStorage):
extensions = [auth_get_challenge, auth_response] extensions = [auth_get_challenge, auth_response]
class DigestClient(Client): class DigestClient(Client):
extensions = ["auth_get_challenge", "auth_response"] extensions = ["auth_get_challenge", "auth_response"]
......
...@@ -22,6 +22,7 @@ from __future__ import print_function ...@@ -22,6 +22,7 @@ from __future__ import print_function
import os import os
from ..hash import sha1 from ..hash import sha1
class Client(object): class Client(object):
# Subclass should override to list the names of methods that # Subclass should override to list the names of methods that
# will be called on the server. # will be called on the server.
...@@ -32,11 +33,13 @@ class Client(object): ...@@ -32,11 +33,13 @@ class Client(object):
for m in self.extensions: for m in self.extensions:
setattr(self.stub, m, self.stub.extensionMethod(m)) setattr(self.stub, m, self.stub.extensionMethod(m))
def sort(L): def sort(L):
"""Sort a list in-place and return it.""" """Sort a list in-place and return it."""
L.sort() L.sort()
return L return L
class Database(object): class Database(object):
"""Abstracts a password database. """Abstracts a password database.
...@@ -49,6 +52,7 @@ class Database(object): ...@@ -49,6 +52,7 @@ class Database(object):
produced from the password string. produced from the password string.
""" """
realm = None realm = None
def __init__(self, filename, realm=None): def __init__(self, filename, realm=None):
"""Creates a new Database """Creates a new Database
......
...@@ -3,24 +3,26 @@ ...@@ -3,24 +3,26 @@
Implements the HMAC algorithm as described by RFC 2104. Implements the HMAC algorithm as described by RFC 2104.
""" """
from six.moves import map from six.moves import map
from six.moves import zip
def _strxor(s1, s2): def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length). """Utility method. XOR the two strings s1 and s2 (must have same length).
""" """
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2)) return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying # The size of the digests returned by HMAC depends on the underlying
# hashing module used. # hashing module used.
digest_size = None digest_size = None
class HMAC(object): class HMAC(object):
"""RFC2104 HMAC class. """RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247). This supports the API for Cryptographic Hash Functions (PEP 247).
""" """
def __init__(self, key, msg = None, digestmod = None): def __init__(self, key, msg=None, digestmod=None):
"""Create a new HMAC object. """Create a new HMAC object.
key: key for the keyed hash object. key: key for the keyed hash object.
...@@ -49,8 +51,8 @@ class HMAC(object): ...@@ -49,8 +51,8 @@ class HMAC(object):
if msg is not None: if msg is not None:
self.update(msg) self.update(msg)
## def clear(self): # def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.") # raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg): def update(self, msg):
"""Update this hashing object with the string msg. """Update this hashing object with the string msg.
...@@ -85,7 +87,8 @@ class HMAC(object): ...@@ -85,7 +87,8 @@ class HMAC(object):
return "".join([hex(ord(x))[2:].zfill(2) return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())]) for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
def new(key, msg=None, digestmod=None):
"""Create a new hashing object and return it. """Create a new hashing object and return it.
key: The starting key for the hash. key: The starting key for the hash.
......
...@@ -47,6 +47,7 @@ else: ...@@ -47,6 +47,7 @@ else:
if zeo_dist is not None: if zeo_dist is not None:
zeo_version = zeo_dist.version zeo_version = zeo_dist.version
class StorageStats(object): class StorageStats(object):
"""Per-storage usage statistics.""" """Per-storage usage statistics."""
...@@ -113,6 +114,7 @@ class StorageStats(object): ...@@ -113,6 +114,7 @@ class StorageStats(object):
print("Conflicts:", self.conflicts, file=f) print("Conflicts:", self.conflicts, file=f)
print("Conflicts resolved:", self.conflicts_resolved, file=f) print("Conflicts resolved:", self.conflicts_resolved, file=f)
class StatsClient(asyncore.dispatcher): class StatsClient(asyncore.dispatcher):
def __init__(self, sock, addr): def __init__(self, sock, addr):
...@@ -144,6 +146,7 @@ class StatsClient(asyncore.dispatcher): ...@@ -144,6 +146,7 @@ class StatsClient(asyncore.dispatcher):
if self.closed and not self.buf: if self.closed and not self.buf:
asyncore.dispatcher.close(self) asyncore.dispatcher.close(self)
class StatsServer(asyncore.dispatcher): class StatsServer(asyncore.dispatcher):
StatsConnectionClass = StatsClient StatsConnectionClass = StatsClient
......
...@@ -49,21 +49,24 @@ from zdaemon.zdoptions import ZDOptions ...@@ -49,21 +49,24 @@ from zdaemon.zdoptions import ZDOptions
logger = logging.getLogger('ZEO.runzeo') logger = logging.getLogger('ZEO.runzeo')
_pid = str(os.getpid()) _pid = str(os.getpid())
def log(msg, level=logging.INFO, exc_info=False): def log(msg, level=logging.INFO, exc_info=False):
"""Internal: generic logging function.""" """Internal: generic logging function."""
message = "(%s) %s" % (_pid, msg) message = "(%s) %s" % (_pid, msg)
logger.log(level, message, exc_info=exc_info) logger.log(level, message, exc_info=exc_info)
def parse_binding_address(arg): def parse_binding_address(arg):
# Caution: Not part of the official ZConfig API. # Caution: Not part of the official ZConfig API.
obj = ZConfig.datatypes.SocketBindingAddress(arg) obj = ZConfig.datatypes.SocketBindingAddress(arg)
return obj.family, obj.address return obj.family, obj.address
def windows_shutdown_handler(): def windows_shutdown_handler():
# Called by the signal mechanism on Windows to perform shutdown. # Called by the signal mechanism on Windows to perform shutdown.
import asyncore
asyncore.close_all() asyncore.close_all()
class ZEOOptionsMixin(object): class ZEOOptionsMixin(object):
storages = None storages = None
...@@ -76,13 +79,17 @@ class ZEOOptionsMixin(object): ...@@ -76,13 +79,17 @@ class ZEOOptionsMixin(object):
def handle_filename(self, arg): def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*! from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig(object): class FSConfig(object):
def __init__(self, name, path): def __init__(self, name, path):
self._name = name self._name = name
self.path = path self.path = path
self.stop = None self.stop = None
def getSectionName(self): def getSectionName(self):
return self._name return self._name
if not self.storages: if not self.storages:
self.storages = [] self.storages = []
name = str(1 + len(self.storages)) name = str(1 + len(self.storages))
...@@ -90,6 +97,7 @@ class ZEOOptionsMixin(object): ...@@ -90,6 +97,7 @@ class ZEOOptionsMixin(object):
self.storages.append(conf) self.storages.append(conf)
testing_exit_immediately = False testing_exit_immediately = False
def handle_test(self, *args): def handle_test(self, *args):
self.testing_exit_immediately = True self.testing_exit_immediately = True
...@@ -117,6 +125,7 @@ class ZEOOptionsMixin(object): ...@@ -117,6 +125,7 @@ class ZEOOptionsMixin(object):
self.add('pid_file', 'zeo.pid_filename', self.add('pid_file', 'zeo.pid_filename',
None, 'pid-file=') None, 'pid-file=')
class ZEOOptions(ZDOptions, ZEOOptionsMixin): class ZEOOptions(ZDOptions, ZEOOptionsMixin):
__doc__ = __doc__ __doc__ = __doc__
...@@ -179,8 +188,8 @@ class ZEOServer(object): ...@@ -179,8 +188,8 @@ class ZEOServer(object):
root.addHandler(handler) root.addHandler(handler)
def check_socket(self): def check_socket(self):
if (isinstance(self.options.address, tuple) and if isinstance(self.options.address, tuple) and \
self.options.address[1] is None): self.options.address[1] is None:
self.options.address = self.options.address[0], 0 self.options.address = self.options.address[0], 0
return return
if self.can_connect(self.options.family, self.options.address): if self.can_connect(self.options.family, self.options.address):
...@@ -275,7 +284,8 @@ class ZEOServer(object): ...@@ -275,7 +284,8 @@ class ZEOServer(object):
def handle_sigusr2(self): def handle_sigusr2(self):
# log rotation signal - do the same as Zope 2.7/2.8... # log rotation signal - do the same as Zope 2.7/2.8...
if self.options.config_logger is None or os.name not in ("posix", "nt"): if self.options.config_logger is None or \
os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!", log("received SIGUSR2, but it was not handled!",
level=logging.WARNING) level=logging.WARNING)
return return
...@@ -283,12 +293,12 @@ class ZEOServer(object): ...@@ -283,12 +293,12 @@ class ZEOServer(object):
loggers = [self.options.config_logger] loggers = [self.options.config_logger]
if os.name == "posix": if os.name == "posix":
for l in loggers: for logger in loggers:
l.reopen() logger.reopen()
log("Log files reopened successfully", level=logging.INFO) log("Log files reopened successfully", level=logging.INFO)
else: # nt - same rotation code as in Zope's Signals/Signals.py else: # nt - same rotation code as in Zope's Signals/Signals.py
for l in loggers: for logger in loggers:
for f in l.handler_factories: for f in logger.handler_factories:
handler = f() handler = f()
if hasattr(handler, 'rotate') and callable(handler.rotate): if hasattr(handler, 'rotate') and callable(handler.rotate):
handler.rotate() handler.rotate()
...@@ -347,14 +357,14 @@ def create_server(storages, options): ...@@ -347,14 +357,14 @@ def create_server(storages, options):
return StorageServer( return StorageServer(
options.address, options.address,
storages, storages,
read_only = options.read_only, read_only=options.read_only,
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size=options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age=options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout=options.transaction_timeout,
monitor_address = options.monitor_address, monitor_address=options.monitor_address,
auth_protocol = options.auth_protocol, auth_protocol=options.auth_protocol,
auth_database = options.auth_database, auth_database=options.auth_database,
auth_realm = options.auth_realm, auth_realm=options.auth_realm,
) )
...@@ -362,6 +372,7 @@ def create_server(storages, options): ...@@ -362,6 +372,7 @@ def create_server(storages, options):
signames = None signames = None
def signame(sig): def signame(sig):
"""Return a symbolic name for a signal. """Return a symbolic name for a signal.
...@@ -373,6 +384,7 @@ def signame(sig): ...@@ -373,6 +384,7 @@ def signame(sig):
init_signames() init_signames()
return signames.get(sig) or "signal %d" % sig return signames.get(sig) or "signal %d" % sig
def init_signames(): def init_signames():
global signames global signames
signames = {} signames = {}
...@@ -392,5 +404,6 @@ def main(args=None): ...@@ -392,5 +404,6 @@ def main(args=None):
s = ZEOServer(options) s = ZEOServer(options)
s.main() s.main()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -6,24 +6,26 @@ ...@@ -6,24 +6,26 @@
Implements the HMAC algorithm as described by RFC 2104. Implements the HMAC algorithm as described by RFC 2104.
""" """
from six.moves import map from six.moves import map
from six.moves import zip
def _strxor(s1, s2): def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length). """Utility method. XOR the two strings s1 and s2 (must have same length).
""" """
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2)) return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying # The size of the digests returned by HMAC depends on the underlying
# hashing module used. # hashing module used.
digest_size = None digest_size = None
class HMAC(object): class HMAC(object):
"""RFC2104 HMAC class. """RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247). This supports the API for Cryptographic Hash Functions (PEP 247).
""" """
def __init__(self, key, msg = None, digestmod = None): def __init__(self, key, msg=None, digestmod=None):
"""Create a new HMAC object. """Create a new HMAC object.
key: key for the keyed hash object. key: key for the keyed hash object.
...@@ -56,8 +58,8 @@ class HMAC(object): ...@@ -56,8 +58,8 @@ class HMAC(object):
if msg is not None: if msg is not None:
self.update(msg) self.update(msg)
## def clear(self): # def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.") # raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg): def update(self, msg):
"""Update this hashing object with the string msg. """Update this hashing object with the string msg.
...@@ -92,7 +94,8 @@ class HMAC(object): ...@@ -92,7 +94,8 @@ class HMAC(object):
return "".join([hex(ord(x))[2:].zfill(2) return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())]) for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
def new(key, msg=None, digestmod=None):
"""Create a new hashing object and return it. """Create a new hashing object and return it.
key: The starting key for the hash. key: The starting key for the hash.
......
...@@ -34,6 +34,7 @@ from six.moves import map ...@@ -34,6 +34,7 @@ from six.moves import map
def client_timeout(): def client_timeout():
return 30.0 return 30.0
def client_loop(map): def client_loop(map):
read = asyncore.read read = asyncore.read
write = asyncore.write write = asyncore.write
...@@ -52,7 +53,7 @@ def client_loop(map): ...@@ -52,7 +53,7 @@ def client_loop(map):
r, w, e = select.select(r, w, e, client_timeout()) r, w, e = select.select(r, w, e, client_timeout())
except (select.error, RuntimeError) as err: except (select.error, RuntimeError) as err:
# Python >= 3.3 makes select.error an alias of OSError, # Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute # which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0] err_errno = getattr(err, 'errno', None) or err[0]
if err_errno != errno.EINTR: if err_errno != errno.EINTR:
if err_errno == errno.EBADF: if err_errno == errno.EBADF:
...@@ -114,14 +115,13 @@ def client_loop(map): ...@@ -114,14 +115,13 @@ def client_loop(map):
continue continue
_exception(obj) _exception(obj)
except: except: # NOQA: E722 bare except
if map: if map:
try: try:
logging.getLogger(__name__+'.client_loop').critical( logging.getLogger(__name__+'.client_loop').critical(
'A ZEO client loop failed.', 'A ZEO client loop failed.',
exc_info=sys.exc_info()) exc_info=sys.exc_info())
except: except: # NOQA: E722 bare except
pass pass
for fd, obj in map.items(): for fd, obj in map.items():
...@@ -129,14 +129,14 @@ def client_loop(map): ...@@ -129,14 +129,14 @@ def client_loop(map):
continue continue
try: try:
obj.mgr.client.close() obj.mgr.client.close()
except: except: # NOQA: E722 bare except
map.pop(fd, None) map.pop(fd, None)
try: try:
logging.getLogger(__name__+'.client_loop' logging.getLogger(__name__+'.client_loop'
).critical( ).critical(
"Couldn't close a dispatcher.", "Couldn't close a dispatcher.",
exc_info=sys.exc_info()) exc_info=sys.exc_info())
except: except: # NOQA: E722 bare except
pass pass
...@@ -189,7 +189,8 @@ class ConnectionManager(object): ...@@ -189,7 +189,8 @@ class ConnectionManager(object):
for addr in addrs: for addr in addrs:
addr_type = self._guess_type(addr) addr_type = self._guess_type(addr)
if addr_type is None: if addr_type is None:
raise ValueError("unknown address in list: %s" % repr(addr)) raise ValueError(
"unknown address in list: %s" % repr(addr))
addrlist.append((addr_type, addr)) addrlist.append((addr_type, addr))
return addrlist return addrlist
...@@ -197,9 +198,9 @@ class ConnectionManager(object): ...@@ -197,9 +198,9 @@ class ConnectionManager(object):
if isinstance(addr, str): if isinstance(addr, str):
return socket.AF_UNIX return socket.AF_UNIX
if (len(addr) == 2 if len(addr) == 2 and \
and isinstance(addr[0], str) isinstance(addr[0], str) and \
and isinstance(addr[1], int)): isinstance(addr[1], int):
return socket.AF_INET # also denotes IPv6 return socket.AF_INET # also denotes IPv6
# not anything I know about # not anything I know about
...@@ -226,7 +227,7 @@ class ConnectionManager(object): ...@@ -226,7 +227,7 @@ class ConnectionManager(object):
if obj is not self.trigger: if obj is not self.trigger:
try: try:
obj.close() obj.close()
except: except: # NOQA: E722 bare except
logging.getLogger(__name__+'.'+self.__class__.__name__ logging.getLogger(__name__+'.'+self.__class__.__name__
).critical( ).critical(
"Couldn't close a dispatcher.", "Couldn't close a dispatcher.",
...@@ -331,6 +332,7 @@ class ConnectionManager(object): ...@@ -331,6 +332,7 @@ class ConnectionManager(object):
finally: finally:
self.cond.release() self.cond.release()
# When trying to do a connect on a non-blocking socket, some outcomes # When trying to do a connect on a non-blocking socket, some outcomes
# are expected. Set _CONNECT_IN_PROGRESS to the errno value(s) expected # are expected. Set _CONNECT_IN_PROGRESS to the errno value(s) expected
# when an initial connect can't complete immediately. Set _CONNECT_OK # when an initial connect can't complete immediately. Set _CONNECT_OK
...@@ -347,6 +349,7 @@ else: # Unix ...@@ -347,6 +349,7 @@ else: # Unix
_CONNECT_IN_PROGRESS = (errno.EINPROGRESS,) _CONNECT_IN_PROGRESS = (errno.EINPROGRESS,)
_CONNECT_OK = (0, errno.EISCONN) _CONNECT_OK = (0, errno.EISCONN)
class ConnectThread(threading.Thread): class ConnectThread(threading.Thread):
"""Thread that tries to connect to server given one or more addresses. """Thread that tries to connect to server given one or more addresses.
...@@ -495,7 +498,7 @@ class ConnectThread(threading.Thread): ...@@ -495,7 +498,7 @@ class ConnectThread(threading.Thread):
break break
try: try:
r, w, x = select.select([], connecting, connecting, 1.0) r, w, x = select.select([], connecting, connecting, 1.0)
log("CT: select() %d, %d, %d" % tuple(map(len, (r,w,x)))) log("CT: select() %d, %d, %d" % tuple(map(len, (r, w, x))))
except select.error as msg: except select.error as msg:
log("CT: select failed; msg=%s" % str(msg), log("CT: select failed; msg=%s" % str(msg),
level=logging.WARNING) level=logging.WARNING)
...@@ -610,7 +613,7 @@ class ConnectWrapper(object): ...@@ -610,7 +613,7 @@ class ConnectWrapper(object):
log("CW: ReadOnlyError in testConnection (%s)" % repr(self.addr)) log("CW: ReadOnlyError in testConnection (%s)" % repr(self.addr))
self.close() self.close()
return return
except: except: # NOQA: E722 bare except
log("CW: error in testConnection (%s)" % repr(self.addr), log("CW: error in testConnection (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True) level=logging.ERROR, exc_info=True)
self.close() self.close()
...@@ -629,7 +632,7 @@ class ConnectWrapper(object): ...@@ -629,7 +632,7 @@ class ConnectWrapper(object):
""" """
try: try:
self.client.notifyConnected(self.conn) self.client.notifyConnected(self.conn)
except: except: # NOQA: E722 bare except
log("CW: error in notifyConnected (%s)" % repr(self.addr), log("CW: error in notifyConnected (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True) level=logging.ERROR, exc_info=True)
self.close() self.close()
......
...@@ -32,6 +32,7 @@ exception_type_type = type(Exception) ...@@ -32,6 +32,7 @@ exception_type_type = type(Exception)
debug_zrpc = False debug_zrpc = False
class Delay(object): class Delay(object):
"""Used to delay response to client for synchronous calls. """Used to delay response to client for synchronous calls.
...@@ -57,7 +58,9 @@ class Delay(object): ...@@ -57,7 +58,9 @@ class Delay(object):
def __repr__(self): def __repr__(self):
return "%s[%s, %r, %r, %r]" % ( return "%s[%s, %r, %r, %r]" % (
self.__class__.__name__, id(self), self.msgid, self.conn, self.sent) self.__class__.__name__, id(self), self.msgid,
self.conn, self.sent)
class Result(Delay): class Result(Delay):
...@@ -69,6 +72,7 @@ class Result(Delay): ...@@ -69,6 +72,7 @@ class Result(Delay):
conn.send_reply(msgid, reply, False) conn.send_reply(msgid, reply, False)
callback() callback()
class MTDelay(Delay): class MTDelay(Delay):
def __init__(self): def __init__(self):
...@@ -147,6 +151,7 @@ class MTDelay(Delay): ...@@ -147,6 +151,7 @@ class MTDelay(Delay):
# supply a handshake() method appropriate for their role in protocol # supply a handshake() method appropriate for their role in protocol
# negotiation. # negotiation.
class Connection(smac.SizedMessageAsyncConnection, object): class Connection(smac.SizedMessageAsyncConnection, object):
"""Dispatcher for RPC on object on both sides of socket. """Dispatcher for RPC on object on both sides of socket.
...@@ -441,7 +446,7 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -441,7 +446,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
try: try:
self.message_output(self.fast_encode(msgid, 0, REPLY, ret)) self.message_output(self.fast_encode(msgid, 0, REPLY, ret))
self.poll() self.poll()
except: except: # NOQA: E722 bare except
# Fall back to normal version for better error handling # Fall back to normal version for better error handling
self.send_reply(msgid, ret) self.send_reply(msgid, ret)
...@@ -520,10 +525,10 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -520,10 +525,10 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# cPickle may raise. # cPickle may raise.
try: try:
msg = self.encode(msgid, 0, REPLY, (err_type, err_value)) msg = self.encode(msgid, 0, REPLY, (err_type, err_value))
except: # see above except: # NOQA: E722 bare except; see above
try: try:
r = short_repr(err_value) r = short_repr(err_value)
except: except: # NOQA: E722 bare except
r = "<unreprable>" r = "<unreprable>"
err = ZRPCError("Couldn't pickle error %.100s" % r) err = ZRPCError("Couldn't pickle error %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err)) msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
...@@ -656,10 +661,10 @@ class ManagedServerConnection(Connection): ...@@ -656,10 +661,10 @@ class ManagedServerConnection(Connection):
# cPickle may raise. # cPickle may raise.
try: try:
msg = self.encode(msgid, 0, REPLY, ret) msg = self.encode(msgid, 0, REPLY, ret)
except: # see above except: # NOQA: E722 bare except; see above
try: try:
r = short_repr(ret) r = short_repr(ret)
except: except: # NOQA: E722 bare except
r = "<unreprable>" r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r) err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err)) msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
...@@ -669,6 +674,7 @@ class ManagedServerConnection(Connection): ...@@ -669,6 +674,7 @@ class ManagedServerConnection(Connection):
poll = smac.SizedMessageAsyncConnection.handle_write poll = smac.SizedMessageAsyncConnection.handle_write
def server_loop(map): def server_loop(map):
while len(map) > 1: while len(map) > 1:
try: try:
...@@ -680,6 +686,7 @@ def server_loop(map): ...@@ -680,6 +686,7 @@ def server_loop(map):
for o in tuple(map.values()): for o in tuple(map.values()):
o.close() o.close()
class ManagedClientConnection(Connection): class ManagedClientConnection(Connection):
"""Client-side Connection subclass.""" """Client-side Connection subclass."""
__super_init = Connection.__init__ __super_init = Connection.__init__
...@@ -778,9 +785,9 @@ class ManagedClientConnection(Connection): ...@@ -778,9 +785,9 @@ class ManagedClientConnection(Connection):
raise DisconnectedError() raise DisconnectedError()
msgid = self.send_call(method, args) msgid = self.send_call(method, args)
r_args = self.wait(msgid) r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1 if isinstance(r_args, tuple) and len(r_args) > 1 and \
and type(r_args[0]) == exception_type_type type(r_args[0]) == exception_type_type and \
and issubclass(r_args[0], Exception)): issubclass(r_args[0], Exception):
inst = r_args[1] inst = r_args[1]
raise inst # error raised by server raise inst # error raised by server
else: else:
...@@ -821,9 +828,9 @@ class ManagedClientConnection(Connection): ...@@ -821,9 +828,9 @@ class ManagedClientConnection(Connection):
def _deferred_wait(self, msgid): def _deferred_wait(self, msgid):
r_args = self.wait(msgid) r_args = self.wait(msgid)
if (isinstance(r_args, tuple) if isinstance(r_args, tuple) and \
and type(r_args[0]) == exception_type_type type(r_args[0]) == exception_type_type and \
and issubclass(r_args[0], Exception)): issubclass(r_args[0], Exception):
inst = r_args[1] inst = r_args[1]
raise inst # error raised by server raise inst # error raised by server
else: else:
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
from ZODB import POSException from ZODB import POSException
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
class ZRPCError(POSException.StorageError): class ZRPCError(POSException.StorageError):
pass pass
class DisconnectedError(ZRPCError, ClientDisconnected): class DisconnectedError(ZRPCError, ClientDisconnected):
"""The database storage is disconnected from the storage server. """The database storage is disconnected from the storage server.
......
...@@ -17,24 +17,29 @@ import logging ...@@ -17,24 +17,29 @@ import logging
from ZODB.loglevels import BLATHER from ZODB.loglevels import BLATHER
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
logger = logging.getLogger('ZEO.zrpc') logger = logging.getLogger('ZEO.zrpc')
_label = "%s" % os.getpid() _label = "%s" % os.getpid()
def new_label(): def new_label():
global _label global _label
_label = str(os.getpid()) _label = str(os.getpid())
def log(message, level=BLATHER, label=None, exc_info=False): def log(message, level=BLATHER, label=None, exc_info=False):
label = label or _label label = label or _label
if LOG_THREAD_ID: if LOG_THREAD_ID:
label = label + ':' + threading.currentThread().getName() label = label + ':' + threading.currentThread().getName()
logger.log(level, '(%s) %s' % (label, message), exc_info=exc_info) logger.log(level, '(%s) %s' % (label, message), exc_info=exc_info)
REPR_LIMIT = 60 REPR_LIMIT = 60
def short_repr(obj): def short_repr(obj):
"Return an object repr limited to REPR_LIMIT bytes." "Return an object repr limited to REPR_LIMIT bytes."
......
...@@ -19,6 +19,7 @@ from .log import log, short_repr ...@@ -19,6 +19,7 @@ from .log import log, short_repr
PY2 = not PY3 PY2 = not PY3
def encode(*args): # args: (msgid, flags, name, args) def encode(*args): # args: (msgid, flags, name, args)
# (We used to have a global pickler, but that's not thread-safe. :-( ) # (We used to have a global pickler, but that's not thread-safe. :-( )
...@@ -41,7 +42,6 @@ def encode(*args): # args: (msgid, flags, name, args) ...@@ -41,7 +42,6 @@ def encode(*args): # args: (msgid, flags, name, args)
return res return res
if PY3: if PY3:
# XXX: Py3: Needs optimization. # XXX: Py3: Needs optimization.
fast_encode = encode fast_encode = encode
...@@ -50,48 +50,57 @@ elif PYPY: ...@@ -50,48 +50,57 @@ elif PYPY:
# every time, getvalue() only works once # every time, getvalue() only works once
fast_encode = encode fast_encode = encode
else: else:
def fast_encode(): def fast_encode():
# Only use in cases where you *know* the data contains only basic # Only use in cases where you *know* the data contains only basic
# Python objects # Python objects
pickler = Pickler(1) pickler = Pickler(1)
pickler.fast = 1 pickler.fast = 1
dump = pickler.dump dump = pickler.dump
def fast_encode(*args): def fast_encode(*args):
return dump(args, 1) return dump(args, 1)
return fast_encode return fast_encode
fast_encode = fast_encode() fast_encode = fast_encode()
def decode(msg): def decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global unpickler.find_global = find_global
try: try:
unpickler.find_class = find_global # PyPy, zodbpickle, the non-c-accelerated version # PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = find_global
except AttributeError: except AttributeError:
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except: # NOQA: E722 bare except
log("can't decode message: %s" % short_repr(msg), log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR) level=logging.ERROR)
raise raise
def server_decode(msg): def server_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global unpickler.find_global = server_find_global
try: try:
unpickler.find_class = server_find_global # PyPy, zodbpickle, the non-c-accelerated version # PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = server_find_global
except AttributeError: except AttributeError:
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except: # NOQA: E722 bare except
log("can't decode message: %s" % short_repr(msg), log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR) level=logging.ERROR)
raise raise
_globals = globals() _globals = globals()
_silly = ('__doc__',) _silly = ('__doc__',)
...@@ -102,6 +111,7 @@ _SAFE_MODULE_NAMES = ( ...@@ -102,6 +111,7 @@ _SAFE_MODULE_NAMES = (
'builtins', 'copy_reg', '__builtin__', 'builtins', 'copy_reg', '__builtin__',
) )
def find_global(module, name): def find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
try: try:
...@@ -114,7 +124,8 @@ def find_global(module, name): ...@@ -114,7 +124,8 @@ def find_global(module, name):
except AttributeError: except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name)) raise ZRPCError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) or (PY2 and module in _SAFE_MODULE_NAMES) safe = (getattr(r, '__no_side_effects__', 0) or
(PY2 and module in _SAFE_MODULE_NAMES))
if safe: if safe:
return r return r
...@@ -124,6 +135,7 @@ def find_global(module, name): ...@@ -124,6 +135,7 @@ def find_global(module, name):
raise ZRPCError("Unsafe global: %s.%s" % (module, name)) raise ZRPCError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name): def server_find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
if module not in _SAFE_MODULE_NAMES: if module not in _SAFE_MODULE_NAMES:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
############################################################################## ##############################################################################
import asyncore import asyncore
import socket import socket
import time
# _has_dualstack: True if the dual-stack sockets are supported # _has_dualstack: True if the dual-stack sockets are supported
try: try:
...@@ -39,6 +40,7 @@ import logging ...@@ -39,6 +40,7 @@ import logging
# Export the main asyncore loop # Export the main asyncore loop
loop = asyncore.loop loop = asyncore.loop
class Dispatcher(asyncore.dispatcher): class Dispatcher(asyncore.dispatcher):
"""A server that accepts incoming RPC connections""" """A server that accepts incoming RPC connections"""
__super_init = asyncore.dispatcher.__init__ __super_init = asyncore.dispatcher.__init__
...@@ -74,7 +76,7 @@ class Dispatcher(asyncore.dispatcher): ...@@ -74,7 +76,7 @@ class Dispatcher(asyncore.dispatcher):
for i in range(25): for i in range(25):
try: try:
self.bind(self.addr) self.bind(self.addr)
except Exception as exc: except Exception:
log("bind failed %s waiting", i) log("bind failed %s waiting", i)
if i == 24: if i == 24:
raise raise
...@@ -98,7 +100,6 @@ class Dispatcher(asyncore.dispatcher): ...@@ -98,7 +100,6 @@ class Dispatcher(asyncore.dispatcher):
log("accepted failed: %s" % msg) log("accepted failed: %s" % msg)
return return
# We could short-circuit the attempt below in some edge cases # We could short-circuit the attempt below in some edge cases
# and avoid a log message by checking for addr being None. # and avoid a log message by checking for addr being None.
# Unfortunately, our test for the code below, # Unfortunately, our test for the code below,
...@@ -116,7 +117,7 @@ class Dispatcher(asyncore.dispatcher): ...@@ -116,7 +117,7 @@ class Dispatcher(asyncore.dispatcher):
try: try:
c = self.factory(sock, addr) c = self.factory(sock, addr)
except: except: # NOQA: E722 bare except
if sock.fileno() in asyncore.socket_map: if sock.fileno() in asyncore.socket_map:
del asyncore.socket_map[sock.fileno()] del asyncore.socket_map[sock.fileno()]
logger.exception("Error in handle_accept") logger.exception("Error in handle_accept")
......
...@@ -67,6 +67,7 @@ MAC_BIT = 0x80000000 ...@@ -67,6 +67,7 @@ MAC_BIT = 0x80000000
_close_marker = object() _close_marker = object()
class SizedMessageAsyncConnection(asyncore.dispatcher): class SizedMessageAsyncConnection(asyncore.dispatcher):
__super_init = asyncore.dispatcher.__init__ __super_init = asyncore.dispatcher.__init__
__super_close = asyncore.dispatcher.close __super_close = asyncore.dispatcher.close
...@@ -168,7 +169,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher): ...@@ -168,7 +169,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
d = self.recv(8192) d = self.recv(8192)
except socket.error as err: except socket.error as err:
# Python >= 3.3 makes select.error an alias of OSError, # Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute # which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0] err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_read_errors: if err_errno in expected_socket_read_errors:
return return
...@@ -298,7 +299,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher): ...@@ -298,7 +299,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
# ensure the above mentioned "output" invariant # ensure the above mentioned "output" invariant
output.insert(0, v) output.insert(0, v)
# Python >= 3.3 makes select.error an alias of OSError, # Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute # which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0] err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_write_errors: if err_errno in expected_socket_write_errors:
break # we couldn't write anything break # we couldn't write anything
......
...@@ -21,7 +21,7 @@ import socket ...@@ -21,7 +21,7 @@ import socket
import errno import errno
from ZODB.utils import positive_id from ZODB.utils import positive_id
from ZEO._compat import thread, get_ident from ZEO._compat import thread
# Original comments follow; they're hard to follow in the context of # Original comments follow; they're hard to follow in the context of
# ZEO's use of triggers. TODO: rewrite from a ZEO perspective. # ZEO's use of triggers. TODO: rewrite from a ZEO perspective.
...@@ -56,6 +56,7 @@ from ZEO._compat import thread, get_ident ...@@ -56,6 +56,7 @@ from ZEO._compat import thread, get_ident
# new data onto a channel's outgoing data queue at the same time that # new data onto a channel's outgoing data queue at the same time that
# the main thread is trying to remove some] # the main thread is trying to remove some]
class _triggerbase(object): class _triggerbase(object):
"""OS-independent base class for OS-dependent trigger class.""" """OS-independent base class for OS-dependent trigger class."""
...@@ -127,7 +128,7 @@ class _triggerbase(object): ...@@ -127,7 +128,7 @@ class _triggerbase(object):
return return
try: try:
thunk[0](*thunk[1:]) thunk[0](*thunk[1:])
except: except: # NOQA: E722 bare except
nil, t, v, tbinfo = asyncore.compact_traceback() nil, t, v, tbinfo = asyncore.compact_traceback()
print(('exception in trigger thunk:' print(('exception in trigger thunk:'
' (%s:%s %s)' % (t, v, tbinfo))) ' (%s:%s %s)' % (t, v, tbinfo)))
...@@ -135,6 +136,7 @@ class _triggerbase(object): ...@@ -135,6 +136,7 @@ class _triggerbase(object):
def __repr__(self): def __repr__(self):
return '<select-trigger (%s) at %x>' % (self.kind, positive_id(self)) return '<select-trigger (%s) at %x>' % (self.kind, positive_id(self))
if os.name == 'posix': if os.name == 'posix':
class trigger(_triggerbase, asyncore.file_dispatcher): class trigger(_triggerbase, asyncore.file_dispatcher):
......
...@@ -16,7 +16,6 @@ from __future__ import print_function ...@@ -16,7 +16,6 @@ from __future__ import print_function
import random import random
import sys
import time import time
...@@ -56,8 +55,8 @@ def encode_format(fmt): ...@@ -56,8 +55,8 @@ def encode_format(fmt):
fmt = fmt.replace(*xform) fmt = fmt.replace(*xform)
return fmt return fmt
runner = _forker.runner
runner = _forker.runner
stop_runner = _forker.stop_runner stop_runner = _forker.stop_runner
start_zeo_server = _forker.start_zeo_server start_zeo_server = _forker.start_zeo_server
...@@ -70,6 +69,7 @@ else: ...@@ -70,6 +69,7 @@ else:
shutdown_zeo_server = _forker.shutdown_zeo_server shutdown_zeo_server = _forker.shutdown_zeo_server
def get_port(ignored=None): def get_port(ignored=None):
"""Return a port that is not in use. """Return a port that is not in use.
...@@ -107,6 +107,7 @@ def get_port(ignored=None): ...@@ -107,6 +107,7 @@ def get_port(ignored=None):
s1.close() s1.close()
raise RuntimeError("Can't find port") raise RuntimeError("Can't find port")
def can_connect(port): def can_connect(port):
c = socket.socket(socket.AF_INET, socket.SOCK_STREAM) c = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try: try:
...@@ -119,6 +120,7 @@ def can_connect(port): ...@@ -119,6 +120,7 @@ def can_connect(port):
finally: finally:
c.close() c.close()
def setUp(test): def setUp(test):
ZODB.tests.util.setUp(test) ZODB.tests.util.setUp(test)
...@@ -194,9 +196,11 @@ def wait_until(label=None, func=None, timeout=30, onfail=None): ...@@ -194,9 +196,11 @@ def wait_until(label=None, func=None, timeout=30, onfail=None):
return onfail() return onfail()
time.sleep(0.01) time.sleep(0.01)
def wait_connected(storage): def wait_connected(storage):
wait_until("storage is connected", storage.is_connected) wait_until("storage is connected", storage.is_connected)
def wait_disconnected(storage): def wait_disconnected(storage):
wait_until("storage is disconnected", wait_until("storage is disconnected",
lambda: not storage.is_connected()) lambda: not storage.is_connected())
......
...@@ -34,6 +34,7 @@ import ZEO.asyncio.tests ...@@ -34,6 +34,7 @@ import ZEO.asyncio.tests
import ZEO.StorageServer import ZEO.StorageServer
import ZODB.MappingStorage import ZODB.MappingStorage
class StorageServer(ZEO.StorageServer.StorageServer): class StorageServer(ZEO.StorageServer.StorageServer):
def __init__(self, addr='test_addr', storages=None, **kw): def __init__(self, addr='test_addr', storages=None, **kw):
...@@ -41,6 +42,7 @@ class StorageServer(ZEO.StorageServer.StorageServer): ...@@ -41,6 +42,7 @@ class StorageServer(ZEO.StorageServer.StorageServer):
storages = {'1': ZODB.MappingStorage.MappingStorage()} storages = {'1': ZODB.MappingStorage.MappingStorage()}
ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw) ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw)
def client(server, name='client'): def client(server, name='client'):
zs = ZEO.StorageServer.ZEOStorage(server) zs = ZEO.StorageServer.ZEOStorage(server)
protocol = ZEO.asyncio.tests.server_protocol( protocol = ZEO.asyncio.tests.server_protocol(
......
...@@ -18,7 +18,21 @@ from __future__ import print_function ...@@ -18,7 +18,21 @@ from __future__ import print_function
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
usage="""Test speed of a ZODB storage
import asyncore
import getopt
import os
import sys
import time
import persistent
import transaction
import ZODB
from ZODB.POSException import ConflictError
from ZEO.tests import forker
usage = """Test speed of a ZODB storage
Options: Options:
...@@ -48,41 +62,40 @@ Options: ...@@ -48,41 +62,40 @@ Options:
-t n Number of concurrent threads to run. -t n Number of concurrent threads to run.
""" """
import asyncore
import sys, os, getopt, time
##sys.path.insert(0, os.getcwd())
import persistent
import transaction
import ZODB
from ZODB.POSException import ConflictError
from ZEO.tests import forker
class P(persistent.Persistent): class P(persistent.Persistent):
pass pass
fs_name = "zeo-speed.fs" fs_name = "zeo-speed.fs"
class ZEOExit(asyncore.file_dispatcher): class ZEOExit(asyncore.file_dispatcher):
"""Used to exit ZEO.StorageServer when run is done""" """Used to exit ZEO.StorageServer when run is done"""
def writable(self): def writable(self):
return 0 return 0
def readable(self): def readable(self):
return 1 return 1
def handle_read(self): def handle_read(self):
buf = self.recv(4) buf = self.recv(4)
assert buf == "done" assert buf == "done"
self.delete_fs() self.delete_fs()
os._exit(0) os._exit(0)
def handle_close(self): def handle_close(self):
print("Parent process exited unexpectedly") print("Parent process exited unexpectedly")
self.delete_fs() self.delete_fs()
os._exit(0) os._exit(0)
def delete_fs(self): def delete_fs(self):
os.unlink(fs_name) os.unlink(fs_name)
os.unlink(fs_name + ".lock") os.unlink(fs_name + ".lock")
os.unlink(fs_name + ".tmp") os.unlink(fs_name + ".tmp")
def work(db, results, nrep, compress, data, detailed, minimize, threadno=None): def work(db, results, nrep, compress, data, detailed, minimize, threadno=None):
for j in range(nrep): for j in range(nrep):
for r in 1, 10, 100, 1000: for r in 1, 10, 100, 1000:
...@@ -98,7 +111,7 @@ def work(db, results, nrep, compress, data, detailed, minimize, threadno=None): ...@@ -98,7 +111,7 @@ def work(db, results, nrep, compress, data, detailed, minimize, threadno=None):
if key in rt: if key in rt:
p = rt[key] p = rt[key]
else: else:
rt[key] = p =P() rt[key] = p = P()
for i in range(r): for i in range(r):
v = getattr(p, str(i), P()) v = getattr(p, str(i), P())
if compress is not None: if compress is not None:
...@@ -121,46 +134,49 @@ def work(db, results, nrep, compress, data, detailed, minimize, threadno=None): ...@@ -121,46 +134,49 @@ def work(db, results, nrep, compress, data, detailed, minimize, threadno=None):
print("%s\t%s\t%.4f\t%d\t%d" % (j, r, t, conflicts, print("%s\t%s\t%.4f\t%d\t%d" % (j, r, t, conflicts,
threadno)) threadno))
results[r].append((t, conflicts)) results[r].append((t, conflicts))
rt=d=p=v=None # release all references rt = p = v = None # release all references
if minimize: if minimize:
time.sleep(3) time.sleep(3)
jar.cacheMinimize() jar.cacheMinimize()
def main(args): def main(args):
opts, args = getopt.getopt(args, 'zd:n:Ds:LMt:U') opts, args = getopt.getopt(args, 'zd:n:Ds:LMt:U')
s = None s = None
compress = None compress = None
data=sys.argv[0] data = sys.argv[0]
nrep=5 nrep = 5
minimize=0 minimize = 0
detailed=1 detailed = 1
cache = None cache = None
domain = 'AF_INET' domain = 'AF_INET'
threads = 1 threads = 1
for o, v in opts: for o, v in opts:
if o=='-n': nrep = int(v) if o == '-n':
elif o=='-d': data = v nrep = int(v)
elif o=='-s': s = v elif o == '-d':
elif o=='-z': data = v
elif o == '-s':
s = v
elif o == '-z':
import zlib import zlib
compress = zlib.compress compress = zlib.compress
elif o=='-L': elif o == '-L':
minimize=1 minimize = 1
elif o=='-M': elif o == '-M':
detailed=0 detailed = 0
elif o=='-D': elif o == '-D':
global debug global debug
os.environ['STUPID_LOG_FILE']='' os.environ['STUPID_LOG_FILE'] = ''
os.environ['STUPID_LOG_SEVERITY']='-999' os.environ['STUPID_LOG_SEVERITY'] = '-999'
debug = 1 debug = 1
elif o == '-C': elif o == '-C':
cache = 'speed' cache = 'speed' # NOQA: F841 unused variable
elif o == '-U': elif o == '-U':
domain = 'AF_UNIX' domain = 'AF_UNIX'
elif o == '-t': elif o == '-t':
threads = int(v) threads = int(v)
zeo_pipe = None
if s: if s:
s = __import__(s, globals(), globals(), ('__doc__',)) s = __import__(s, globals(), globals(), ('__doc__',))
s = s.Storage s = s.Storage
...@@ -169,25 +185,25 @@ def main(args): ...@@ -169,25 +185,25 @@ def main(args):
s, server, pid = forker.start_zeo("FileStorage", s, server, pid = forker.start_zeo("FileStorage",
(fs_name, 1), domain=domain) (fs_name, 1), domain=domain)
data=open(data).read() data = open(data).read()
db=ZODB.DB(s, db = ZODB.DB(s,
# disable cache deactivation # disable cache deactivation
cache_size=4000, cache_size=4000,
cache_deactivate_after=6000,) cache_deactivate_after=6000)
print("Beginning work...") print("Beginning work...")
results={1:[], 10:[], 100:[], 1000:[]} results = {1: [], 10: [], 100: [], 1000: []}
if threads > 1: if threads > 1:
import threading import threading
l = [] thread_list = []
for i in range(threads): for i in range(threads):
t = threading.Thread(target=work, t = threading.Thread(target=work,
args=(db, results, nrep, compress, data, args=(db, results, nrep, compress, data,
detailed, minimize, i)) detailed, minimize, i))
l.append(t) thread_list.append(t)
for t in l: for t in thread_list:
t.start() t.start()
for t in l: for t in thread_list:
t.join() t.join()
else: else:
...@@ -202,21 +218,24 @@ def main(args): ...@@ -202,21 +218,24 @@ def main(args):
print("num\tmean\tmin\tmax") print("num\tmean\tmin\tmax")
for r in 1, 10, 100, 1000: for r in 1, 10, 100, 1000:
times = [] times = []
for time, conf in results[r]: for time_val, conf in results[r]:
times.append(time) times.append(time_val)
t = mean(times) t = mean(times)
print("%d\t%.4f\t%.4f\t%.4f" % (r, t, min(times), max(times))) print("%d\t%.4f\t%.4f\t%.4f" % (r, t, min(times), max(times)))
def mean(l):
def mean(lst):
tot = 0 tot = 0
for v in l: for v in lst:
tot = tot + v tot = tot + v
return tot / len(l) return tot / len(lst)
# def compress(s):
# c = zlib.compressobj()
# o = c.compress(s)
# return o + c.flush()
##def compress(s):
## c = zlib.compressobj()
## o = c.compress(s)
## return o + c.flush()
if __name__=='__main__': if __name__ == '__main__':
main(sys.argv[1:]) main(sys.argv[1:])
...@@ -36,20 +36,22 @@ MAX_DEPTH = 20 ...@@ -36,20 +36,22 @@ MAX_DEPTH = 20
MIN_OBJSIZE = 128 MIN_OBJSIZE = 128
MAX_OBJSIZE = 2048 MAX_OBJSIZE = 2048
def an_object(): def an_object():
"""Return an object suitable for a PersistentMapping key""" """Return an object suitable for a PersistentMapping key"""
size = random.randrange(MIN_OBJSIZE, MAX_OBJSIZE) size = random.randrange(MIN_OBJSIZE, MAX_OBJSIZE)
if os.path.exists("/dev/urandom"): if os.path.exists("/dev/urandom"):
f = open("/dev/urandom") fp = open("/dev/urandom")
buf = f.read(size) buf = fp.read(size)
f.close() fp.close()
return buf return buf
else: else:
f = open(MinPO.__file__) fp = open(MinPO.__file__)
l = list(f.read(size)) lst = list(fp.read(size))
f.close() fp.close()
random.shuffle(l) random.shuffle(lst)
return "".join(l) return "".join(lst)
def setup(cn): def setup(cn):
"""Initialize the database with some objects""" """Initialize the database with some objects"""
...@@ -63,6 +65,7 @@ def setup(cn): ...@@ -63,6 +65,7 @@ def setup(cn):
transaction.commit() transaction.commit()
cn.close() cn.close()
def work(cn): def work(cn):
"""Do some work with a transaction""" """Do some work with a transaction"""
cn.sync() cn.sync()
...@@ -74,11 +77,13 @@ def work(cn): ...@@ -74,11 +77,13 @@ def work(cn):
obj.value = an_object() obj.value = an_object()
transaction.commit() transaction.commit()
def main(): def main():
# Yuck! Need to cleanup forker so that the API is consistent # Yuck! Need to cleanup forker so that the API is consistent
# across Unix and Windows, at least if that's possible. # across Unix and Windows, at least if that's possible.
if os.name == "nt": if os.name == "nt":
zaddr, tport, pid = forker.start_zeo_server('MappingStorage', ()) zaddr, tport, pid = forker.start_zeo_server('MappingStorage', ())
def exitserver(): def exitserver():
import socket import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
...@@ -87,6 +92,7 @@ def main(): ...@@ -87,6 +92,7 @@ def main():
else: else:
zaddr = '', random.randrange(20000, 30000) zaddr = '', random.randrange(20000, 30000)
pid, exitobj = forker.start_zeo_server(MappingStorage(), zaddr) pid, exitobj = forker.start_zeo_server(MappingStorage(), zaddr)
def exitserver(): def exitserver():
exitobj.close() exitobj.close()
...@@ -97,6 +103,7 @@ def main(): ...@@ -97,6 +103,7 @@ def main():
exitserver() exitserver()
def start_child(zaddr): def start_child(zaddr):
pid = os.fork() pid = os.fork()
...@@ -107,6 +114,7 @@ def start_child(zaddr): ...@@ -107,6 +114,7 @@ def start_child(zaddr):
finally: finally:
os._exit(0) os._exit(0)
def _start_child(zaddr): def _start_child(zaddr):
storage = ClientStorage(zaddr, debug=1, min_disconnect_poll=0.5, wait=1) storage = ClientStorage(zaddr, debug=1, min_disconnect_poll=0.5, wait=1)
db = ZODB.DB(storage, pool_size=NUM_CONNECTIONS) db = ZODB.DB(storage, pool_size=NUM_CONNECTIONS)
...@@ -133,5 +141,6 @@ def _start_child(zaddr): ...@@ -133,5 +141,6 @@ def _start_child(zaddr):
c.__count += 1 c.__count += 1
work(c) work(c)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -21,6 +21,7 @@ from ZODB.config import storageFromString ...@@ -21,6 +21,7 @@ from ZODB.config import storageFromString
from .forker import start_zeo_server from .forker import start_zeo_server
from .threaded import threaded_server_tests from .threaded import threaded_server_tests
class ZEOConfigTestBase(setupstack.TestCase): class ZEOConfigTestBase(setupstack.TestCase):
setUp = setupstack.setUpDirectory setUp = setupstack.setUpDirectory
...@@ -52,10 +53,9 @@ class ZEOConfigTestBase(setupstack.TestCase): ...@@ -52,10 +53,9 @@ class ZEOConfigTestBase(setupstack.TestCase):
</clientstorage> </clientstorage>
""".format(settings)) """.format(settings))
def _client_assertions( def _client_assertions(self, client, addr,
self, client, addr,
connected=True, connected=True,
cache_size=20 * (1<<20), cache_size=20 * (1 << 20),
cache_path=None, cache_path=None,
blob_dir=None, blob_dir=None,
shared_blob_dir=False, shared_blob_dir=False,
...@@ -67,8 +67,7 @@ class ZEOConfigTestBase(setupstack.TestCase): ...@@ -67,8 +67,7 @@ class ZEOConfigTestBase(setupstack.TestCase):
wait_timeout=30, wait_timeout=30,
client_label=None, client_label=None,
storage='1', storage='1',
name=None, name=None):
):
self.assertEqual(client.is_connected(), connected) self.assertEqual(client.is_connected(), connected)
self.assertEqual(client._addr, [addr]) self.assertEqual(client._addr, [addr])
self.assertEqual(client._cache.maxsize, cache_size) self.assertEqual(client._cache.maxsize, cache_size)
...@@ -88,6 +87,7 @@ class ZEOConfigTestBase(setupstack.TestCase): ...@@ -88,6 +87,7 @@ class ZEOConfigTestBase(setupstack.TestCase):
self.assertEqual(client.__name__, self.assertEqual(client.__name__,
name if name is not None else str(client._addr)) name if name is not None else str(client._addr))
class ZEOConfigTest(ZEOConfigTestBase): class ZEOConfigTest(ZEOConfigTestBase):
def test_default_zeo_config(self, **client_settings): def test_default_zeo_config(self, **client_settings):
...@@ -101,8 +101,7 @@ class ZEOConfigTest(ZEOConfigTestBase): ...@@ -101,8 +101,7 @@ class ZEOConfigTest(ZEOConfigTestBase):
def test_client_variations(self): def test_client_variations(self):
for name, value in dict( for name, value in dict(cache_size=4200,
cache_size=4200,
cache_path='test', cache_path='test',
blob_dir='blobs', blob_dir='blobs',
blob_cache_size=424242, blob_cache_size=424242,
...@@ -111,7 +110,7 @@ class ZEOConfigTest(ZEOConfigTestBase): ...@@ -111,7 +110,7 @@ class ZEOConfigTest(ZEOConfigTestBase):
server_sync=True, server_sync=True,
wait_timeout=33, wait_timeout=33,
client_label='test_client', client_label='test_client',
name='Test' name='Test',
).items(): ).items():
params = {name: value} params = {name: value}
self.test_default_zeo_config(**params) self.test_default_zeo_config(**params)
...@@ -120,6 +119,7 @@ class ZEOConfigTest(ZEOConfigTestBase): ...@@ -120,6 +119,7 @@ class ZEOConfigTest(ZEOConfigTestBase):
self.test_default_zeo_config(blob_cache_size=424242, self.test_default_zeo_config(blob_cache_size=424242,
blob_cache_size_check=50) blob_cache_size_check=50)
def test_suite(): def test_suite():
suite = unittest.makeSuite(ZEOConfigTest) suite = unittest.makeSuite(ZEOConfigTest)
suite.layer = threaded_server_tests suite.layer = threaded_server_tests
......
...@@ -29,10 +29,9 @@ else: ...@@ -29,10 +29,9 @@ else:
import unittest import unittest
import ZODB.tests.util import ZODB.tests.util
import ZEO
from . import forker from . import forker
class FileStorageConfig(object): class FileStorageConfig(object):
def getConfig(self, path, create, read_only): def getConfig(self, path, create, read_only):
return """\ return """\
...@@ -44,6 +43,7 @@ class FileStorageConfig(object): ...@@ -44,6 +43,7 @@ class FileStorageConfig(object):
create and 'yes' or 'no', create and 'yes' or 'no',
read_only and 'yes' or 'no') read_only and 'yes' or 'no')
class MappingStorageConfig(object): class MappingStorageConfig(object):
def getConfig(self, path, create, read_only): def getConfig(self, path, create, read_only):
return """<mappingstorage 1/>""" return """<mappingstorage 1/>"""
...@@ -52,49 +52,47 @@ class MappingStorageConfig(object): ...@@ -52,49 +52,47 @@ class MappingStorageConfig(object):
class FileStorageConnectionTests( class FileStorageConnectionTests(
FileStorageConfig, FileStorageConfig,
ConnectionTests.ConnectionTests, ConnectionTests.ConnectionTests,
InvalidationTests.InvalidationTests InvalidationTests.InvalidationTests):
):
"""FileStorage-specific connection tests.""" """FileStorage-specific connection tests."""
class FileStorageReconnectionTests( class FileStorageReconnectionTests(
FileStorageConfig, FileStorageConfig,
ConnectionTests.ReconnectionTests, ConnectionTests.ReconnectionTests):
):
"""FileStorage-specific re-connection tests.""" """FileStorage-specific re-connection tests."""
# Run this at level 1 because MappingStorage can't do reconnection tests # Run this at level 1 because MappingStorage can't do reconnection tests
class FileStorageInvqTests( class FileStorageInvqTests(
FileStorageConfig, FileStorageConfig,
ConnectionTests.InvqTests ConnectionTests.InvqTests):
):
"""FileStorage-specific invalidation queue tests.""" """FileStorage-specific invalidation queue tests."""
class FileStorageTimeoutTests( class FileStorageTimeoutTests(
FileStorageConfig, FileStorageConfig,
ConnectionTests.TimeoutTests ConnectionTests.TimeoutTests):
):
pass pass
class MappingStorageConnectionTests( class MappingStorageConnectionTests(
MappingStorageConfig, MappingStorageConfig,
ConnectionTests.ConnectionTests ConnectionTests.ConnectionTests):
):
"""Mapping storage connection tests.""" """Mapping storage connection tests."""
# The ReconnectionTests can't work with MappingStorage because it's only an # The ReconnectionTests can't work with MappingStorage because it's only an
# in-memory storage and has no persistent state. # in-memory storage and has no persistent state.
class MappingStorageTimeoutTests( class MappingStorageTimeoutTests(
MappingStorageConfig, MappingStorageConfig,
ConnectionTests.TimeoutTests ConnectionTests.TimeoutTests):
):
pass pass
class SSLConnectionTests( class SSLConnectionTests(
MappingStorageConfig, MappingStorageConfig,
ConnectionTests.SSLConnectionTests, ConnectionTests.SSLConnectionTests):
):
pass pass
...@@ -108,6 +106,7 @@ test_classes = [FileStorageConnectionTests, ...@@ -108,6 +106,7 @@ test_classes = [FileStorageConnectionTests,
if not forker.ZEO4_SERVER: if not forker.ZEO4_SERVER:
test_classes.append(SSLConnectionTests) test_classes.append(SSLConnectionTests)
def invalidations_while_connecting(): def invalidations_while_connecting():
r""" r"""
As soon as a client registers with a server, it will recieve As soon as a client registers with a server, it will recieve
...@@ -122,7 +121,7 @@ This tests tries to provoke this bug by: ...@@ -122,7 +121,7 @@ This tests tries to provoke this bug by:
- starting a server - starting a server
>>> addr, _ = start_server() >>> addr, _ = start_server() # NOQA: F821 undefined name
- opening a client to the server that writes some objects, filling - opening a client to the server that writes some objects, filling
it's cache at the same time, it's cache at the same time,
...@@ -182,7 +181,9 @@ This tests tries to provoke this bug by: ...@@ -182,7 +181,9 @@ This tests tries to provoke this bug by:
... db = ZODB.DB(ZEO.ClientStorage.ClientStorage(addr, client='x')) ... db = ZODB.DB(ZEO.ClientStorage.ClientStorage(addr, client='x'))
... with lock: ... with lock:
... #logging.getLogger('ZEO').debug('Locked %s' % c) ... #logging.getLogger('ZEO').debug('Locked %s' % c)
... @wait_until("connected and we have caught up", timeout=199) ... msg = "connected and we have caught up"
...
... @wait_until(msg, timeout=199) # NOQA: F821 undefined var
... def _(): ... def _():
... if (db.storage.is_connected() ... if (db.storage.is_connected()
... and db.storage.lastTransaction() ... and db.storage.lastTransaction()
...@@ -228,6 +229,7 @@ This tests tries to provoke this bug by: ...@@ -228,6 +229,7 @@ This tests tries to provoke this bug by:
>>> db2.close() >>> db2.close()
""" """
def test_suite(): def test_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import doctest import doctest
import unittest import unittest
import ZEO.asyncio.testing
class FakeStorageBase(object): class FakeStorageBase(object):
...@@ -31,10 +30,11 @@ class FakeStorageBase(object): ...@@ -31,10 +30,11 @@ class FakeStorageBase(object):
def __len__(self): def __len__(self):
return 4 return 4
class FakeStorage(FakeStorageBase): class FakeStorage(FakeStorageBase):
def record_iternext(self, next=None): def record_iternext(self, next=None):
if next == None: if next is None:
next = '0' next = '0'
next = str(int(next) + 1) next = str(int(next) + 1)
oid = next oid = next
...@@ -43,6 +43,7 @@ class FakeStorage(FakeStorageBase): ...@@ -43,6 +43,7 @@ class FakeStorage(FakeStorageBase):
return oid, oid*8, 'data ' + oid, next return oid, oid*8, 'data ' + oid, next
class FakeServer(object): class FakeServer(object):
storages = { storages = {
'1': FakeStorage(), '1': FakeStorage(),
...@@ -55,13 +56,17 @@ class FakeServer(object): ...@@ -55,13 +56,17 @@ class FakeServer(object):
client_conflict_resolution = False client_conflict_resolution = False
class FakeConnection(object): class FakeConnection(object):
protocol_version = b'Z4' protocol_version = b'Z4'
addr = 'test' addr = 'test'
call_soon_threadsafe = lambda f, *a: f(*a) def call_soon_threadsafe(f, *a):
return f(*a)
async_ = async_threadsafe = None async_ = async_threadsafe = None
def test_server_record_iternext(): def test_server_record_iternext():
""" """
...@@ -99,6 +104,7 @@ The storage info also reflects the fact that record_iternext is supported. ...@@ -99,6 +104,7 @@ The storage info also reflects the fact that record_iternext is supported.
""" """
def test_client_record_iternext(): def test_client_record_iternext():
"""Test client storage delegation to the network client """Test client storage delegation to the network client
...@@ -143,8 +149,10 @@ Now we'll have our way with it's private _server attr: ...@@ -143,8 +149,10 @@ Now we'll have our way with it's private _server attr:
""" """
def test_suite(): def test_suite():
return doctest.DocTestSuite() return doctest.DocTestSuite()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(defaultTest='test_suite') unittest.main(defaultTest='test_suite')
...@@ -16,15 +16,18 @@ import unittest ...@@ -16,15 +16,18 @@ import unittest
from ZEO.TransactionBuffer import TransactionBuffer from ZEO.TransactionBuffer import TransactionBuffer
def random_string(size): def random_string(size):
"""Return a random string of size size.""" """Return a random string of size size."""
l = [chr(random.randrange(256)) for i in range(size)] lst = [chr(random.randrange(256)) for i in range(size)]
return "".join(l) return "".join(lst)
def new_store_data(): def new_store_data():
"""Return arbitrary data to use as argument to store() method.""" """Return arbitrary data to use as argument to store() method."""
return random_string(8), random_string(random.randrange(1000)) return random_string(8), random_string(random.randrange(1000))
def store(tbuf, resolved=False): def store(tbuf, resolved=False):
data = new_store_data() data = new_store_data()
tbuf.store(*data) tbuf.store(*data)
...@@ -32,6 +35,7 @@ def store(tbuf, resolved=False): ...@@ -32,6 +35,7 @@ def store(tbuf, resolved=False):
tbuf.server_resolve(data[0]) tbuf.server_resolve(data[0])
return data return data
class TransBufTests(unittest.TestCase): class TransBufTests(unittest.TestCase):
def checkTypicalUsage(self): def checkTypicalUsage(self):
...@@ -54,5 +58,6 @@ class TransBufTests(unittest.TestCase): ...@@ -54,5 +58,6 @@ class TransBufTests(unittest.TestCase):
self.assertEqual(resolved, data[i][1]) self.assertEqual(resolved, data[i][1])
tbuf.close() tbuf.close()
def test_suite(): def test_suite():
return unittest.makeSuite(TransBufTests, 'check') return unittest.makeSuite(TransBufTests, 'check')
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"""Test suite for ZEO based on ZODB.tests.""" """Test suite for ZEO based on ZODB.tests."""
from __future__ import print_function from __future__ import print_function
import multiprocessing import multiprocessing
import re
from ZEO.ClientStorage import ClientStorage from ZEO.ClientStorage import ClientStorage
from ZEO.tests import forker, Cache, CommitLockTests, ThreadTests from ZEO.tests import forker, Cache, CommitLockTests, ThreadTests
...@@ -41,7 +40,6 @@ import re ...@@ -41,7 +40,6 @@ import re
import shutil import shutil
import signal import signal
import stat import stat
import ssl
import sys import sys
import tempfile import tempfile
import threading import threading
...@@ -62,11 +60,15 @@ from . import testssl ...@@ -62,11 +60,15 @@ from . import testssl
logger = logging.getLogger('ZEO.tests.testZEO') logger = logging.getLogger('ZEO.tests.testZEO')
class DummyDB(object): class DummyDB(object):
def invalidate(self, *args): def invalidate(self, *args):
pass pass
def invalidateCache(*unused): def invalidateCache(*unused):
pass pass
transform_record_data = untransform_record_data = lambda self, v: v transform_record_data = untransform_record_data = lambda self, v: v
...@@ -76,7 +78,6 @@ class CreativeGetState(persistent.Persistent): ...@@ -76,7 +78,6 @@ class CreativeGetState(persistent.Persistent):
return super(CreativeGetState, self).__getstate__() return super(CreativeGetState, self).__getstate__()
class Test_convenience_functions(unittest.TestCase): class Test_convenience_functions(unittest.TestCase):
def test_ZEO_client_convenience(self): def test_ZEO_client_convenience(self):
...@@ -206,9 +207,8 @@ class MiscZEOTests(object): ...@@ -206,9 +207,8 @@ class MiscZEOTests(object):
for n in range(30): for n in range(30):
time.sleep(.1) time.sleep(.1)
data, serial = storage2.load(oid, '') data, serial = storage2.load(oid, '')
if (serial == revid2 and if serial == revid2 and \
zodb_unpickle(data) == MinPO('second') zodb_unpickle(data) == MinPO('second'):
):
break break
else: else:
raise AssertionError('Invalidation message was not sent!') raise AssertionError('Invalidation message was not sent!')
...@@ -230,6 +230,7 @@ class MiscZEOTests(object): ...@@ -230,6 +230,7 @@ class MiscZEOTests(object):
self.assertNotEqual(ZODB.utils.z64, storage3.lastTransaction()) self.assertNotEqual(ZODB.utils.z64, storage3.lastTransaction())
storage3.close() storage3.close()
class GenericTestBase( class GenericTestBase(
# Base class for all ZODB tests # Base class for all ZODB tests
StorageTestBase.StorageTestBase): StorageTestBase.StorageTestBase):
...@@ -259,8 +260,9 @@ class GenericTestBase( ...@@ -259,8 +260,9 @@ class GenericTestBase(
) )
self._storage.registerDB(DummyDB()) self._storage.registerDB(DummyDB())
# _new_storage_client opens another ClientStorage to the same storage server # _new_storage_client opens another ClientStorage to the same storage
# self._storage is connected to. It is used by both ZEO and ZODB tests. # server self._storage is connected to. It is used by both ZEO and ZODB
# tests.
def _new_storage_client(self): def _new_storage_client(self):
client = ZEO.ClientStorage.ClientStorage( client = ZEO.ClientStorage.ClientStorage(
self._storage._addr, wait=1, **self._client_options()) self._storage._addr, wait=1, **self._client_options())
...@@ -283,9 +285,9 @@ class GenericTestBase( ...@@ -283,9 +285,9 @@ class GenericTestBase(
stop() stop()
StorageTestBase.StorageTestBase.tearDown(self) StorageTestBase.StorageTestBase.tearDown(self)
class GenericTests( class GenericTests(
GenericTestBase, GenericTestBase,
# ZODB test mixin classes (in the same order as imported) # ZODB test mixin classes (in the same order as imported)
BasicStorage.BasicStorage, BasicStorage.BasicStorage,
PackableStorage.PackableStorage, PackableStorage.PackableStorage,
...@@ -296,8 +298,7 @@ class GenericTests( ...@@ -296,8 +298,7 @@ class GenericTests(
CommitLockTests.CommitLockVoteTests, CommitLockTests.CommitLockVoteTests,
ThreadTests.ThreadTests, ThreadTests.ThreadTests,
# Locally defined (see above) # Locally defined (see above)
MiscZEOTests, MiscZEOTests):
):
"""Combine tests from various origins in one class. """Combine tests from various origins in one class.
""" """
...@@ -347,6 +348,7 @@ class GenericTests( ...@@ -347,6 +348,7 @@ class GenericTests(
thread.join(voted and .1 or 9) thread.join(voted and .1 or 9)
return thread return thread
class FullGenericTests( class FullGenericTests(
GenericTests, GenericTests,
Cache.TransUndoStorageWithCache, Cache.TransUndoStorageWithCache,
...@@ -356,8 +358,7 @@ class FullGenericTests( ...@@ -356,8 +358,7 @@ class FullGenericTests(
RevisionStorage.RevisionStorage, RevisionStorage.RevisionStorage,
TransactionalUndoStorage.TransactionalUndoStorage, TransactionalUndoStorage.TransactionalUndoStorage,
IteratorStorage.IteratorStorage, IteratorStorage.IteratorStorage,
IterationTests.IterationTests, IterationTests.IterationTests):
):
"""Extend GenericTests with tests that MappingStorage can't pass.""" """Extend GenericTests with tests that MappingStorage can't pass."""
def checkPackUndoLog(self): def checkPackUndoLog(self):
...@@ -457,8 +458,7 @@ class FileStorageTests(FullGenericTests): ...@@ -457,8 +458,7 @@ class FileStorageTests(FullGenericTests):
self._storage)) self._storage))
# This is communicated using ClientStorage's _info object: # This is communicated using ClientStorage's _info object:
self.assertEqual(self._expected_interfaces, self.assertEqual(self._expected_interfaces,
self._storage._info['interfaces'] self._storage._info['interfaces'])
)
class FileStorageSSLTests(FileStorageTests): class FileStorageSSLTests(FileStorageTests):
...@@ -492,6 +492,7 @@ class FileStorageHexTests(FileStorageTests): ...@@ -492,6 +492,7 @@ class FileStorageHexTests(FileStorageTests):
</hexstorage> </hexstorage>
""" """
class FileStorageClientHexTests(FileStorageHexTests): class FileStorageClientHexTests(FileStorageHexTests):
use_extension_bytes = True use_extension_bytes = True
...@@ -509,10 +510,10 @@ class FileStorageClientHexTests(FileStorageHexTests): ...@@ -509,10 +510,10 @@ class FileStorageClientHexTests(FileStorageHexTests):
def _wrap_client(self, client): def _wrap_client(self, client):
return ZODB.tests.hexstorage.HexStorage(client) return ZODB.tests.hexstorage.HexStorage(client)
class ClientConflictResolutionTests( class ClientConflictResolutionTests(
GenericTestBase, GenericTestBase,
ConflictResolution.ConflictResolvingStorage, ConflictResolution.ConflictResolvingStorage):
):
def getConfig(self): def getConfig(self):
return '<mappingstorage>\n</mappingstorage>\n' return '<mappingstorage>\n</mappingstorage>\n'
...@@ -520,7 +521,9 @@ class ClientConflictResolutionTests( ...@@ -520,7 +521,9 @@ class ClientConflictResolutionTests(
def getZEOConfig(self): def getZEOConfig(self):
# Using '' can result in binding to :: and cause problems # Using '' can result in binding to :: and cause problems
# connecting to the MTAcceptor on Travis CI # connecting to the MTAcceptor on Travis CI
return forker.ZEOConfig(('127.0.0.1', 0), client_conflict_resolution=True) return forker.ZEOConfig(('127.0.0.1', 0),
client_conflict_resolution=True)
class MappingStorageTests(GenericTests): class MappingStorageTests(GenericTests):
"""ZEO backed by a Mapping storage.""" """ZEO backed by a Mapping storage."""
...@@ -538,9 +541,8 @@ class MappingStorageTests(GenericTests): ...@@ -538,9 +541,8 @@ class MappingStorageTests(GenericTests):
# to construct our iterator, which we don't, so we disable this test. # to construct our iterator, which we don't, so we disable this test.
pass pass
class DemoStorageTests(
GenericTests, class DemoStorageTests(GenericTests):
):
def getConfig(self): def getConfig(self):
return """ return """
...@@ -560,6 +562,7 @@ class DemoStorageTests( ...@@ -560,6 +562,7 @@ class DemoStorageTests(
pass # DemoStorage pack doesn't do gc pass # DemoStorage pack doesn't do gc
checkPackAllRevisions = checkPackWithMultiDatabaseReferences checkPackAllRevisions = checkPackWithMultiDatabaseReferences
class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown): class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
def getConfig(self, path, create, read_only): def getConfig(self, path, create, read_only):
...@@ -573,7 +576,6 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown): ...@@ -573,7 +576,6 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
handler = zope.testing.loggingsupport.InstalledHandler( handler = zope.testing.loggingsupport.InstalledHandler(
'ZEO.asyncio.client') 'ZEO.asyncio.client')
# We no longer implement the event loop, we we no longer know # We no longer implement the event loop, we we no longer know
# how to break it. We'll just stop it instead for now. # how to break it. We'll just stop it instead for now.
self._storage._server.loop.call_soon_threadsafe( self._storage._server.loop.call_soon_threadsafe(
...@@ -581,7 +583,7 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown): ...@@ -581,7 +583,7 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
forker.wait_until( forker.wait_until(
'disconnected', 'disconnected',
lambda : not self._storage.is_connected() lambda: not self._storage.is_connected()
) )
log = str(handler) log = str(handler)
...@@ -614,10 +616,13 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown): ...@@ -614,10 +616,13 @@ class ZRPCConnectionTests(ZEO.tests.ConnectionTests.CommonSetupTearDown):
class DummyDB(object): class DummyDB(object):
_invalidatedCache = 0 _invalidatedCache = 0
def invalidateCache(self): def invalidateCache(self):
self._invalidatedCache += 1 self._invalidatedCache += 1
def invalidate(*a, **k): def invalidate(*a, **k):
pass pass
transform_record_data = untransform_record_data = \ transform_record_data = untransform_record_data = \
lambda self, data: data lambda self, data: data
...@@ -660,7 +665,6 @@ class CommonBlobTests(object): ...@@ -660,7 +665,6 @@ class CommonBlobTests(object):
blob_cache_dir = 'blob_cache' blob_cache_dir = 'blob_cache'
def checkStoreBlob(self): def checkStoreBlob(self):
import transaction
from ZODB.blob import Blob from ZODB.blob import Blob
from ZODB.tests.StorageTestBase import ZERO from ZODB.tests.StorageTestBase import ZERO
from ZODB.tests.StorageTestBase import zodb_pickle from ZODB.tests.StorageTestBase import zodb_pickle
...@@ -681,7 +685,7 @@ class CommonBlobTests(object): ...@@ -681,7 +685,7 @@ class CommonBlobTests(object):
self._storage.storeBlob(oid, ZERO, data, tfname, '', t) self._storage.storeBlob(oid, ZERO, data, tfname, '', t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
revid = self._storage.tpc_finish(t) revid = self._storage.tpc_finish(t)
except: except: # NOQA: E722 bare except
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
raise raise
self.assertTrue(not os.path.exists(tfname)) self.assertTrue(not os.path.exists(tfname))
...@@ -703,7 +707,6 @@ class CommonBlobTests(object): ...@@ -703,7 +707,6 @@ class CommonBlobTests(object):
def checkLoadBlob(self): def checkLoadBlob(self):
from ZODB.blob import Blob from ZODB.blob import Blob
from ZODB.tests.StorageTestBase import zodb_pickle, ZERO from ZODB.tests.StorageTestBase import zodb_pickle, ZERO
import transaction
somedata = b'a' * 10 somedata = b'a' * 10
...@@ -720,7 +723,7 @@ class CommonBlobTests(object): ...@@ -720,7 +723,7 @@ class CommonBlobTests(object):
self._storage.storeBlob(oid, ZERO, data, tfname, '', t) self._storage.storeBlob(oid, ZERO, data, tfname, '', t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
serial = self._storage.tpc_finish(t) serial = self._storage.tpc_finish(t)
except: except: # NOQA: E722 bare except
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
raise raise
...@@ -749,7 +752,6 @@ class BlobAdaptedFileStorageTests(FullGenericTests, CommonBlobTests): ...@@ -749,7 +752,6 @@ class BlobAdaptedFileStorageTests(FullGenericTests, CommonBlobTests):
"""ZEO backed by a BlobStorage-adapted FileStorage.""" """ZEO backed by a BlobStorage-adapted FileStorage."""
def checkStoreAndLoadBlob(self): def checkStoreAndLoadBlob(self):
import transaction
from ZODB.blob import Blob from ZODB.blob import Blob
from ZODB.tests.StorageTestBase import ZERO from ZODB.tests.StorageTestBase import ZERO
from ZODB.tests.StorageTestBase import zodb_pickle from ZODB.tests.StorageTestBase import zodb_pickle
...@@ -785,7 +787,7 @@ class BlobAdaptedFileStorageTests(FullGenericTests, CommonBlobTests): ...@@ -785,7 +787,7 @@ class BlobAdaptedFileStorageTests(FullGenericTests, CommonBlobTests):
self._storage.storeBlob(oid, ZERO, data, tfname, '', t) self._storage.storeBlob(oid, ZERO, data, tfname, '', t)
self._storage.tpc_vote(t) self._storage.tpc_vote(t)
revid = self._storage.tpc_finish(t) revid = self._storage.tpc_finish(t)
except: except: # NOQA: E722 bare except
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
raise raise
...@@ -812,11 +814,9 @@ class BlobAdaptedFileStorageTests(FullGenericTests, CommonBlobTests): ...@@ -812,11 +814,9 @@ class BlobAdaptedFileStorageTests(FullGenericTests, CommonBlobTests):
returns = [] returns = []
threads = [ threads = [
threading.Thread( threading.Thread(
target=lambda : target=lambda:
returns.append(self._storage.loadBlob(oid, revid)) returns.append(self._storage.loadBlob(oid, revid))
) ) for i in range(10)]
for i in range(10)
]
[thread.start() for thread in threads] [thread.start() for thread in threads]
[thread.join() for thread in threads] [thread.join() for thread in threads]
[self.assertEqual(r, filename) for r in returns] [self.assertEqual(r, filename) for r in returns]
...@@ -828,18 +828,21 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests): ...@@ -828,18 +828,21 @@ class BlobWritableCacheTests(FullGenericTests, CommonBlobTests):
blob_cache_dir = 'blobs' blob_cache_dir = 'blobs'
shared_blob_dir = True shared_blob_dir = True
class FauxConn(object): class FauxConn(object):
addr = 'x' addr = 'x'
protocol_version = ZEO.asyncio.server.best_protocol_version protocol_version = ZEO.asyncio.server.best_protocol_version
peer_protocol_version = protocol_version peer_protocol_version = protocol_version
serials = [] serials = []
def async_(self, method, *args): def async_(self, method, *args):
if method == 'serialnos': if method == 'serialnos':
self.serials.extend(args[0]) self.serials.extend(args[0])
call_soon_threadsafe = async_threadsafe = async_ call_soon_threadsafe = async_threadsafe = async_
class StorageServerWrapper(object): class StorageServerWrapper(object):
def __init__(self, server, storage_id): def __init__(self, server, storage_id):
...@@ -881,13 +884,14 @@ class StorageServerWrapper(object): ...@@ -881,13 +884,14 @@ class StorageServerWrapper(object):
def tpc_abort(self, transaction): def tpc_abort(self, transaction):
self.server.tpc_abort(id(transaction)) self.server.tpc_abort(id(transaction))
def tpc_finish(self, transaction, func = lambda: None): def tpc_finish(self, transaction, func=lambda: None):
self.server.tpc_finish(id(transaction)).set_sender(0, self) self.server.tpc_finish(id(transaction)).set_sender(0, self)
return self._result return self._result
def multiple_storages_invalidation_queue_is_not_insane(): def multiple_storages_invalidation_queue_is_not_insane():
""" """
>>> from ZEO.StorageServer import StorageServer, ZEOStorage >>> from ZEO.StorageServer import StorageServer
>>> from ZODB.FileStorage import FileStorage >>> from ZODB.FileStorage import FileStorage
>>> from ZODB.DB import DB >>> from ZODB.DB import DB
>>> from persistent.mapping import PersistentMapping >>> from persistent.mapping import PersistentMapping
...@@ -926,6 +930,7 @@ def multiple_storages_invalidation_queue_is_not_insane(): ...@@ -926,6 +930,7 @@ def multiple_storages_invalidation_queue_is_not_insane():
>>> fs1.close(); fs2.close() >>> fs1.close(); fs2.close()
""" """
def getInvalidationsAfterServerRestart(): def getInvalidationsAfterServerRestart():
""" """
...@@ -969,12 +974,10 @@ If a storage implements the method lastInvalidations, as FileStorage ...@@ -969,12 +974,10 @@ If a storage implements the method lastInvalidations, as FileStorage
does, then the storage server will populate its invalidation data does, then the storage server will populate its invalidation data
structure using lastTransactions. structure using lastTransactions.
>>> tid, oids = s.getInvalidations(last[-10]) >>> tid, oids = s.getInvalidations(last[-10])
>>> tid == last[-1] >>> tid == last[-1]
True True
>>> from ZODB.utils import u64 >>> from ZODB.utils import u64
>>> sorted([int(u64(oid)) for oid in oids]) >>> sorted([int(u64(oid)) for oid in oids])
[0, 92, 93, 94, 95, 96, 97, 98, 99, 100] [0, 92, 93, 94, 95, 96, 97, 98, 99, 100]
...@@ -1023,13 +1026,14 @@ that were only created. ...@@ -1023,13 +1026,14 @@ that were only created.
>>> fs.close() >>> fs.close()
""" """
def tpc_finish_error(): def tpc_finish_error():
r"""Server errors in tpc_finish weren't handled properly. r"""Server errors in tpc_finish weren't handled properly.
If there are errors applying changes to the client cache, don't If there are errors applying changes to the client cache, don't
leave the cache in an inconsistent state. leave the cache in an inconsistent state.
>>> addr, admin = start_server() >>> addr, admin = start_server() # NOQA: F821 undefined
>>> client = ZEO.client(addr) >>> client = ZEO.client(addr)
>>> db = ZODB.DB(client) >>> db = ZODB.DB(client)
...@@ -1070,16 +1074,17 @@ def tpc_finish_error(): ...@@ -1070,16 +1074,17 @@ def tpc_finish_error():
>>> db.close() >>> db.close()
>>> stop_server(admin) >>> stop_server(admin) # NOQA: F821 undefined
""" """
def test_prefetch(self): def test_prefetch(self):
"""The client storage prefetch method pre-fetches from the server """The client storage prefetch method pre-fetches from the server
>>> count = 999 >>> count = 999
>>> import ZEO >>> import ZEO
>>> addr, stop = start_server() >>> addr, stop = start_server() # NOQA: F821 undefined
>>> conn = ZEO.connection(addr) >>> conn = ZEO.connection(addr)
>>> root = conn.root() >>> root = conn.root()
>>> cls = root.__class__ >>> cls = root.__class__
...@@ -1102,7 +1107,7 @@ def test_prefetch(self): ...@@ -1102,7 +1107,7 @@ def test_prefetch(self):
But it is filled eventually: But it is filled eventually:
>>> from zope.testing.wait import wait >>> from zope.testing.wait import wait
>>> wait(lambda : len(storage._cache) > count) >>> wait(lambda: len(storage._cache) > count)
>>> loads = storage.server_status()['loads'] >>> loads = storage.server_status()['loads']
...@@ -1117,15 +1122,16 @@ def test_prefetch(self): ...@@ -1117,15 +1122,16 @@ def test_prefetch(self):
>>> conn.close() >>> conn.close()
""" """
def client_has_newer_data_than_server(): def client_has_newer_data_than_server():
"""It is bad if a client has newer data than the server. """It is bad if a client has newer data than the server.
>>> db = ZODB.DB('Data.fs') >>> db = ZODB.DB('Data.fs')
>>> db.close() >>> db.close()
>>> r = shutil.copyfile('Data.fs', 'Data.save') >>> r = shutil.copyfile('Data.fs', 'Data.save')
>>> addr, admin = start_server(keep=1) >>> addr, admin = start_server(keep=1) # NOQA: F821 undefined
>>> db = ZEO.DB(addr, name='client', max_disconnect_poll=.01) >>> db = ZEO.DB(addr, name='client', max_disconnect_poll=.01)
>>> wait_connected(db.storage) >>> wait_connected(db.storage) # NOQA: F821 undefined
>>> conn = db.open() >>> conn = db.open()
>>> conn.root().x = 1 >>> conn.root().x = 1
>>> transaction.commit() >>> transaction.commit()
...@@ -1134,7 +1140,7 @@ def client_has_newer_data_than_server(): ...@@ -1134,7 +1140,7 @@ def client_has_newer_data_than_server():
the new data. Now, we'll stop the server, put back the old data, and the new data. Now, we'll stop the server, put back the old data, and
see what happens. :) see what happens. :)
>>> stop_server(admin) >>> stop_server(admin) # NOQA: F821 undefined
>>> r = shutil.copyfile('Data.save', 'Data.fs') >>> r = shutil.copyfile('Data.save', 'Data.fs')
>>> import zope.testing.loggingsupport >>> import zope.testing.loggingsupport
...@@ -1142,9 +1148,9 @@ def client_has_newer_data_than_server(): ...@@ -1142,9 +1148,9 @@ def client_has_newer_data_than_server():
... 'ZEO', level=logging.ERROR) ... 'ZEO', level=logging.ERROR)
>>> formatter = logging.Formatter('%(name)s %(levelname)s %(message)s') >>> formatter = logging.Formatter('%(name)s %(levelname)s %(message)s')
>>> _, admin = start_server(addr=addr) >>> _, admin = start_server(addr=addr) # NOQA: F821 undefined
>>> wait_until('got enough errors', lambda: >>> wait_until('got enough errors', lambda: # NOQA: F821 undefined
... len([x for x in handler.records ... len([x for x in handler.records
... if x.levelname == 'CRITICAL' and ... if x.levelname == 'CRITICAL' and
... 'Client cache is out of sync with the server.' in x.msg ... 'Client cache is out of sync with the server.' in x.msg
...@@ -1154,15 +1160,16 @@ def client_has_newer_data_than_server(): ...@@ -1154,15 +1160,16 @@ def client_has_newer_data_than_server():
>>> db.close() >>> db.close()
>>> handler.uninstall() >>> handler.uninstall()
>>> stop_server(admin) >>> stop_server(admin) # NOQA: F821 undefined
""" """
def history_over_zeo(): def history_over_zeo():
""" """
>>> addr, _ = start_server() >>> addr, _ = start_server() # NOQA: F821 undefined
>>> db = ZEO.DB(addr) >>> db = ZEO.DB(addr)
>>> wait_connected(db.storage) >>> wait_connected(db.storage) # NOQA: F821 undefined
>>> conn = db.open() >>> conn = db.open()
>>> conn.root().x = 0 >>> conn.root().x = 0
>>> transaction.commit() >>> transaction.commit()
...@@ -1172,9 +1179,10 @@ def history_over_zeo(): ...@@ -1172,9 +1179,10 @@ def history_over_zeo():
>>> db.close() >>> db.close()
""" """
def dont_log_poskeyerrors_on_server(): def dont_log_poskeyerrors_on_server():
""" """
>>> addr, admin = start_server(log='server.log') >>> addr, admin = start_server(log='server.log') # NOQA: F821 undefined
>>> cs = ClientStorage(addr) >>> cs = ClientStorage(addr)
>>> cs.load(ZODB.utils.p64(1)) >>> cs.load(ZODB.utils.p64(1))
Traceback (most recent call last): Traceback (most recent call last):
...@@ -1182,16 +1190,17 @@ def dont_log_poskeyerrors_on_server(): ...@@ -1182,16 +1190,17 @@ def dont_log_poskeyerrors_on_server():
POSKeyError: 0x01 POSKeyError: 0x01
>>> cs.close() >>> cs.close()
>>> stop_server(admin) >>> stop_server(admin) # NOQA: F821 undefined
>>> with open('server.log') as f: >>> with open('server.log') as f:
... 'POSKeyError' in f.read() ... 'POSKeyError' in f.read()
False False
""" """
def open_convenience(): def open_convenience():
"""Often, we just want to open a single connection. """Often, we just want to open a single connection.
>>> addr, _ = start_server(path='data.fs') >>> addr, _ = start_server(path='data.fs') # NOQA: F821 undefined
>>> conn = ZEO.connection(addr) >>> conn = ZEO.connection(addr)
>>> conn.root() >>> conn.root()
{} {}
...@@ -1210,9 +1219,10 @@ def open_convenience(): ...@@ -1210,9 +1219,10 @@ def open_convenience():
>>> db.close() >>> db.close()
""" """
def client_asyncore_thread_has_name(): def client_asyncore_thread_has_name():
""" """
>>> addr, _ = start_server() >>> addr, _ = start_server() # NOQA: F821 undefined
>>> db = ZEO.DB(addr) >>> db = ZEO.DB(addr)
>>> any(t for t in threading.enumerate() >>> any(t for t in threading.enumerate()
... if ' zeo client networking thread' in t.getName()) ... if ' zeo client networking thread' in t.getName())
...@@ -1220,6 +1230,7 @@ def client_asyncore_thread_has_name(): ...@@ -1220,6 +1230,7 @@ def client_asyncore_thread_has_name():
>>> db.close() >>> db.close()
""" """
def runzeo_without_configfile(): def runzeo_without_configfile():
r""" r"""
>>> with open('runzeo', 'w') as r: >>> with open('runzeo', 'w') as r:
...@@ -1251,11 +1262,12 @@ def runzeo_without_configfile(): ...@@ -1251,11 +1262,12 @@ def runzeo_without_configfile():
>>> proc.stdout.close() >>> proc.stdout.close()
""" """
def close_client_storage_w_invalidations(): def close_client_storage_w_invalidations():
r""" r"""
Invalidations could cause errors when closing client storages, Invalidations could cause errors when closing client storages,
>>> addr, _ = start_server() >>> addr, _ = start_server() # NOQA: F821 undefined
>>> writing = threading.Event() >>> writing = threading.Event()
>>> def mad_write_thread(): >>> def mad_write_thread():
... global writing ... global writing
...@@ -1280,10 +1292,11 @@ Invalidations could cause errors when closing client storages, ...@@ -1280,10 +1292,11 @@ Invalidations could cause errors when closing client storages,
>>> thread.join(1) >>> thread.join(1)
""" """
def convenient_to_pass_port_to_client_and_ZEO_dot_client(): def convenient_to_pass_port_to_client_and_ZEO_dot_client():
"""Jim hates typing """Jim hates typing
>>> addr, _ = start_server() >>> addr, _ = start_server() # NOQA: F821 undefined
>>> client = ZEO.client(addr[1]) >>> client = ZEO.client(addr[1])
>>> client.__name__ == "('127.0.0.1', %s)" % addr[1] >>> client.__name__ == "('127.0.0.1', %s)" % addr[1]
True True
...@@ -1291,12 +1304,14 @@ def convenient_to_pass_port_to_client_and_ZEO_dot_client(): ...@@ -1291,12 +1304,14 @@ def convenient_to_pass_port_to_client_and_ZEO_dot_client():
>>> client.close() >>> client.close()
""" """
@forker.skip_if_testing_client_against_zeo4 @forker.skip_if_testing_client_against_zeo4
def test_server_status(): def test_server_status():
""" """
You can get server status using the server_status method. You can get server status using the server_status method.
>>> addr, _ = start_server(zeo_conf=dict(transaction_timeout=1)) >>> addr, _ = start_server( # NOQA: F821 undefined
... zeo_conf=dict(transaction_timeout=1))
>>> db = ZEO.DB(addr) >>> db = ZEO.DB(addr)
>>> pprint.pprint(db.storage.server_status(), width=40) >>> pprint.pprint(db.storage.server_status(), width=40)
{'aborts': 0, {'aborts': 0,
...@@ -1316,12 +1331,14 @@ def test_server_status(): ...@@ -1316,12 +1331,14 @@ def test_server_status():
>>> db.close() >>> db.close()
""" """
@forker.skip_if_testing_client_against_zeo4 @forker.skip_if_testing_client_against_zeo4
def test_ruok(): def test_ruok():
""" """
You can also get server status using the ruok protocol. You can also get server status using the ruok protocol.
>>> addr, _ = start_server(zeo_conf=dict(transaction_timeout=1)) >>> addr, _ = start_server( # NOQA: F821 undefined
... zeo_conf=dict(transaction_timeout=1))
>>> db = ZEO.DB(addr) # force a transaction :) >>> db = ZEO.DB(addr) # force a transaction :)
>>> import json, socket, struct >>> import json, socket, struct
>>> s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) >>> s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
...@@ -1349,6 +1366,7 @@ def test_ruok(): ...@@ -1349,6 +1366,7 @@ def test_ruok():
>>> db.close(); s.close() >>> db.close(); s.close()
""" """
def client_labels(): def client_labels():
""" """
When looking at server logs, for servers with lots of clients coming When looking at server logs, for servers with lots of clients coming
...@@ -1358,10 +1376,10 @@ log entries with actual clients. It's possible, sort of, but tedious. ...@@ -1358,10 +1376,10 @@ log entries with actual clients. It's possible, sort of, but tedious.
You can make this easier by passing a label to the ClientStorage You can make this easier by passing a label to the ClientStorage
constructor. constructor.
>>> addr, _ = start_server(log='server.log') >>> addr, _ = start_server(log='server.log') # NOQA: F821 undefined
>>> db = ZEO.DB(addr, client_label='test-label-1') >>> db = ZEO.DB(addr, client_label='test-label-1')
>>> db.close() >>> db.close()
>>> @wait_until >>> @wait_until # NOQA: F821 undefined
... def check_for_test_label_1(): ... def check_for_test_label_1():
... with open('server.log') as f: ... with open('server.log') as f:
... for line in f: ... for line in f:
...@@ -1382,7 +1400,7 @@ You can specify the client label via a configuration file as well: ...@@ -1382,7 +1400,7 @@ You can specify the client label via a configuration file as well:
... </zodb> ... </zodb>
... ''' % addr[1]) ... ''' % addr[1])
>>> db.close() >>> db.close()
>>> @wait_until >>> @wait_until # NOQA: F821 undefined
... def check_for_test_label_2(): ... def check_for_test_label_2():
... with open('server.log') as f: ... with open('server.log') as f:
... for line in f: ... for line in f:
...@@ -1393,6 +1411,7 @@ You can specify the client label via a configuration file as well: ...@@ -1393,6 +1411,7 @@ You can specify the client label via a configuration file as well:
""" """
def invalidate_client_cache_entry_on_server_commit_error(): def invalidate_client_cache_entry_on_server_commit_error():
""" """
...@@ -1400,7 +1419,7 @@ When the serials returned during commit includes an error, typically a ...@@ -1400,7 +1419,7 @@ When the serials returned during commit includes an error, typically a
conflict error, invalidate the cache entry. This is important when conflict error, invalidate the cache entry. This is important when
the cache is messed up. the cache is messed up.
>>> addr, _ = start_server() >>> addr, _ = start_server() # NOQA: F821 undefined
>>> conn1 = ZEO.connection(addr) >>> conn1 = ZEO.connection(addr)
>>> conn1.root.x = conn1.root().__class__() >>> conn1.root.x = conn1.root().__class__()
>>> transaction.commit() >>> transaction.commit()
...@@ -1473,6 +1492,8 @@ sys.path[:] = %(path)r ...@@ -1473,6 +1492,8 @@ sys.path[:] = %(path)r
%(src)s %(src)s
""" """
def generate_script(name, src): def generate_script(name, src):
with open(name, 'w') as f: with open(name, 'w') as f:
f.write(script_template % dict( f.write(script_template % dict(
...@@ -1481,10 +1502,12 @@ def generate_script(name, src): ...@@ -1481,10 +1502,12 @@ def generate_script(name, src):
src=src, src=src,
)) ))
def read(filename): def read(filename):
with open(filename) as f: with open(filename) as f:
return f.read() return f.read()
def runzeo_logrotate_on_sigusr2(): def runzeo_logrotate_on_sigusr2():
""" """
>>> from ZEO.tests.forker import get_port >>> from ZEO.tests.forker import get_port
...@@ -1506,10 +1529,10 @@ def runzeo_logrotate_on_sigusr2(): ...@@ -1506,10 +1529,10 @@ def runzeo_logrotate_on_sigusr2():
... import ZEO.runzeo ... import ZEO.runzeo
... ZEO.runzeo.main() ... ZEO.runzeo.main()
... ''') ... ''')
>>> import subprocess, signal >>> import subprocess
>>> p = subprocess.Popen([sys.executable, 's', '-Cc'], close_fds=True) >>> p = subprocess.Popen([sys.executable, 's', '-Cc'], close_fds=True)
>>> wait_until('started', >>> wait_until('started', # NOQA: F821 undefined
... lambda : os.path.exists('l') and ('listening on' in read('l')) ... lambda: os.path.exists('l') and ('listening on' in read('l'))
... ) ... )
>>> oldlog = read('l') >>> oldlog = read('l')
...@@ -1518,7 +1541,8 @@ def runzeo_logrotate_on_sigusr2(): ...@@ -1518,7 +1541,8 @@ def runzeo_logrotate_on_sigusr2():
>>> s = ClientStorage(port) >>> s = ClientStorage(port)
>>> s.close() >>> s.close()
>>> wait_until('See logging', lambda : ('Log files ' in read('l'))) >>> wait_until('See logging', # NOQA: F821 undefined
... lambda: ('Log files ' in read('l')))
>>> read('o') == oldlog # No new data in old log >>> read('o') == oldlog # No new data in old log
True True
...@@ -1528,10 +1552,11 @@ def runzeo_logrotate_on_sigusr2(): ...@@ -1528,10 +1552,11 @@ def runzeo_logrotate_on_sigusr2():
>>> _ = p.wait() >>> _ = p.wait()
""" """
def unix_domain_sockets(): def unix_domain_sockets():
"""Make sure unix domain sockets work """Make sure unix domain sockets work
>>> addr, _ = start_server(port='./sock') >>> addr, _ = start_server(port='./sock') # NOQA: F821 undefined
>>> c = ZEO.connection(addr) >>> c = ZEO.connection(addr)
>>> c.root.x = 1 >>> c.root.x = 1
...@@ -1539,6 +1564,7 @@ def unix_domain_sockets(): ...@@ -1539,6 +1564,7 @@ def unix_domain_sockets():
>>> c.close() >>> c.close()
""" """
def gracefully_handle_abort_while_storing_many_blobs(): def gracefully_handle_abort_while_storing_many_blobs():
r""" r"""
...@@ -1548,7 +1574,7 @@ def gracefully_handle_abort_while_storing_many_blobs(): ...@@ -1548,7 +1574,7 @@ def gracefully_handle_abort_while_storing_many_blobs():
>>> handler = logging.StreamHandler(sys.stdout) >>> handler = logging.StreamHandler(sys.stdout)
>>> logging.getLogger().addHandler(handler) >>> logging.getLogger().addHandler(handler)
>>> addr, _ = start_server(blob_dir='blobs') >>> addr, _ = start_server(blob_dir='blobs') # NOQA: F821 undefined
>>> client = ZEO.client(addr, blob_dir='cblobs') >>> client = ZEO.client(addr, blob_dir='cblobs')
>>> c = ZODB.connection(client) >>> c = ZODB.connection(client)
>>> c.root.x = ZODB.blob.Blob(b'z'*(1<<20)) >>> c.root.x = ZODB.blob.Blob(b'z'*(1<<20))
...@@ -1578,6 +1604,7 @@ call to the server. we'd get some sort of error here. ...@@ -1578,6 +1604,7 @@ call to the server. we'd get some sort of error here.
""" """
def ClientDisconnected_errors_are_TransientErrors(): def ClientDisconnected_errors_are_TransientErrors():
""" """
>>> from ZEO.Exceptions import ClientDisconnected >>> from ZEO.Exceptions import ClientDisconnected
...@@ -1586,6 +1613,7 @@ def ClientDisconnected_errors_are_TransientErrors(): ...@@ -1586,6 +1613,7 @@ def ClientDisconnected_errors_are_TransientErrors():
True True
""" """
if not os.environ.get('ZEO4_SERVER'): if not os.environ.get('ZEO4_SERVER'):
if os.environ.get('ZEO_MSGPACK'): if os.environ.get('ZEO_MSGPACK'):
def test_runzeo_msgpack_support(): def test_runzeo_msgpack_support():
...@@ -1620,11 +1648,13 @@ if WIN: ...@@ -1620,11 +1648,13 @@ if WIN:
del runzeo_logrotate_on_sigusr2 del runzeo_logrotate_on_sigusr2
del unix_domain_sockets del unix_domain_sockets
def work_with_multiprocessing_process(name, addr, q): def work_with_multiprocessing_process(name, addr, q):
conn = ZEO.connection(addr) conn = ZEO.connection(addr)
q.put((name, conn.root.x)) q.put((name, conn.root.x))
conn.close() conn.close()
class MultiprocessingTests(unittest.TestCase): class MultiprocessingTests(unittest.TestCase):
layer = ZODB.tests.util.MininalTestLayer('work_with_multiprocessing') layer = ZODB.tests.util.MininalTestLayer('work_with_multiprocessing')
...@@ -1634,9 +1664,9 @@ class MultiprocessingTests(unittest.TestCase): ...@@ -1634,9 +1664,9 @@ class MultiprocessingTests(unittest.TestCase):
# Gaaa, zope.testing.runner.FakeInputContinueGenerator has no close # Gaaa, zope.testing.runner.FakeInputContinueGenerator has no close
if not hasattr(sys.stdin, 'close'): if not hasattr(sys.stdin, 'close'):
sys.stdin.close = lambda : None sys.stdin.close = lambda: None
if not hasattr(sys.stdin, 'fileno'): if not hasattr(sys.stdin, 'fileno'):
sys.stdin.fileno = lambda : -1 sys.stdin.fileno = lambda: -1
self.globs = {} self.globs = {}
forker.setUp(self) forker.setUp(self)
...@@ -1657,6 +1687,7 @@ class MultiprocessingTests(unittest.TestCase): ...@@ -1657,6 +1687,7 @@ class MultiprocessingTests(unittest.TestCase):
conn.close() conn.close()
zope.testing.setupstack.tearDown(self) zope.testing.setupstack.tearDown(self)
@forker.skip_if_testing_client_against_zeo4 @forker.skip_if_testing_client_against_zeo4
def quick_close_doesnt_kill_server(): def quick_close_doesnt_kill_server():
r""" r"""
...@@ -1664,7 +1695,7 @@ def quick_close_doesnt_kill_server(): ...@@ -1664,7 +1695,7 @@ def quick_close_doesnt_kill_server():
Start a server: Start a server:
>>> from .testssl import server_config, client_ssl >>> from .testssl import server_config, client_ssl
>>> addr, _ = start_server(zeo_conf=server_config) >>> addr, _ = start_server(zeo_conf=server_config) # NOQA: F821 undefined
Now connect and immediately disconnect. This caused the server to Now connect and immediately disconnect. This caused the server to
die in the past: die in the past:
...@@ -1678,7 +1709,10 @@ def quick_close_doesnt_kill_server(): ...@@ -1678,7 +1709,10 @@ def quick_close_doesnt_kill_server():
... s.close() ... s.close()
>>> print("\n\nXXX WARNING: running quick_close_doesnt_kill_server with ssl as hack pending http://bugs.python.org/issue27386\n", file=sys.stderr) # Intentional long line to be annoying till this is fixed >>> print("\n\nXXX WARNING: running quick_close_doesnt_kill_server "
... "with ssl as hack pending http://bugs.python.org/issue27386\n",
... file=sys.stderr) # Intentional long line to be annoying
... # until this is fixed
Now we should be able to connect as normal: Now we should be able to connect as normal:
...@@ -1689,10 +1723,11 @@ def quick_close_doesnt_kill_server(): ...@@ -1689,10 +1723,11 @@ def quick_close_doesnt_kill_server():
>>> db.close() >>> db.close()
""" """
def can_use_empty_string_for_local_host_on_client(): def can_use_empty_string_for_local_host_on_client():
"""We should be able to spell localhost with ''. """We should be able to spell localhost with ''.
>>> (_, port), _ = start_server() >>> (_, port), _ = start_server() # NOQA: F821 undefined name
>>> conn = ZEO.connection(('', port)) >>> conn = ZEO.connection(('', port))
>>> conn.root() >>> conn.root()
{} {}
...@@ -1702,6 +1737,7 @@ def can_use_empty_string_for_local_host_on_client(): ...@@ -1702,6 +1737,7 @@ def can_use_empty_string_for_local_host_on_client():
>>> conn.close() >>> conn.close()
""" """
slow_test_classes = [ slow_test_classes = [
BlobAdaptedFileStorageTests, BlobWritableCacheTests, BlobAdaptedFileStorageTests, BlobWritableCacheTests,
MappingStorageTests, DemoStorageTests, MappingStorageTests, DemoStorageTests,
...@@ -1713,6 +1749,7 @@ if not forker.ZEO4_SERVER: ...@@ -1713,6 +1749,7 @@ if not forker.ZEO4_SERVER:
quick_test_classes = [FileStorageRecoveryTests, ZRPCConnectionTests] quick_test_classes = [FileStorageRecoveryTests, ZRPCConnectionTests]
class ServerManagingClientStorage(ClientStorage): class ServerManagingClientStorage(ClientStorage):
def __init__(self, name, blob_dir, shared=False, extrafsoptions=''): def __init__(self, name, blob_dir, shared=False, extrafsoptions=''):
...@@ -1743,9 +1780,11 @@ class ServerManagingClientStorage(ClientStorage): ...@@ -1743,9 +1780,11 @@ class ServerManagingClientStorage(ClientStorage):
ClientStorage.close(self) ClientStorage.close(self)
zope.testing.setupstack.tearDown(self) zope.testing.setupstack.tearDown(self)
def create_storage_shared(name, blob_dir): def create_storage_shared(name, blob_dir):
return ServerManagingClientStorage(name, blob_dir, True) return ServerManagingClientStorage(name, blob_dir, True)
class ServerManagingClientStorageForIExternalGCTest( class ServerManagingClientStorageForIExternalGCTest(
ServerManagingClientStorage): ServerManagingClientStorage):
...@@ -1756,6 +1795,7 @@ class ServerManagingClientStorageForIExternalGCTest( ...@@ -1756,6 +1795,7 @@ class ServerManagingClientStorageForIExternalGCTest(
self._cache.clear() self._cache.clear()
ZEO.ClientStorage._check_blob_cache_size(self.blob_dir, 0) ZEO.ClientStorage._check_blob_cache_size(self.blob_dir, 0)
def test_suite(): def test_suite():
suite = unittest.TestSuite(( suite = unittest.TestSuite((
unittest.makeSuite(Test_convenience_functions), unittest.makeSuite(Test_convenience_functions),
...@@ -1769,7 +1809,8 @@ def test_suite(): ...@@ -1769,7 +1809,8 @@ def test_suite():
'last-transaction'), 'last-transaction'),
(re.compile("ZODB.POSException.ConflictError"), "ConflictError"), (re.compile("ZODB.POSException.ConflictError"), "ConflictError"),
(re.compile("ZODB.POSException.POSKeyError"), "POSKeyError"), (re.compile("ZODB.POSException.POSKeyError"), "POSKeyError"),
(re.compile("ZEO.Exceptions.ClientStorageError"), "ClientStorageError"), (re.compile("ZEO.Exceptions.ClientStorageError"),
"ClientStorageError"),
(re.compile(r"\[Errno \d+\]"), '[Errno N]'), (re.compile(r"\[Errno \d+\]"), '[Errno N]'),
(re.compile(r"loads=\d+\.\d+"), 'loads=42.42'), (re.compile(r"loads=\d+\.\d+"), 'loads=42.42'),
# Python 3 drops the u prefix # Python 3 drops the u prefix
...@@ -1810,7 +1851,7 @@ def test_suite(): ...@@ -1810,7 +1851,7 @@ def test_suite():
), ),
) )
zeo.addTest(PackableStorage.IExternalGC_suite( zeo.addTest(PackableStorage.IExternalGC_suite(
lambda : lambda:
ServerManagingClientStorageForIExternalGCTest( ServerManagingClientStorageForIExternalGCTest(
'data.fs', 'blobs', extrafsoptions='pack-gc false') 'data.fs', 'blobs', extrafsoptions='pack-gc false')
)) ))
......
...@@ -27,6 +27,7 @@ import ZODB.FileStorage ...@@ -27,6 +27,7 @@ import ZODB.FileStorage
import ZODB.tests.util import ZODB.tests.util
import ZODB.utils import ZODB.utils
def proper_handling_of_blob_conflicts(): def proper_handling_of_blob_conflicts():
r""" r"""
...@@ -108,6 +109,7 @@ The transaction is aborted by the server: ...@@ -108,6 +109,7 @@ The transaction is aborted by the server:
>>> fs.close() >>> fs.close()
""" """
def proper_handling_of_errors_in_restart(): def proper_handling_of_errors_in_restart():
r""" r"""
...@@ -149,6 +151,7 @@ We can start another client and get the storage lock. ...@@ -149,6 +151,7 @@ We can start another client and get the storage lock.
>>> server.close() >>> server.close()
""" """
def errors_in_vote_should_clear_lock(): def errors_in_vote_should_clear_lock():
""" """
...@@ -409,6 +412,7 @@ If clients disconnect while waiting, they will be dequeued: ...@@ -409,6 +412,7 @@ If clients disconnect while waiting, they will be dequeued:
>>> server.close() >>> server.close()
""" """
def lock_sanity_check(): def lock_sanity_check():
r""" r"""
On one occasion with 3.10.0a1 in production, we had a case where a On one occasion with 3.10.0a1 in production, we had a case where a
...@@ -492,6 +496,7 @@ ZEOStorage as closed and see if trying to get a lock cleans it up: ...@@ -492,6 +496,7 @@ ZEOStorage as closed and see if trying to get a lock cleans it up:
>>> server.close() >>> server.close()
""" """
def test_suite(): def test_suite():
return unittest.TestSuite(( return unittest.TestSuite((
doctest.DocTestSuite( doctest.DocTestSuite(
...@@ -506,5 +511,6 @@ def test_suite(): ...@@ -506,5 +511,6 @@ def test_suite():
), ),
)) ))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(defaultTest='test_suite') unittest.main(defaultTest='test_suite')
...@@ -27,6 +27,7 @@ from zdaemon.tests.testzdoptions import TestZDOptions ...@@ -27,6 +27,7 @@ from zdaemon.tests.testzdoptions import TestZDOptions
# supplies the empty string. # supplies the empty string.
DEFAULT_BINDING_HOST = "" DEFAULT_BINDING_HOST = ""
class TestZEOOptions(TestZDOptions): class TestZEOOptions(TestZDOptions):
OptionsClass = ZEOOptions OptionsClass = ZEOOptions
...@@ -106,5 +107,6 @@ def test_suite(): ...@@ -106,5 +107,6 @@ def test_suite():
suite.addTest(unittest.makeSuite(cls)) suite.addTest(unittest.makeSuite(cls))
return suite return suite
if __name__ == "__main__": if __name__ == "__main__":
unittest.main(defaultTest='test_suite') unittest.main(defaultTest='test_suite')
import unittest import unittest
import mock import mock
import os
from ZEO._compat import PY3 from ZEO._compat import PY3
from ZEO.runzeo import ZEOServer from ZEO.runzeo import ZEOServer
...@@ -11,7 +10,8 @@ class TestStorageServer(object): ...@@ -11,7 +10,8 @@ class TestStorageServer(object):
def __init__(self, fail_create_server): def __init__(self, fail_create_server):
self.called = [] self.called = []
if fail_create_server: raise RuntimeError() if fail_create_server:
raise RuntimeError()
def close(self): def close(self):
self.called.append("close") self.called.append("close")
...@@ -49,7 +49,8 @@ class TestZEOServer(ZEOServer): ...@@ -49,7 +49,8 @@ class TestZEOServer(ZEOServer):
def loop_forever(self): def loop_forever(self):
self.called.append("loop_forever") self.called.append("loop_forever")
if self.fail_loop_forever: raise RuntimeError() if self.fail_loop_forever:
raise RuntimeError()
def close_server(self): def close_server(self):
self.called.append("close_server") self.called.append("close_server")
...@@ -138,6 +139,7 @@ class CloseServerTests(unittest.TestCase): ...@@ -138,6 +139,7 @@ class CloseServerTests(unittest.TestCase):
self.assertEqual(hasattr(zeo, "server"), True) self.assertEqual(hasattr(zeo, "server"), True)
self.assertEqual(zeo.server, None) self.assertEqual(zeo.server, None)
@mock.patch('os.unlink') @mock.patch('os.unlink')
class TestZEOServerSocket(unittest.TestCase): class TestZEOServerSocket(unittest.TestCase):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Basic unit tests for a client cache.""" """Basic unit tests for a client cache."""
from __future__ import print_function from __future__ import print_function
from ZODB.utils import p64, repr_to_oid from ZODB.utils import p64, u64, z64, repr_to_oid
import doctest import doctest
import os import os
import re import re
...@@ -28,8 +28,6 @@ import ZODB.tests.util ...@@ -28,8 +28,6 @@ import ZODB.tests.util
import zope.testing.setupstack import zope.testing.setupstack
import zope.testing.renormalizing import zope.testing.renormalizing
import ZEO.cache
from ZODB.utils import p64, u64, z64
n1 = p64(1) n1 = p64(1)
n2 = p64(2) n2 = p64(2)
...@@ -47,8 +45,8 @@ def hexprint(file): ...@@ -47,8 +45,8 @@ def hexprint(file):
printable = "" printable = ""
hex = "" hex = ""
for character in line: for character in line:
if (character in string.printable if character in string.printable and \
and not ord(character) in [12,13,9]): not ord(character) in [12, 13, 9]:
printable += character printable += character
else: else:
printable += '.' printable += '.'
...@@ -63,8 +61,11 @@ def hexprint(file): ...@@ -63,8 +61,11 @@ def hexprint(file):
def oid(o): def oid(o):
repr = '%016x' % o repr = '%016x' % o
return repr_to_oid(repr) return repr_to_oid(repr)
tid = oid tid = oid
class CacheTests(ZODB.tests.util.TestCase): class CacheTests(ZODB.tests.util.TestCase):
def setUp(self): def setUp(self):
...@@ -207,24 +208,24 @@ class CacheTests(ZODB.tests.util.TestCase): ...@@ -207,24 +208,24 @@ class CacheTests(ZODB.tests.util.TestCase):
self.assertTrue(1 not in cache.noncurrent) self.assertTrue(1 not in cache.noncurrent)
def testVeryLargeCaches(self): def testVeryLargeCaches(self):
cache = ZEO.cache.ClientCache('cache', size=(1<<32)+(1<<20)) cache = ZEO.cache.ClientCache('cache', size=(1 << 32)+(1 << 20))
cache.store(n1, n2, None, b"x") cache.store(n1, n2, None, b"x")
cache.close() cache.close()
cache = ZEO.cache.ClientCache('cache', size=(1<<33)+(1<<20)) cache = ZEO.cache.ClientCache('cache', size=(1 << 33)+(1 << 20))
self.assertEqual(cache.load(n1), (b'x', n2)) self.assertEqual(cache.load(n1), (b'x', n2))
cache.close() cache.close()
def testConversionOfLargeFreeBlocks(self): def testConversionOfLargeFreeBlocks(self):
with open('cache', 'wb') as f: with open('cache', 'wb') as f:
f.write(ZEO.cache.magic+ f.write(ZEO.cache.magic +
b'\0'*8 + b'\0'*8 +
b'f'+struct.pack(">I", (1<<32)-12) b'f'+struct.pack(">I", (1 << 32)-12)
) )
f.seek((1<<32)-1) f.seek((1 << 32)-1)
f.write(b'x') f.write(b'x')
cache = ZEO.cache.ClientCache('cache', size=1<<32) cache = ZEO.cache.ClientCache('cache', size=1 << 32)
cache.close() cache.close()
cache = ZEO.cache.ClientCache('cache', size=1<<32) cache = ZEO.cache.ClientCache('cache', size=1 << 32)
cache.close() cache.close()
with open('cache', 'rb') as f: with open('cache', 'rb') as f:
f.seek(12) f.seek(12)
...@@ -310,7 +311,6 @@ class CacheTests(ZODB.tests.util.TestCase): ...@@ -310,7 +311,6 @@ class CacheTests(ZODB.tests.util.TestCase):
self.assertEqual(set(u64(oid) for (oid, tid) in cache.contents()), self.assertEqual(set(u64(oid) for (oid, tid) in cache.contents()),
expected_oids) expected_oids)
for i in range(200, 305): for i in range(200, 305):
cache.store(p64(i), n1, None, data) cache.store(p64(i), n1, None, data)
...@@ -356,6 +356,7 @@ class CacheTests(ZODB.tests.util.TestCase): ...@@ -356,6 +356,7 @@ class CacheTests(ZODB.tests.util.TestCase):
self.assertEqual(cache.loadBefore(oid, n2), (b'first', n1, n2)) self.assertEqual(cache.loadBefore(oid, n2), (b'first', n1, n2))
self.assertEqual(cache.loadBefore(oid, n3), (b'second', n2, None)) self.assertEqual(cache.loadBefore(oid, n3), (b'second', n2, None))
def kill_does_not_cause_cache_corruption(): def kill_does_not_cause_cache_corruption():
r""" r"""
...@@ -363,7 +364,7 @@ If we kill a process while a cache is being written to, the cache ...@@ -363,7 +364,7 @@ If we kill a process while a cache is being written to, the cache
isn't corrupted. To see this, we'll write a little script that isn't corrupted. To see this, we'll write a little script that
writes records to a cache file repeatedly. writes records to a cache file repeatedly.
>>> import os, random, sys, time >>> import os, sys
>>> with open('t', 'w') as f: >>> with open('t', 'w') as f:
... _ = f.write(''' ... _ = f.write('''
... import os, random, sys, time ... import os, random, sys, time
...@@ -402,6 +403,7 @@ writes records to a cache file repeatedly. ...@@ -402,6 +403,7 @@ writes records to a cache file repeatedly.
""" """
def full_cache_is_valid(): def full_cache_is_valid():
r""" r"""
...@@ -419,6 +421,7 @@ still be used. ...@@ -419,6 +421,7 @@ still be used.
>>> cache.close() >>> cache.close()
""" """
def cannot_open_same_cache_file_twice(): def cannot_open_same_cache_file_twice():
r""" r"""
>>> import ZEO.cache >>> import ZEO.cache
...@@ -432,6 +435,7 @@ LockError: Couldn't lock 'cache.lock' ...@@ -432,6 +435,7 @@ LockError: Couldn't lock 'cache.lock'
>>> cache.close() >>> cache.close()
""" """
def broken_non_current(): def broken_non_current():
r""" r"""
...@@ -467,6 +471,7 @@ Couldn't find non-current ...@@ -467,6 +471,7 @@ Couldn't find non-current
# def bad_magic_number(): See rename_bad_cache_file # def bad_magic_number(): See rename_bad_cache_file
def cache_trace_analysis(): def cache_trace_analysis():
r""" r"""
Check to make sure the cache analysis scripts work. Check to make sure the cache analysis scripts work.
...@@ -585,19 +590,19 @@ Check to make sure the cache analysis scripts work. ...@@ -585,19 +590,19 @@ Check to make sure the cache analysis scripts work.
Jul 11 12:11:43 20 947 0000000000000000 0000000000000000 - Jul 11 12:11:43 20 947 0000000000000000 0000000000000000 -
Jul 11 12:11:43 52 947 0000000000000002 0000000000000000 - 602 Jul 11 12:11:43 52 947 0000000000000002 0000000000000000 - 602
Jul 11 12:11:44 20 124b 0000000000000000 0000000000000000 - Jul 11 12:11:44 20 124b 0000000000000000 0000000000000000 -
Jul 11 12:11:44 52 124b 0000000000000002 0000000000000000 - 1418 Jul 11 12:11:44 52 ... 124b 0000000000000002 0000000000000000 - 1418
... ...
Jul 11 15:14:55 52 10cc 00000000000003e9 0000000000000000 - 1306 Jul 11 15:14:55 52 ... 10cc 00000000000003e9 0000000000000000 - 1306
Jul 11 15:14:56 20 18a7 0000000000000000 0000000000000000 - Jul 11 15:14:56 20 18a7 0000000000000000 0000000000000000 -
Jul 11 15:14:56 52 18a7 00000000000003e9 0000000000000000 - 1610 Jul 11 15:14:56 52 ... 18a7 00000000000003e9 0000000000000000 - 1610
Jul 11 15:14:57 22 18b5 000000000000031d 0000000000000000 - 1636 Jul 11 15:14:57 22 ... 18b5 000000000000031d 0000000000000000 - 1636
Jul 11 15:14:58 20 b8a 0000000000000000 0000000000000000 - Jul 11 15:14:58 20 b8a 0000000000000000 0000000000000000 -
Jul 11 15:14:58 52 b8a 00000000000003e9 0000000000000000 - 838 Jul 11 15:14:58 52 b8a 00000000000003e9 0000000000000000 - 838
Jul 11 15:14:59 22 1085 0000000000000357 0000000000000000 - 217 Jul 11 15:14:59 22 1085 0000000000000357 0000000000000000 - 217
Jul 11 15:00-14 818 291 30 609 35.6% Jul 11 15:00-14 818 291 30 609 35.6%
Jul 11 15:15:00 22 1072 000000000000037e 0000000000000000 - 204 Jul 11 15:15:00 22 1072 000000000000037e 0000000000000000 - 204
Jul 11 15:15:01 20 16c5 0000000000000000 0000000000000000 - Jul 11 15:15:01 20 16c5 0000000000000000 0000000000000000 -
Jul 11 15:15:01 52 16c5 00000000000003e9 0000000000000000 - 1712 Jul 11 15:15:01 52 ... 16c5 00000000000003e9 0000000000000000 - 1712
Jul 11 15:15-15 2 1 0 1 50.0% Jul 11 15:15-15 2 1 0 1 50.0%
<BLANKLINE> <BLANKLINE>
Read 18,876 trace records (641,776 bytes) in 0.0 seconds Read 18,876 trace records (641,776 bytes) in 0.0 seconds
...@@ -1001,6 +1006,7 @@ Cleanup: ...@@ -1001,6 +1006,7 @@ Cleanup:
""" """
def cache_simul_properly_handles_load_miss_after_eviction_and_inval(): def cache_simul_properly_handles_load_miss_after_eviction_and_inval():
r""" r"""
...@@ -1031,6 +1037,7 @@ Now try to do simulation: ...@@ -1031,6 +1037,7 @@ Now try to do simulation:
""" """
def invalidations_with_current_tid_dont_wreck_cache(): def invalidations_with_current_tid_dont_wreck_cache():
""" """
>>> cache = ZEO.cache.ClientCache('cache', 1000) >>> cache = ZEO.cache.ClientCache('cache', 1000)
...@@ -1049,6 +1056,7 @@ def invalidations_with_current_tid_dont_wreck_cache(): ...@@ -1049,6 +1056,7 @@ def invalidations_with_current_tid_dont_wreck_cache():
>>> logging.getLogger().setLevel(old_level) >>> logging.getLogger().setLevel(old_level)
""" """
def rename_bad_cache_file(): def rename_bad_cache_file():
""" """
An attempt to open a bad cache file will cause it to be dropped and recreated. An attempt to open a bad cache file will cause it to be dropped and recreated.
...@@ -1098,6 +1106,7 @@ An attempt to open a bad cache file will cause it to be dropped and recreated. ...@@ -1098,6 +1106,7 @@ An attempt to open a bad cache file will cause it to be dropped and recreated.
>>> logging.getLogger().setLevel(old_level) >>> logging.getLogger().setLevel(old_level)
""" """
def test_suite(): def test_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(CacheTests)) suite.addTest(unittest.makeSuite(CacheTests))
...@@ -1105,10 +1114,9 @@ def test_suite(): ...@@ -1105,10 +1114,9 @@ def test_suite():
doctest.DocTestSuite( doctest.DocTestSuite(
setUp=zope.testing.setupstack.setUpDirectory, setUp=zope.testing.setupstack.setUpDirectory,
tearDown=zope.testing.setupstack.tearDown, tearDown=zope.testing.setupstack.tearDown,
checker=ZODB.tests.util.checker + \ checker=(ZODB.tests.util.checker +
zope.testing.renormalizing.RENormalizing([ zope.testing.renormalizing.RENormalizing([
(re.compile(r'31\.3%'), '31.2%'), (re.compile(r'31\.3%'), '31.2%')])),
]),
) )
) )
return suite return suite
...@@ -11,6 +11,7 @@ import ZEO.StorageServer ...@@ -11,6 +11,7 @@ import ZEO.StorageServer
from . import forker from . import forker
from .threaded import threaded_server_tests from .threaded import threaded_server_tests
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL") @unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
class ClientAuthTests(setupstack.TestCase): class ClientAuthTests(setupstack.TestCase):
...@@ -54,8 +55,8 @@ class ClientAuthTests(setupstack.TestCase): ...@@ -54,8 +55,8 @@ class ClientAuthTests(setupstack.TestCase):
stop() stop()
def test_suite(): def test_suite():
suite = unittest.makeSuite(ClientAuthTests) suite = unittest.makeSuite(ClientAuthTests)
suite.layer = threaded_server_tests suite.layer = threaded_server_tests
return suite return suite
...@@ -15,11 +15,13 @@ import ZEO ...@@ -15,11 +15,13 @@ import ZEO
from . import forker from . import forker
from .utils import StorageServer from .utils import StorageServer
class Var(object): class Var(object):
def __eq__(self, other): def __eq__(self, other):
self.value = other self.value = other
return True return True
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL") @unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase): class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
...@@ -62,7 +64,6 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase): ...@@ -62,7 +64,6 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length') self.assertEqual(reader.getClassName(p), 'BTrees.Length.Length')
self.assertEqual(reader.getState(p), 2) self.assertEqual(reader.getState(p), 2)
# Now, we'll create a server that expects the client to # Now, we'll create a server that expects the client to
# resolve conflicts: # resolve conflicts:
...@@ -119,17 +120,18 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase): ...@@ -119,17 +120,18 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
addr, stop = ZEO.server(os.path.join(path, 'data.fs'), threaded=False) addr, stop = ZEO.server(os.path.join(path, 'data.fs'), threaded=False)
db = ZEO.DB(addr) db = ZEO.DB(addr)
with db.transaction() as conn: with db.transaction() as conn:
conn.root.l = Length(0) conn.root.len = Length(0)
conn2 = db.open() conn2 = db.open()
conn2.root.l.change(1) conn2.root.len.change(1)
with db.transaction() as conn: with db.transaction() as conn:
conn.root.l.change(1) conn.root.len.change(1)
conn2.transaction_manager.commit() conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2) self.assertEqual(conn2.root.len.value, 2)
db.close(); stop() db.close()
stop()
# Now, do conflict resolution on the client. # Now, do conflict resolution on the client.
addr2, stop = ZEO.server( addr2, stop = ZEO.server(
...@@ -140,18 +142,20 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase): ...@@ -140,18 +142,20 @@ class ClientSideConflictResolutionTests(zope.testing.setupstack.TestCase):
db = ZEO.DB(addr2) db = ZEO.DB(addr2)
with db.transaction() as conn: with db.transaction() as conn:
conn.root.l = Length(0) conn.root.len = Length(0)
conn2 = db.open() conn2 = db.open()
conn2.root.l.change(1) conn2.root.len.change(1)
with db.transaction() as conn: with db.transaction() as conn:
conn.root.l.change(1) conn.root.len.change(1)
self.assertEqual(conn2.root.l.value, 1) self.assertEqual(conn2.root.len.value, 1)
conn2.transaction_manager.commit() conn2.transaction_manager.commit()
self.assertEqual(conn2.root.l.value, 2) self.assertEqual(conn2.root.len.value, 2)
db.close()
stop()
db.close(); stop()
def test_suite(): def test_suite():
return unittest.makeSuite(ClientSideConflictResolutionTests) return unittest.makeSuite(ClientSideConflictResolutionTests)
...@@ -17,7 +17,7 @@ class MarshalTests(unittest.TestCase): ...@@ -17,7 +17,7 @@ class MarshalTests(unittest.TestCase):
# this is an example (1) of Zope2's arguments for # this is an example (1) of Zope2's arguments for
# undoInfo call. Arguments are encoded by ZEO client # undoInfo call. Arguments are encoded by ZEO client
# and decoded by server. The operation must be idempotent. # and decoded by server. The operation must be idempotent.
# (1) https://github.com/zopefoundation/Zope/blob/2.13/src/App/Undo.py#L111 # (1) https://github.com/zopefoundation/Zope/blob/2.13/src/App/Undo.py#L111 # NOQA: E501 line too long
args = (0, 20, {'user_name': Prefix('test')}) args = (0, 20, {'user_name': Prefix('test')})
# test against repr because Prefix __eq__ operator # test against repr because Prefix __eq__ operator
# doesn't compare Prefix with Prefix but only # doesn't compare Prefix with Prefix but only
......
import unittest
from zope.testing import setupstack from zope.testing import setupstack
from .. import server, client from .. import server, client
...@@ -13,6 +11,7 @@ else: ...@@ -13,6 +11,7 @@ else:
server_ping_method = 'ping' server_ping_method = 'ping'
server_zss = 'zeo_storages_by_storage_id' server_zss = 'zeo_storages_by_storage_id'
class SyncTests(setupstack.TestCase): class SyncTests(setupstack.TestCase):
def instrument(self): def instrument(self):
...@@ -22,6 +21,7 @@ class SyncTests(setupstack.TestCase): ...@@ -22,6 +21,7 @@ class SyncTests(setupstack.TestCase):
[zs] = getattr(server.server, server_zss)['1'] [zs] = getattr(server.server, server_zss)['1']
orig_ping = getattr(zs, server_ping_method) orig_ping = getattr(zs, server_ping_method)
def ping(): def ping():
self.__ping_calls += 1 self.__ping_calls += 1
return orig_ping() return orig_ping()
......
...@@ -21,6 +21,7 @@ serverpw_key = os.path.join(here, 'serverpw_key.pem') ...@@ -21,6 +21,7 @@ serverpw_key = os.path.join(here, 'serverpw_key.pem')
client_cert = os.path.join(here, 'client.pem') client_cert = os.path.join(here, 'client.pem')
client_key = os.path.join(here, 'client_key.pem') client_key = os.path.join(here, 'client_key.pem')
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL") @unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
class SSLConfigTest(ZEOConfigTestBase): class SSLConfigTest(ZEOConfigTestBase):
...@@ -117,6 +118,7 @@ class SSLConfigTest(ZEOConfigTestBase): ...@@ -117,6 +118,7 @@ class SSLConfigTest(ZEOConfigTestBase):
) )
stop() stop()
@unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL") @unittest.skipIf(forker.ZEO4_SERVER, "ZEO4 servers don't support SSL")
@mock.patch(('asyncio' if PY3 else 'trollius') + '.ensure_future') @mock.patch(('asyncio' if PY3 else 'trollius') + '.ensure_future')
@mock.patch(('asyncio' if PY3 else 'trollius') + '.set_event_loop') @mock.patch(('asyncio' if PY3 else 'trollius') + '.set_event_loop')
...@@ -139,8 +141,7 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase): ...@@ -139,8 +141,7 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
cert=(server_cert, server_key, None), cert=(server_cert, server_key, None),
verify_mode=ssl.CERT_REQUIRED, verify_mode=ssl.CERT_REQUIRED,
check_hostname=False, check_hostname=False,
cafile=None, capath=None, cafile=None, capath=None):
):
factory.assert_called_with( factory.assert_called_with(
ssl.Purpose.CLIENT_AUTH if server else ssl.Purpose.SERVER_AUTH, ssl.Purpose.CLIENT_AUTH if server else ssl.Purpose.SERVER_AUTH,
cafile=cafile, capath=capath) cafile=cafile, capath=capath)
...@@ -181,27 +182,31 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase): ...@@ -181,27 +182,31 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
) )
context = server.acceptor.ssl_context context = server.acceptor.ssl_context
self.assert_context(True, self.assert_context(True,
factory, context, (server_cert, server_key, pwfunc), capath=here) factory,
context,
(server_cert, server_key, pwfunc),
capath=here)
server.close() server.close()
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_no_ssl(self, ClientStorage, factory, *_): def test_ssl_mockiavellian_client_no_ssl(self, ClientStorage, factory, *_):
client = ssl_client() ssl_client()
self.assertFalse('ssl' in ClientStorage.call_args[1]) self.assertFalse('ssl' in ClientStorage.call_args[1])
self.assertFalse('ssl_server_hostname' in ClientStorage.call_args[1]) self.assertFalse('ssl_server_hostname' in ClientStorage.call_args[1])
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_server_signed( def test_ssl_mockiavellian_client_server_signed(
self, ClientStorage, factory, *_ self, ClientStorage, factory, *_):
): ssl_client(certificate=client_cert, key=client_key)
client = ssl_client(certificate=client_cert, key=client_key)
context = ClientStorage.call_args[1]['ssl'] context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'], self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None) None)
self.assert_context(False, self.assert_context(False,
factory, context, (client_cert, client_key, None), factory,
context,
(client_cert, client_key, None),
check_hostname=True) check_hostname=True)
context.load_default_certs.assert_called_with() context.load_default_certs.assert_called_with()
...@@ -209,43 +214,42 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase): ...@@ -209,43 +214,42 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_auth_dir( def test_ssl_mockiavellian_client_auth_dir(
self, ClientStorage, factory, *_ self, ClientStorage, factory, *_):
): ssl_client(
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=here) certificate=client_cert, key=client_key, authenticate=here)
context = ClientStorage.call_args[1]['ssl'] context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'], self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None) None)
self.assert_context(False, self.assert_context(False,
factory, context, (client_cert, client_key, None), factory,
context,
(client_cert, client_key, None),
capath=here, capath=here,
check_hostname=True, check_hostname=True)
)
context.load_default_certs.assert_not_called() context.load_default_certs.assert_not_called()
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_auth_file( def test_ssl_mockiavellian_client_auth_file(
self, ClientStorage, factory, *_ self, ClientStorage, factory, *_):
): ssl_client(
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert) certificate=client_cert, key=client_key, authenticate=server_cert)
context = ClientStorage.call_args[1]['ssl'] context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'], self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None) None)
self.assert_context(False, self.assert_context(False,
factory, context, (client_cert, client_key, None), factory,
context,
(client_cert, client_key, None),
cafile=server_cert, cafile=server_cert,
check_hostname=True, check_hostname=True)
)
context.load_default_certs.assert_not_called() context.load_default_certs.assert_not_called()
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_pw( def test_ssl_mockiavellian_client_pw(
self, ClientStorage, factory, *_ self, ClientStorage, factory, *_):
): ssl_client(
client = ssl_client(
certificate=client_cert, key=client_key, certificate=client_cert, key=client_key,
password_function='ZEO.tests.testssl.pwfunc', password_function='ZEO.tests.testssl.pwfunc',
authenticate=server_cert) authenticate=server_cert)
...@@ -253,48 +257,51 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase): ...@@ -253,48 +257,51 @@ class SSLConfigTestMockiavellian(ZEOConfigTestBase):
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'], self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None) None)
self.assert_context(False, self.assert_context(False,
factory, context, (client_cert, client_key, pwfunc), factory,
context,
(client_cert, client_key, pwfunc),
cafile=server_cert, cafile=server_cert,
check_hostname=True, check_hostname=True)
)
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_server_hostname( def test_ssl_mockiavellian_client_server_hostname(
self, ClientStorage, factory, *_ self, ClientStorage, factory, *_):
): ssl_client(
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert, certificate=client_cert, key=client_key, authenticate=server_cert,
server_hostname='example.com') server_hostname='example.com')
context = ClientStorage.call_args[1]['ssl'] context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'], self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
'example.com') 'example.com')
self.assert_context(False, self.assert_context(False,
factory, context, (client_cert, client_key, None), factory,
context,
(client_cert, client_key, None),
cafile=server_cert, cafile=server_cert,
check_hostname=True, check_hostname=True)
)
@mock.patch('ssl.create_default_context') @mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage') @mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_check_hostname( def test_ssl_mockiavellian_client_check_hostname(
self, ClientStorage, factory, *_ self, ClientStorage, factory, *_):
): ssl_client(
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert, certificate=client_cert, key=client_key, authenticate=server_cert,
check_hostname=False) check_hostname=False)
context = ClientStorage.call_args[1]['ssl'] context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'], self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None) None)
self.assert_context(False, self.assert_context(False,
factory, context, (client_cert, client_key, None), factory,
context,
(client_cert, client_key, None),
cafile=server_cert, cafile=server_cert,
check_hostname=False, check_hostname=False)
)
def args(*a, **kw): def args(*a, **kw):
return a, kw return a, kw
def ssl_conf(**ssl_settings): def ssl_conf(**ssl_settings):
if ssl_settings: if ssl_settings:
ssl_conf = '<ssl>\n' + '\n'.join( ssl_conf = '<ssl>\n' + '\n'.join(
...@@ -306,6 +313,7 @@ def ssl_conf(**ssl_settings): ...@@ -306,6 +313,7 @@ def ssl_conf(**ssl_settings):
return ssl_conf return ssl_conf
def ssl_client(**ssl_settings): def ssl_client(**ssl_settings):
return storageFromString( return storageFromString(
"""%import ZEO """%import ZEO
...@@ -317,6 +325,7 @@ def ssl_client(**ssl_settings): ...@@ -317,6 +325,7 @@ def ssl_client(**ssl_settings):
""".format(ssl_conf(**ssl_settings)) """.format(ssl_conf(**ssl_settings))
) )
def create_server(**ssl_settings): def create_server(**ssl_settings):
with open('conf', 'w') as f: with open('conf', 'w') as f:
f.write( f.write(
...@@ -336,7 +345,9 @@ def create_server(**ssl_settings): ...@@ -336,7 +345,9 @@ def create_server(**ssl_settings):
s.create_server() s.create_server()
return s.server return s.server
pwfunc = lambda : '1234'
def pwfunc():
return '1234'
def test_suite(): def test_suite():
...@@ -347,8 +358,8 @@ def test_suite(): ...@@ -347,8 +358,8 @@ def test_suite():
suite.layer = threaded_server_tests suite.layer = threaded_server_tests
return suite return suite
# Helpers for other tests:
# Helpers for other tests:
server_config = """ server_config = """
<zeo> <zeo>
address 127.0.0.1:0 address 127.0.0.1:0
...@@ -360,6 +371,7 @@ server_config = """ ...@@ -360,6 +371,7 @@ server_config = """
</zeo> </zeo>
""".format(server_cert, server_key, client_cert) """.format(server_cert, server_key, client_cert)
def client_ssl(cafile=server_key, def client_ssl(cafile=server_key,
client_cert=client_cert, client_cert=client_cert,
client_key=client_key, client_key=client_key,
...@@ -373,11 +385,11 @@ def client_ssl(cafile=server_key, ...@@ -373,11 +385,11 @@ def client_ssl(cafile=server_key,
return context return context
# See # See
# https://discuss.pivotal.io/hc/en-us/articles/202653388-How-to-renew-an-expired-Apache-Web-Server-self-signed-certificate-using-the-OpenSSL-tool # https://discuss.pivotal.io/hc/en-us/articles/202653388-How-to-renew-an-expired-Apache-Web-Server-self-signed-certificate-using-the-OpenSSL-tool # NOQA: E501
# for instructions on updating the server.pem (the certificate) if # for instructions on updating the server.pem (the certificate) if
# needed. server.pem.csr is the request. # needed. server.pem.csr is the request.
# This should do it: # This should do it:
# openssl x509 -req -days 999999 -in src/ZEO/tests/server.pem.csr -signkey src/ZEO/tests/server_key.pem -out src/ZEO/tests/server.pem # openssl x509 -req -days 999999 -in src/ZEO/tests/server.pem.csr -signkey src/ZEO/tests/server_key.pem -out src/ZEO/tests/server.pem # NOQA: E501
# If you need to create a new key first: # If you need to create a new key first:
# openssl genrsa -out server_key.pem 2048 # openssl genrsa -out server_key.pem 2048
# These two files should then be copied to client_key.pem and client.pem. # These two files should then be copied to client_key.pem and client.pem.
...@@ -9,4 +9,3 @@ import ZODB.tests.util ...@@ -9,4 +9,3 @@ import ZODB.tests.util
threaded_server_tests = ZODB.tests.util.MininalTestLayer( threaded_server_tests = ZODB.tests.util.MininalTestLayer(
'threaded_server_tests') 'threaded_server_tests')
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import ZEO.StorageServer import ZEO.StorageServer
from ..asyncio.server import best_protocol_version from ..asyncio.server import best_protocol_version
class ServerProtocol(object): class ServerProtocol(object):
method = ('register', ) method = ('register', )
...@@ -17,6 +18,7 @@ class ServerProtocol(object): ...@@ -17,6 +18,7 @@ class ServerProtocol(object):
zs.notify_connected(self) zs.notify_connected(self)
closed = False closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -30,6 +32,7 @@ class ServerProtocol(object): ...@@ -30,6 +32,7 @@ class ServerProtocol(object):
async_threadsafe = async_ async_threadsafe = async_
class StorageServer(object): class StorageServer(object):
"""Create a client interface to a StorageServer. """Create a client interface to a StorageServer.
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import os import os
def parentdir(p, n=1): def parentdir(p, n=1):
"""Return the ancestor of p from n levels up.""" """Return the ancestor of p from n levels up."""
d = p d = p
...@@ -25,6 +26,7 @@ def parentdir(p, n=1): ...@@ -25,6 +26,7 @@ def parentdir(p, n=1):
n -= 1 n -= 1
return d return d
class Environment(object): class Environment(object):
"""Determine location of the Data.fs & ZEO_SERVER.pid files. """Determine location of the Data.fs & ZEO_SERVER.pid files.
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
import sys import sys
def ssl_config(section, server): def ssl_config(section, server):
import ssl import ssl
...@@ -10,9 +11,9 @@ def ssl_config(section, server): ...@@ -10,9 +11,9 @@ def ssl_config(section, server):
auth = section.authenticate auth = section.authenticate
if auth: if auth:
if os.path.isdir(auth): if os.path.isdir(auth):
capath=auth capath = auth
elif auth != 'DYNAMIC': elif auth != 'DYNAMIC':
cafile=auth cafile = auth
context = ssl.create_default_context( context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH if server else ssl.Purpose.SERVER_AUTH, ssl.Purpose.CLIENT_AUTH if server else ssl.Purpose.SERVER_AUTH,
...@@ -44,12 +45,15 @@ def ssl_config(section, server): ...@@ -44,12 +45,15 @@ def ssl_config(section, server):
return context, section.server_hostname return context, section.server_hostname
def server_ssl(section): def server_ssl(section):
return ssl_config(section, True) return ssl_config(section, True)
def client_ssl(section): def client_ssl(section):
return ssl_config(section, False) return ssl_config(section, False)
class ClientStorageConfig(object): class ClientStorageConfig(object):
def __init__(self, config): def __init__(self, config):
...@@ -86,6 +90,6 @@ class ClientStorageConfig(object): ...@@ -86,6 +90,6 @@ class ClientStorageConfig(object):
name=config.name, name=config.name,
read_only=config.read_only, read_only=config.read_only,
read_only_fallback=config.read_only_fallback, read_only_fallback=config.read_only_fallback,
server_sync = config.server_sync, server_sync=config.server_sync,
wait_timeout=config.wait_timeout, wait_timeout=config.wait_timeout,
**options) **options)
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import ZEO import ZEO
import zdaemon.zdctl import zdaemon.zdctl
# Main program # Main program
def main(args=None): def main(args=None):
options = zdaemon.zdctl.ZDCtlOptions() options = zdaemon.zdctl.ZDCtlOptions()
...@@ -27,5 +28,6 @@ def main(args=None): ...@@ -27,5 +28,6 @@ def main(args=None):
options.schemafile = "zeoctl.xml" options.schemafile = "zeoctl.xml"
zdaemon.zdctl.main(args, options) zdaemon.zdctl.main(args, options)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -38,10 +38,12 @@ extras = ...@@ -38,10 +38,12 @@ extras =
basepython = python3 basepython = python3
skip_install = true skip_install = true
deps = deps =
flake8
check-manifest check-manifest
check-python-versions >= 0.19.1 check-python-versions >= 0.19.1
wheel wheel
commands = commands =
flake8 src setup.py
check-manifest check-manifest
check-python-versions check-python-versions
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment