Commit a5eebbae authored by Jens Vagelpohl's avatar Jens Vagelpohl

- full linting with flake8

parent 0e19e22b
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# https://github.com/zopefoundation/meta/tree/master/config/pure-python # https://github.com/zopefoundation/meta/tree/master/config/pure-python
[meta] [meta]
template = "pure-python" template = "pure-python"
commit-id = "3b712f305ca8207e971c5bf81f2bdb5872489f2f" commit-id = "0c07a1cfd78d28a07aebd23383ed16959f166574"
[python] [python]
with-windows = false with-windows = false
...@@ -13,7 +13,7 @@ with-docs = true ...@@ -13,7 +13,7 @@ with-docs = true
with-sphinx-doctests = false with-sphinx-doctests = false
[tox] [tox]
use-flake8 = false use-flake8 = true
testenv-commands = [ testenv-commands = [
"# Run unit tests first.", "# Run unit tests first.",
"zope-testrunner -u --test-path=src {posargs:-vc}", "zope-testrunner -u --test-path=src {posargs:-vc}",
......
...@@ -4,6 +4,8 @@ Changelog ...@@ -4,6 +4,8 @@ Changelog
5.4.0 (unreleased) 5.4.0 (unreleased)
------------------ ------------------
- linted the code with flake8
- Add support for Python 3.10. - Add support for Python 3.10.
- Add ``ConflictError`` to the list of unlogged server exceptions - Add ``ConflictError`` to the list of unlogged server exceptions
......
...@@ -11,11 +11,12 @@ ...@@ -11,11 +11,12 @@
# FOR A PARTICULAR PURPOSE. # FOR A PARTICULAR PURPOSE.
# #
############################################################################## ##############################################################################
version = '5.3.1.dev0'
from setuptools import setup, find_packages from setuptools import setup, find_packages
import os import os
version = '5.3.1.dev0'
install_requires = [ install_requires = [
'ZODB >= 5.1.1', 'ZODB >= 5.1.1',
'six', 'six',
...@@ -64,12 +65,14 @@ Operating System :: Unix ...@@ -64,12 +65,14 @@ Operating System :: Unix
Framework :: ZODB Framework :: ZODB
""".strip().split('\n') """.strip().split('\n')
def _modname(path, base, name=''): def _modname(path, base, name=''):
if path == base: if path == base:
return name return name
dirname, basename = os.path.split(path) dirname, basename = os.path.split(path)
return _modname(dirname, base, basename + '.' + name) return _modname(dirname, base, basename + '.' + name)
def _flatten(suite, predicate=lambda *x: True): def _flatten(suite, predicate=lambda *x: True):
from unittest import TestCase from unittest import TestCase
for suite_or_case in suite: for suite_or_case in suite:
...@@ -80,18 +83,20 @@ def _flatten(suite, predicate=lambda *x: True): ...@@ -80,18 +83,20 @@ def _flatten(suite, predicate=lambda *x: True):
for x in _flatten(suite_or_case): for x in _flatten(suite_or_case):
yield x yield x
def _no_layer(suite_or_case): def _no_layer(suite_or_case):
return getattr(suite_or_case, 'layer', None) is None return getattr(suite_or_case, 'layer', None) is None
def _unittests_only(suite, mod_suite): def _unittests_only(suite, mod_suite):
for case in _flatten(mod_suite, _no_layer): for case in _flatten(mod_suite, _no_layer):
suite.addTest(case) suite.addTest(case)
def alltests(): def alltests():
import logging import logging
import pkg_resources import pkg_resources
import unittest import unittest
import ZEO.ClientStorage
class NullHandler(logging.Handler): class NullHandler(logging.Handler):
level = 50 level = 50
...@@ -107,7 +112,8 @@ def alltests(): ...@@ -107,7 +112,8 @@ def alltests():
for dirpath, dirnames, filenames in os.walk(base): for dirpath, dirnames, filenames in os.walk(base):
if os.path.basename(dirpath) == 'tests': if os.path.basename(dirpath) == 'tests':
for filename in filenames: for filename in filenames:
if filename != 'testZEO.py': continue if filename != 'testZEO.py':
continue
if filename.endswith('.py') and filename.startswith('test'): if filename.endswith('.py') and filename.startswith('test'):
mod = __import__( mod = __import__(
_modname(dirpath, base, os.path.splitext(filename)[0]), _modname(dirpath, base, os.path.splitext(filename)[0]),
...@@ -115,11 +121,13 @@ def alltests(): ...@@ -115,11 +121,13 @@ def alltests():
_unittests_only(suite, mod.test_suite()) _unittests_only(suite, mod.test_suite())
return suite return suite
long_description = ( long_description = (
open('README.rst').read() open('README.rst').read()
+ '\n' + + '\n' +
open('CHANGES.rst').read() open('CHANGES.rst').read()
) )
setup(name="ZEO", setup(name="ZEO",
version=version, version=version,
description=long_description.split('\n', 2)[1], description=long_description.split('\n', 2)[1],
...@@ -133,7 +141,7 @@ setup(name="ZEO", ...@@ -133,7 +141,7 @@ setup(name="ZEO",
license="ZPL 2.1", license="ZPL 2.1",
platforms=["any"], platforms=["any"],
classifiers=classifiers, classifiers=classifiers,
test_suite="__main__.alltests", # to support "setup.py test" test_suite="__main__.alltests", # to support "setup.py test"
tests_require=tests_require, tests_require=tests_require,
extras_require={ extras_require={
'test': tests_require, 'test': tests_require,
...@@ -164,4 +172,4 @@ setup(name="ZEO", ...@@ -164,4 +172,4 @@ setup(name="ZEO",
""", """,
include_package_data=True, include_package_data=True,
python_requires='>=2.7.9,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*', python_requires='>=2.7.9,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*',
) )
...@@ -52,9 +52,11 @@ import ZEO.cache ...@@ -52,9 +52,11 @@ import ZEO.cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def tid2time(tid): def tid2time(tid):
return str(TimeStamp(tid)) return str(TimeStamp(tid))
def get_timestamp(prev_ts=None): def get_timestamp(prev_ts=None):
"""Internal helper to return a unique TimeStamp instance. """Internal helper to return a unique TimeStamp instance.
...@@ -69,8 +71,10 @@ def get_timestamp(prev_ts=None): ...@@ -69,8 +71,10 @@ def get_timestamp(prev_ts=None):
t = t.laterThan(prev_ts) t = t.laterThan(prev_ts)
return t return t
MB = 1024**2 MB = 1024**2
@zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage) @zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage)
class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
"""A storage class that is a network client to a remote storage. """A storage class that is a network client to a remote storage.
...@@ -90,7 +94,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -90,7 +94,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
blob_cache_size=None, blob_cache_size_check=10, blob_cache_size=None, blob_cache_size_check=10,
client_label=None, client_label=None,
cache=None, cache=None,
ssl = None, ssl_server_hostname=None, ssl=None, ssl_server_hostname=None,
# Mostly ignored backward-compatability options # Mostly ignored backward-compatability options
client=None, var=None, client=None, var=None,
min_disconnect_poll=1, max_disconnect_poll=None, min_disconnect_poll=1, max_disconnect_poll=None,
...@@ -189,14 +193,15 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -189,14 +193,15 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
if isinstance(addr, int): if isinstance(addr, int):
addr = ('127.0.0.1', addr) addr = ('127.0.0.1', addr)
self.__name__ = name or str(addr) # Standard convention for storages self.__name__ = name or str(addr) # Standard convention for storages
if isinstance(addr, six.string_types): if isinstance(addr, six.string_types):
if WIN: if WIN:
raise ValueError("Unix sockets are not available on Windows") raise ValueError("Unix sockets are not available on Windows")
addr = [addr] addr = [addr]
elif (isinstance(addr, tuple) and len(addr) == 2 and elif (isinstance(addr, tuple) and len(addr) == 2 and
isinstance(addr[0], six.string_types) and isinstance(addr[1], int)): isinstance(addr[0], six.string_types) and
isinstance(addr[1], int)):
addr = [addr] addr = [addr]
logger.info( logger.info(
...@@ -212,7 +217,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -212,7 +217,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._is_read_only = read_only self._is_read_only = read_only
self._read_only_fallback = read_only_fallback self._read_only_fallback = read_only_fallback
self._addr = addr # For tests self._addr = addr # For tests
self._iterators = weakref.WeakValueDictionary() self._iterators = weakref.WeakValueDictionary()
self._iterator_ids = set() self._iterator_ids = set()
...@@ -228,7 +233,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -228,7 +233,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._db = None self._db = None
self._oids = [] # List of pre-fetched oids from server self._oids = [] # List of pre-fetched oids from server
cache = self._cache = open_cache( cache = self._cache = open_cache(
cache, var, client, storage, cache_size) cache, var, client, storage, cache_size)
...@@ -266,7 +271,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -266,7 +271,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
addr, self, cache, storage, addr, self, cache, storage,
ZEO.asyncio.client.Fallback if read_only_fallback else read_only, ZEO.asyncio.client.Fallback if read_only_fallback else read_only,
wait_timeout or 30, wait_timeout or 30,
ssl = ssl, ssl_server_hostname=ssl_server_hostname, ssl=ssl, ssl_server_hostname=ssl_server_hostname,
credentials=credentials, credentials=credentials,
) )
self._call = self._server.call self._call = self._server.call
...@@ -308,6 +313,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -308,6 +313,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._check_blob_size_thread.join() self._check_blob_size_thread.join()
_check_blob_size_thread = None _check_blob_size_thread = None
def _check_blob_size(self, bytes=None): def _check_blob_size(self, bytes=None):
if self._blob_cache_size is None: if self._blob_cache_size is None:
return return
...@@ -349,8 +355,8 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -349,8 +355,8 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
pass pass
_connection_generation = 0 _connection_generation = 0
def notify_connected(self, conn, info): def notify_connected(self, conn, info):
reconnected = self._connection_generation
self.set_server_addr(conn.get_peername()) self.set_server_addr(conn.get_peername())
self.protocol_version = conn.protocol_version self.protocol_version = conn.protocol_version
self._is_read_only = conn.is_read_only() self._is_read_only = conn.is_read_only()
...@@ -373,22 +379,20 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -373,22 +379,20 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._info.update(info) self._info.update(info)
for iface in ( for iface in (ZODB.interfaces.IStorageRestoreable,
ZODB.interfaces.IStorageRestoreable, ZODB.interfaces.IStorageIteration,
ZODB.interfaces.IStorageIteration, ZODB.interfaces.IStorageUndoable,
ZODB.interfaces.IStorageUndoable, ZODB.interfaces.IStorageCurrentRecordIteration,
ZODB.interfaces.IStorageCurrentRecordIteration, ZODB.interfaces.IBlobStorage,
ZODB.interfaces.IBlobStorage, ZODB.interfaces.IExternalGC):
ZODB.interfaces.IExternalGC, if (iface.__module__, iface.__name__) in \
): self._info.get('interfaces', ()):
if (iface.__module__, iface.__name__) in self._info.get(
'interfaces', ()):
zope.interface.alsoProvides(self, iface) zope.interface.alsoProvides(self, iface)
if self.protocol_version[1:] >= b'5': if self.protocol_version[1:] >= b'5':
self.ping = lambda : self._call('ping') self.ping = lambda: self._call('ping')
else: else:
self.ping = lambda : self._call('lastTransaction') self.ping = lambda: self._call('lastTransaction')
if self.server_sync: if self.server_sync:
self.sync = self.ping self.sync = self.ping
...@@ -536,7 +540,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -536,7 +540,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
try: try:
return self._oids.pop() return self._oids.pop()
except IndexError: except IndexError:
pass # We ran out. We need to get some more. pass # We ran out. We need to get some more.
self._oids[:0] = reversed(self._call('new_oids')) self._oids[:0] = reversed(self._call('new_oids'))
...@@ -735,7 +739,6 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -735,7 +739,6 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
finally: finally:
lock.close() lock.close()
def temporaryDirectory(self): def temporaryDirectory(self):
return self.fshelper.temp_dir return self.fshelper.temp_dir
...@@ -747,7 +750,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -747,7 +750,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
conflicts = True conflicts = True
vote_attempts = 0 vote_attempts = 0
while conflicts and vote_attempts < 9: # 9? Mainly avoid inf. loop while conflicts and vote_attempts < 9: # 9? Mainly avoid inf. loop
conflicts = False conflicts = False
for oid in self._call('vote', id(txn)) or (): for oid in self._call('vote', id(txn)) or ():
if isinstance(oid, dict): if isinstance(oid, dict):
...@@ -843,11 +846,11 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -843,11 +846,11 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
def tpc_abort(self, txn, timeout=None): def tpc_abort(self, txn, timeout=None):
"""Storage API: abort a transaction. """Storage API: abort a transaction.
(The timeout keyword argument is for tests to wat longer than (The timeout keyword argument is for tests to wait longer than
they normally would.) they normally would.)
""" """
try: try:
tbuf = txn.data(self) tbuf = txn.data(self) # NOQA: F841 unused variable
except KeyError: except KeyError:
return return
...@@ -899,7 +902,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -899,7 +902,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
while blobs: while blobs:
oid, blobfilename = blobs.pop() oid, blobfilename = blobs.pop()
self._blob_data_bytes_loaded += os.stat(blobfilename).st_size self._blob_data_bytes_loaded += os.stat(blobfilename).st_size
targetpath = self.fshelper.getPathForOID(oid, create=True) self.fshelper.getPathForOID(oid, create=True)
target_blob_file_name = self.fshelper.getBlobFilename(oid, tid) target_blob_file_name = self.fshelper.getBlobFilename(oid, tid)
lock = _lock_blob(target_blob_file_name) lock = _lock_blob(target_blob_file_name)
try: try:
...@@ -1037,6 +1040,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage): ...@@ -1037,6 +1040,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
def server_status(self): def server_status(self):
return self._call('server_status') return self._call('server_status')
class TransactionIterator(object): class TransactionIterator(object):
def __init__(self, storage, iid, *args): def __init__(self, storage, iid, *args):
...@@ -1130,14 +1134,18 @@ class BlobCacheLayout(object): ...@@ -1130,14 +1134,18 @@ class BlobCacheLayout(object):
ZODB.blob.BLOB_SUFFIX) ZODB.blob.BLOB_SUFFIX)
) )
def _accessed(filename): def _accessed(filename):
try: try:
os.utime(filename, (time.time(), os.stat(filename).st_mtime)) os.utime(filename, (time.time(), os.stat(filename).st_mtime))
except OSError: except OSError:
pass # We tried. :) pass # We tried. :)
return filename return filename
cache_file_name = re.compile(r'\d+$').match cache_file_name = re.compile(r'\d+$').match
def _check_blob_cache_size(blob_dir, target): def _check_blob_cache_size(blob_dir, target):
logger = logging.getLogger(__name__+'.check_blob_cache') logger = logging.getLogger(__name__+'.check_blob_cache')
...@@ -1162,7 +1170,7 @@ def _check_blob_cache_size(blob_dir, target): ...@@ -1162,7 +1170,7 @@ def _check_blob_cache_size(blob_dir, target):
# Someone is already cleaning up, so don't bother # Someone is already cleaning up, so don't bother
logger.debug("%s Another thread is checking the blob cache size.", logger.debug("%s Another thread is checking the blob cache size.",
get_ident()) get_ident())
open(attempt_path, 'w').close() # Mark that we tried open(attempt_path, 'w').close() # Mark that we tried
return return
logger.debug("%s Checking blob cache size. (target: %s)", logger.debug("%s Checking blob cache size. (target: %s)",
...@@ -1200,7 +1208,7 @@ def _check_blob_cache_size(blob_dir, target): ...@@ -1200,7 +1208,7 @@ def _check_blob_cache_size(blob_dir, target):
try: try:
os.remove(attempt_path) os.remove(attempt_path)
except OSError: except OSError:
pass # Sigh, windows pass # Sigh, windows
continue continue
logger.debug("%s -->", get_ident()) logger.debug("%s -->", get_ident())
break break
...@@ -1222,8 +1230,8 @@ def _check_blob_cache_size(blob_dir, target): ...@@ -1222,8 +1230,8 @@ def _check_blob_cache_size(blob_dir, target):
fsize = os.stat(file_name).st_size fsize = os.stat(file_name).st_size
try: try:
ZODB.blob.remove_committed(file_name) ZODB.blob.remove_committed(file_name)
except OSError as v: except OSError:
pass # probably open on windows pass # probably open on windows
else: else:
size -= fsize size -= fsize
finally: finally:
...@@ -1238,12 +1246,14 @@ def _check_blob_cache_size(blob_dir, target): ...@@ -1238,12 +1246,14 @@ def _check_blob_cache_size(blob_dir, target):
finally: finally:
check_lock.close() check_lock.close()
def check_blob_size_script(args=None): def check_blob_size_script(args=None):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
blob_dir, target = args blob_dir, target = args
_check_blob_cache_size(blob_dir, int(target)) _check_blob_cache_size(blob_dir, int(target))
def _lock_blob(path): def _lock_blob(path):
lockfilename = os.path.join(os.path.dirname(path), '.lock') lockfilename = os.path.join(os.path.dirname(path), '.lock')
n = 0 n = 0
...@@ -1258,6 +1268,7 @@ def _lock_blob(path): ...@@ -1258,6 +1268,7 @@ def _lock_blob(path):
else: else:
break break
def open_cache(cache, var, client, storage, cache_size): def open_cache(cache, var, client, storage, cache_size):
if isinstance(cache, (None.__class__, str)): if isinstance(cache, (None.__class__, str)):
from ZEO.cache import ClientCache from ZEO.cache import ClientCache
......
...@@ -17,27 +17,33 @@ import transaction.interfaces ...@@ -17,27 +17,33 @@ import transaction.interfaces
from ZODB.POSException import StorageError from ZODB.POSException import StorageError
class ClientStorageError(StorageError): class ClientStorageError(StorageError):
"""An error occurred in the ZEO Client Storage. """An error occurred in the ZEO Client Storage.
""" """
class UnrecognizedResult(ClientStorageError): class UnrecognizedResult(ClientStorageError):
"""A server call returned an unrecognized result. """A server call returned an unrecognized result.
""" """
class ClientDisconnected(ClientStorageError, class ClientDisconnected(ClientStorageError,
transaction.interfaces.TransientError): transaction.interfaces.TransientError):
"""The database storage is disconnected from the storage. """The database storage is disconnected from the storage.
""" """
class AuthError(StorageError): class AuthError(StorageError):
"""The client provided invalid authentication credentials. """The client provided invalid authentication credentials.
""" """
class ProtocolError(ClientStorageError): class ProtocolError(ClientStorageError):
"""A client contacted a server with an incomparible protocol """A client contacted a server with an incomparible protocol
""" """
class ServerException(ClientStorageError): class ServerException(ClientStorageError):
""" """
""" """
This diff is collapsed.
...@@ -21,12 +21,11 @@ is used to store the data until a commit or abort. ...@@ -21,12 +21,11 @@ is used to store the data until a commit or abort.
# A faster implementation might store trans data in memory until it # A faster implementation might store trans data in memory until it
# reaches a certain size. # reaches a certain size.
import os
import tempfile import tempfile
import ZODB.blob
from ZEO._compat import Pickler, Unpickler from ZEO._compat import Pickler, Unpickler
class TransactionBuffer(object): class TransactionBuffer(object):
# The TransactionBuffer is used by client storage to hold update # The TransactionBuffer is used by client storage to hold update
...@@ -44,8 +43,8 @@ class TransactionBuffer(object): ...@@ -44,8 +43,8 @@ class TransactionBuffer(object):
# stored are builtin types -- strings or None. # stored are builtin types -- strings or None.
self.pickler = Pickler(self.file, 1) self.pickler = Pickler(self.file, 1)
self.pickler.fast = 1 self.pickler.fast = 1
self.server_resolved = set() # {oid} self.server_resolved = set() # {oid}
self.client_resolved = {} # {oid -> buffer_record_number} self.client_resolved = {} # {oid -> buffer_record_number}
self.exception = None self.exception = None
def close(self): def close(self):
...@@ -93,9 +92,7 @@ class TransactionBuffer(object): ...@@ -93,9 +92,7 @@ class TransactionBuffer(object):
if oid not in seen: if oid not in seen:
yield oid, None, True yield oid, None, True
# Support ZEO4: # Support ZEO4:
def serialnos(self, args): def serialnos(self, args):
for oid in args: for oid in args:
if isinstance(oid, bytes): if isinstance(oid, bytes):
......
...@@ -21,6 +21,7 @@ ZEO is now part of ZODB; ZODB's home on the web is ...@@ -21,6 +21,7 @@ ZEO is now part of ZODB; ZODB's home on the web is
""" """
def client(*args, **kw): def client(*args, **kw):
""" """
Shortcut for :class:`ZEO.ClientStorage.ClientStorage`. Shortcut for :class:`ZEO.ClientStorage.ClientStorage`.
...@@ -28,6 +29,7 @@ def client(*args, **kw): ...@@ -28,6 +29,7 @@ def client(*args, **kw):
import ZEO.ClientStorage import ZEO.ClientStorage
return ZEO.ClientStorage.ClientStorage(*args, **kw) return ZEO.ClientStorage.ClientStorage(*args, **kw)
def DB(*args, **kw): def DB(*args, **kw):
""" """
Shortcut for creating a :class:`ZODB.DB` using a ZEO :func:`~ZEO.client`. Shortcut for creating a :class:`ZODB.DB` using a ZEO :func:`~ZEO.client`.
...@@ -40,6 +42,7 @@ def DB(*args, **kw): ...@@ -40,6 +42,7 @@ def DB(*args, **kw):
s.close() s.close()
raise raise
def connection(*args, **kw): def connection(*args, **kw):
db = DB(*args, **kw) db = DB(*args, **kw)
try: try:
...@@ -48,6 +51,7 @@ def connection(*args, **kw): ...@@ -48,6 +51,7 @@ def connection(*args, **kw):
db.close() db.close()
raise raise
def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None, def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None,
port=0, threaded=True, **kw): port=0, threaded=True, **kw):
"""Convenience function to start a server for interactive exploration """Convenience function to start a server for interactive exploration
......
...@@ -16,13 +16,20 @@ ...@@ -16,13 +16,20 @@
import sys import sys
import platform import platform
from ZODB._compat import BytesIO # NOQA: F401 unused import
PY3 = sys.version_info[0] >= 3 PY3 = sys.version_info[0] >= 3
PY32 = sys.version_info[:2] == (3, 2) PY32 = sys.version_info[:2] == (3, 2)
PYPY = getattr(platform, 'python_implementation', lambda: None)() == 'PyPy' PYPY = getattr(platform, 'python_implementation', lambda: None)() == 'PyPy'
WIN = sys.platform.startswith('win') WIN = sys.platform.startswith('win')
if PY3: if PY3:
from zodbpickle.pickle import Pickler, Unpickler as _Unpickler, dump, dumps, loads from zodbpickle.pickle import dump
from zodbpickle.pickle import dumps
from zodbpickle.pickle import loads
from zodbpickle.pickle import Pickler
from zodbpickle.pickle import Unpickler as _Unpickler
class Unpickler(_Unpickler): class Unpickler(_Unpickler):
# Py3: Python 3 doesn't allow assignments to find_global, # Py3: Python 3 doesn't allow assignments to find_global,
# instead, find_class can be overridden # instead, find_class can be overridden
...@@ -44,24 +51,17 @@ else: ...@@ -44,24 +51,17 @@ else:
dumps = cPickle.dumps dumps = cPickle.dumps
loads = cPickle.loads loads = cPickle.loads
# String and Bytes IO
from ZODB._compat import BytesIO
if PY3: if PY3:
import _thread as thread # NOQA: F401 unused import
import _thread as thread
if PY32: if PY32:
from threading import _get_ident as get_ident from threading import _get_ident as get_ident # NOQA: F401 unused
else: else:
from threading import get_ident from threading import get_ident # NOQA: F401 unused import
else: else:
import thread # NOQA: F401 unused import
import thread from thread import get_ident # NOQA: F401 unused import
from thread import get_ident
try: try:
from cStringIO import StringIO from cStringIO import StringIO # NOQA: F401 unused import
except: except ImportError:
from io import StringIO from io import StringIO # NOQA: F401 unused import
...@@ -26,11 +26,10 @@ import six ...@@ -26,11 +26,10 @@ import six
from ZEO._compat import StringIO from ZEO._compat import StringIO
logger = logging.getLogger('ZEO.tests.forker') logger = logging.getLogger('ZEO.tests.forker')
DEBUG = os.environ.get('ZEO_TEST_SERVER_DEBUG') DEBUG = os.environ.get('ZEO_TEST_SERVER_DEBUG')
ZEO4_SERVER = os.environ.get('ZEO4_SERVER') ZEO4_SERVER = os.environ.get('ZEO4_SERVER')
class ZEOConfig(object): class ZEOConfig(object):
"""Class to generate ZEO configuration file. """ """Class to generate ZEO configuration file. """
...@@ -61,8 +60,7 @@ class ZEOConfig(object): ...@@ -61,8 +60,7 @@ class ZEOConfig(object):
for name in ( for name in (
'invalidation_queue_size', 'invalidation_age', 'invalidation_queue_size', 'invalidation_age',
'transaction_timeout', 'pid_filename', 'msgpack', 'transaction_timeout', 'pid_filename', 'msgpack',
'ssl_certificate', 'ssl_key', 'client_conflict_resolution', 'ssl_certificate', 'ssl_key', 'client_conflict_resolution'):
):
v = getattr(self, name, None) v = getattr(self, name, None)
if v: if v:
print(name.replace('_', '-'), v, file=f) print(name.replace('_', '-'), v, file=f)
...@@ -134,7 +132,7 @@ def runner(config, qin, qout, timeout=None, ...@@ -134,7 +132,7 @@ def runner(config, qin, qout, timeout=None,
os.remove(config) os.remove(config)
try: try:
qin.get(timeout=timeout) # wait for shutdown qin.get(timeout=timeout) # wait for shutdown
except Empty: except Empty:
pass pass
server.server.close() server.server.close()
...@@ -158,6 +156,7 @@ def runner(config, qin, qout, timeout=None, ...@@ -158,6 +156,7 @@ def runner(config, qin, qout, timeout=None,
ZEO.asyncio.server.best_protocol_version = old_protocol ZEO.asyncio.server.best_protocol_version = old_protocol
ZEO.asyncio.server.ServerProtocol.protocols = old_protocols ZEO.asyncio.server.ServerProtocol.protocols = old_protocols
def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None): def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None):
qin.put('stop') qin.put('stop')
try: try:
...@@ -180,6 +179,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None): ...@@ -180,6 +179,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None):
gc.collect() gc.collect()
def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False, def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
path='Data.fs', protocol=None, blob_dir=None, path='Data.fs', protocol=None, blob_dir=None,
suicide=True, debug=False, suicide=True, debug=False,
...@@ -220,7 +220,8 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False, ...@@ -220,7 +220,8 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
print(zeo_conf) print(zeo_conf)
# Store the config info in a temp file. # Store the config info in a temp file.
fd, tmpfile = tempfile.mkstemp(".conf", prefix='ZEO_forker', dir=os.getcwd()) fd, tmpfile = tempfile.mkstemp(".conf", prefix='ZEO_forker',
dir=os.getcwd())
with os.fdopen(fd, 'w') as fp: with os.fdopen(fd, 'w') as fp:
fp.write(zeo_conf) fp.write(zeo_conf)
...@@ -273,10 +274,12 @@ def debug_logging(logger='ZEO', stream='stderr', level=logging.DEBUG): ...@@ -273,10 +274,12 @@ def debug_logging(logger='ZEO', stream='stderr', level=logging.DEBUG):
return stop return stop
def whine(*message): def whine(*message):
print(*message, file=sys.stderr) print(*message, file=sys.stderr)
sys.stderr.flush() sys.stderr.flush()
class ThreadlessQueue(object): class ThreadlessQueue(object):
def __init__(self): def __init__(self):
......
...@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__) ...@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
INET_FAMILIES = socket.AF_INET, socket.AF_INET6 INET_FAMILIES = socket.AF_INET, socket.AF_INET6
class Protocol(asyncio.Protocol): class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO base interface """asyncio low-level ZEO base interface
""" """
...@@ -30,9 +31,9 @@ class Protocol(asyncio.Protocol): ...@@ -30,9 +31,9 @@ class Protocol(asyncio.Protocol):
def __init__(self, loop, addr): def __init__(self, loop, addr):
self.loop = loop self.loop = loop
self.addr = addr self.addr = addr
self.input = [] # Input buffer when assembling messages self.input = [] # Input buffer when assembling messages
self.output = [] # Output buffer when paused self.output = [] # Output buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup self.paused = [] # Paused indicator, mutable to avoid attr lookup
# Handle the first message, the protocol handshake, differently # Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received self.message_received = self.first_message_received
...@@ -41,6 +42,7 @@ class Protocol(asyncio.Protocol): ...@@ -41,6 +42,7 @@ class Protocol(asyncio.Protocol):
return self.name return self.name
closed = False closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -50,7 +52,6 @@ class Protocol(asyncio.Protocol): ...@@ -50,7 +52,6 @@ class Protocol(asyncio.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
logger.info("Connected %s", self) logger.info("Connected %s", self)
if sys.version_info < (3, 6): if sys.version_info < (3, 6):
sock = transport.get_extra_info('socket') sock = transport.get_extra_info('socket')
if sock is not None and sock.family in INET_FAMILIES: if sock is not None and sock.family in INET_FAMILIES:
...@@ -91,6 +92,7 @@ class Protocol(asyncio.Protocol): ...@@ -91,6 +92,7 @@ class Protocol(asyncio.Protocol):
got = 0 got = 0
want = 4 want = 4
getting_size = True getting_size = True
def data_received(self, data): def data_received(self, data):
# Low-level input handler collects data into sized messages. # Low-level input handler collects data into sized messages.
...@@ -135,7 +137,7 @@ class Protocol(asyncio.Protocol): ...@@ -135,7 +137,7 @@ class Protocol(asyncio.Protocol):
def first_message_received(self, protocol_version): def first_message_received(self, protocol_version):
# Handler for first/handshake message, set up in __init__ # Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on del self.message_received # use default handler from here on
self.finish_connect(protocol_version) self.finish_connect(protocol_version)
def call_async(self, method, args): def call_async(self, method, args):
...@@ -162,7 +164,7 @@ class Protocol(asyncio.Protocol): ...@@ -162,7 +164,7 @@ class Protocol(asyncio.Protocol):
data = message data = message
for message in data: for message in data:
writelines((pack(">I", len(message)), message)) writelines((pack(">I", len(message)), message))
if paused: # paused again. Put iter back. if paused: # paused again. Put iter back.
output.insert(0, data) output.insert(0, data)
break break
......
...@@ -19,7 +19,8 @@ logger = logging.getLogger(__name__) ...@@ -19,7 +19,8 @@ logger = logging.getLogger(__name__)
Fallback = object() Fallback = object()
local_random = random.Random() # use separate generator to facilitate tests local_random = random.Random() # use separate generator to facilitate tests
def future_generator(func): def future_generator(func):
"""Decorates a generator that generates futures """Decorates a generator that generates futures
...@@ -52,6 +53,7 @@ def future_generator(func): ...@@ -52,6 +53,7 @@ def future_generator(func):
return call_generator return call_generator
class Protocol(base.Protocol): class Protocol(base.Protocol):
"""asyncio low-level ZEO client interface """asyncio low-level ZEO client interface
""" """
...@@ -85,7 +87,7 @@ class Protocol(base.Protocol): ...@@ -85,7 +87,7 @@ class Protocol(base.Protocol):
self.client = client self.client = client
self.connect_poll = connect_poll self.connect_poll = connect_poll
self.heartbeat_interval = heartbeat_interval self.heartbeat_interval = heartbeat_interval
self.futures = {} # { message_id -> future } self.futures = {} # { message_id -> future }
self.ssl = ssl self.ssl = ssl
self.ssl_server_hostname = ssl_server_hostname self.ssl_server_hostname = ssl_server_hostname
self.credentials = credentials self.credentials = credentials
...@@ -132,7 +134,9 @@ class Protocol(base.Protocol): ...@@ -132,7 +134,9 @@ class Protocol(base.Protocol):
elif future.exception() is not None: elif future.exception() is not None:
logger.info("Connection to %r failed, %s", logger.info("Connection to %r failed, %s",
self.addr, future.exception()) self.addr, future.exception())
else: return else:
return
# keep trying # keep trying
if not self.closed: if not self.closed:
logger.info("retry connecting %r", self.addr) logger.info("retry connecting %r", self.addr)
...@@ -141,7 +145,6 @@ class Protocol(base.Protocol): ...@@ -141,7 +145,6 @@ class Protocol(base.Protocol):
self.connect, self.connect,
) )
def connection_made(self, transport): def connection_made(self, transport):
super(Protocol, self).connection_made(transport) super(Protocol, self).connection_made(transport)
self.heartbeat(write=False) self.heartbeat(write=False)
...@@ -190,7 +193,8 @@ class Protocol(base.Protocol): ...@@ -190,7 +193,8 @@ class Protocol(base.Protocol):
try: try:
server_tid = yield self.fut( server_tid = yield self.fut(
'register', self.storage_key, 'register', self.storage_key,
self.read_only if self.read_only is not Fallback else False, (self.read_only if self.read_only is not Fallback
else False),
*credentials) *credentials)
except ZODB.POSException.ReadOnlyError: except ZODB.POSException.ReadOnlyError:
if self.read_only is Fallback: if self.read_only is Fallback:
...@@ -208,11 +212,12 @@ class Protocol(base.Protocol): ...@@ -208,11 +212,12 @@ class Protocol(base.Protocol):
self.client.registered(self, server_tid) self.client.registered(self, server_tid)
exception_type_type = type(Exception) exception_type_type = type(Exception)
def message_received(self, data): def message_received(self, data):
msgid, async_, name, args = self.decode(data) msgid, async_, name, args = self.decode(data)
if name == '.reply': if name == '.reply':
future = self.futures.pop(msgid) future = self.futures.pop(msgid)
if async_: # ZEO 5 exception if async_: # ZEO 5 exception
class_, args = args class_, args = args
factory = exc_factories.get(class_) factory = exc_factories.get(class_)
if factory: if factory:
...@@ -237,13 +242,14 @@ class Protocol(base.Protocol): ...@@ -237,13 +242,14 @@ class Protocol(base.Protocol):
else: else:
future.set_result(args) future.set_result(args)
else: else:
assert async_ # clients only get async calls assert async_ # clients only get async calls
if name in self.client_methods: if name in self.client_methods:
getattr(self.client, name)(*args) getattr(self.client, name)(*args)
else: else:
raise AttributeError(name) raise AttributeError(name)
message_id = 0 message_id = 0
def call(self, future, method, args): def call(self, future, method, args):
self.message_id += 1 self.message_id += 1
self.futures[self.message_id] = future self.futures[self.message_id] = future
...@@ -262,6 +268,7 @@ class Protocol(base.Protocol): ...@@ -262,6 +268,7 @@ class Protocol(base.Protocol):
self.futures[message_id] = future self.futures[message_id] = future
self._write( self._write(
self.encode(message_id, False, 'loadBefore', (oid, tid))) self.encode(message_id, False, 'loadBefore', (oid, tid)))
@future.add_done_callback @future.add_done_callback
def _(future): def _(future):
try: try:
...@@ -271,6 +278,7 @@ class Protocol(base.Protocol): ...@@ -271,6 +278,7 @@ class Protocol(base.Protocol):
if data: if data:
data, start, end = data data, start, end = data
self.client.cache.store(oid, start, end, data) self.client.cache.store(oid, start, end, data)
return future return future
# Methods called by the server. # Methods called by the server.
...@@ -290,29 +298,34 @@ class Protocol(base.Protocol): ...@@ -290,29 +298,34 @@ class Protocol(base.Protocol):
self.heartbeat_handle = self.loop.call_later( self.heartbeat_handle = self.loop.call_later(
self.heartbeat_interval, self.heartbeat) self.heartbeat_interval, self.heartbeat)
def create_Exception(class_, args): def create_Exception(class_, args):
return exc_classes[class_](*args) return exc_classes[class_](*args)
def create_ConflictError(class_, args): def create_ConflictError(class_, args):
exc = exc_classes[class_]( exc = exc_classes[class_](
message = args['message'], message=args['message'],
oid = args['oid'], oid=args['oid'],
serials = args['serials'], serials=args['serials'],
) )
exc.class_name = args.get('class_name') exc.class_name = args.get('class_name')
return exc return exc
def create_BTreesConflictError(class_, args): def create_BTreesConflictError(class_, args):
return ZODB.POSException.BTreesConflictError( return ZODB.POSException.BTreesConflictError(
p1 = args['p1'], p1=args['p1'],
p2 = args['p2'], p2=args['p2'],
p3 = args['p3'], p3=args['p3'],
reason = args['reason'], reason=args['reason'],
) )
def create_MultipleUndoErrors(class_, args): def create_MultipleUndoErrors(class_, args):
return ZODB.POSException.MultipleUndoErrors(args['_errs']) return ZODB.POSException.MultipleUndoErrors(args['_errs'])
exc_classes = { exc_classes = {
'builtins.KeyError': KeyError, 'builtins.KeyError': KeyError,
'builtins.TypeError': TypeError, 'builtins.TypeError': TypeError,
...@@ -340,6 +353,8 @@ exc_factories = { ...@@ -340,6 +353,8 @@ exc_factories = {
} }
unlogged_exceptions = (ZODB.POSException.POSKeyError, unlogged_exceptions = (ZODB.POSException.POSKeyError,
ZODB.POSException.ConflictError) ZODB.POSException.ConflictError)
class Client(object): class Client(object):
"""asyncio low-level ZEO client interface """asyncio low-level ZEO client interface
""" """
...@@ -352,8 +367,11 @@ class Client(object): ...@@ -352,8 +367,11 @@ class Client(object):
# connect. # connect.
protocol = None protocol = None
ready = None # Tri-value: None=Never connected, True=connected, # ready can have three values:
# False=Disconnected # None=Never connected
# True=connected
# False=Disconnected
ready = None
def __init__(self, loop, def __init__(self, loop,
addrs, client, cache, storage_key, read_only, connect_poll, addrs, client, cache, storage_key, read_only, connect_poll,
...@@ -404,6 +422,7 @@ class Client(object): ...@@ -404,6 +422,7 @@ class Client(object):
self.is_read_only() and self.read_only is Fallback) self.is_read_only() and self.read_only is Fallback)
closed = False closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -466,7 +485,7 @@ class Client(object): ...@@ -466,7 +485,7 @@ class Client(object):
self.upgrade(protocol) self.upgrade(protocol)
self.verify(server_tid) self.verify(server_tid)
else: else:
protocol.close() # too late, we went home with another protocol.close() # too late, we went home with another
def register_failed(self, protocol, exc): def register_failed(self, protocol, exc):
# A protocol failed registration. That's weird. If they've all # A protocol failed registration. That's weird. If they've all
...@@ -474,18 +493,17 @@ class Client(object): ...@@ -474,18 +493,17 @@ class Client(object):
if protocol is not self: if protocol is not self:
protocol.close() protocol.close()
logger.exception("Registration or cache validation failed, %s", exc) logger.exception("Registration or cache validation failed, %s", exc)
if (self.protocol is None and not if self.protocol is None and \
any(not p.closed for p in self.protocols) not any(not p.closed for p in self.protocols):
):
self.loop.call_later( self.loop.call_later(
self.register_failed_poll + local_random.random(), self.register_failed_poll + local_random.random(),
self.try_connecting) self.try_connecting)
verify_result = None # for tests verify_result = None # for tests
@future_generator @future_generator
def verify(self, server_tid): def verify(self, server_tid):
self.verify_invalidation_queue = [] # See comment in init :( self.verify_invalidation_queue = [] # See comment in init :(
protocol = self.protocol protocol = self.protocol
try: try:
...@@ -739,6 +757,7 @@ class Client(object): ...@@ -739,6 +757,7 @@ class Client(object):
else: else:
return protocol.read_only return protocol.read_only
class ClientRunner(object): class ClientRunner(object):
def set_options(self, addrs, wrapper, cache, storage_key, read_only, def set_options(self, addrs, wrapper, cache, storage_key, read_only,
...@@ -855,6 +874,7 @@ class ClientRunner(object): ...@@ -855,6 +874,7 @@ class ClientRunner(object):
timeout = self.timeout timeout = self.timeout
self.wait_for_result(self.client.connected, timeout) self.wait_for_result(self.client.connected, timeout)
class ClientThread(ClientRunner): class ClientThread(ClientRunner):
"""Thread wrapper for client interface """Thread wrapper for client interface
...@@ -883,6 +903,7 @@ class ClientThread(ClientRunner): ...@@ -883,6 +903,7 @@ class ClientThread(ClientRunner):
raise self.exception raise self.exception
exception = None exception = None
def run(self): def run(self):
loop = None loop = None
try: try:
...@@ -909,6 +930,7 @@ class ClientThread(ClientRunner): ...@@ -909,6 +930,7 @@ class ClientThread(ClientRunner):
logger.debug('Stopping client thread') logger.debug('Stopping client thread')
closed = False closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -918,6 +940,7 @@ class ClientThread(ClientRunner): ...@@ -918,6 +940,7 @@ class ClientThread(ClientRunner):
if self.exception: if self.exception:
raise self.exception raise self.exception
class Fut(object): class Fut(object):
"""Lightweight future that calls it's callbacks immediately rather than soon """Lightweight future that calls it's callbacks immediately rather than soon
""" """
...@@ -929,6 +952,7 @@ class Fut(object): ...@@ -929,6 +952,7 @@ class Fut(object):
self.cbv.append(cb) self.cbv.append(cb)
exc = None exc = None
def set_exception(self, exc): def set_exception(self, exc):
self.exc = exc self.exc = exc
for cb in self.cbv: for cb in self.cbv:
......
...@@ -6,5 +6,5 @@ if PY3: ...@@ -6,5 +6,5 @@ if PY3:
except ImportError: except ImportError:
from asyncio import new_event_loop from asyncio import new_event_loop
else: else:
import trollius as asyncio import trollius as asyncio # NOQA: F401 unused import
from trollius import new_event_loop from trollius import new_event_loop # NOQA: F401 unused import
...@@ -21,19 +21,22 @@ Python-independent format, or possibly a minimal pickle subset. ...@@ -21,19 +21,22 @@ Python-independent format, or possibly a minimal pickle subset.
import logging import logging
from .._compat import Unpickler, Pickler, BytesIO, PY3, PYPY from .._compat import Unpickler, Pickler, BytesIO, PY3
from ..shortrepr import short_repr from ..shortrepr import short_repr
PY2 = not PY3 PY2 = not PY3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def encoder(protocol, server=False): def encoder(protocol, server=False):
"""Return a non-thread-safe encoder """Return a non-thread-safe encoder
""" """
if protocol[:1] == b'M': if protocol[:1] == b'M':
from msgpack import packb from msgpack import packb
default = server_default if server else None default = server_default if server else None
def encode(*args): def encode(*args):
return packb( return packb(
args, use_bin_type=True, default=default) args, use_bin_type=True, default=default)
...@@ -49,6 +52,7 @@ def encoder(protocol, server=False): ...@@ -49,6 +52,7 @@ def encoder(protocol, server=False):
pickler = Pickler(f, 3) pickler = Pickler(f, 3)
pickler.fast = 1 pickler.fast = 1
dump = pickler.dump dump = pickler.dump
def encode(*args): def encode(*args):
seek(0) seek(0)
truncate() truncate()
...@@ -57,21 +61,26 @@ def encoder(protocol, server=False): ...@@ -57,21 +61,26 @@ def encoder(protocol, server=False):
return encode return encode
def encode(*args): def encode(*args):
return encoder(b'Z')(*args) return encoder(b'Z')(*args)
def decoder(protocol): def decoder(protocol):
if protocol[:1] == b'M': if protocol[:1] == b'M':
from msgpack import unpackb from msgpack import unpackb
def msgpack_decode(data): def msgpack_decode(data):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
return unpackb(data, raw=False, use_list=False) return unpackb(data, raw=False, use_list=False)
return msgpack_decode return msgpack_decode
else: else:
assert protocol[:1] == b'Z' assert protocol[:1] == b'Z'
return pickle_decode return pickle_decode
def pickle_decode(msg): def pickle_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
...@@ -82,11 +91,12 @@ def pickle_decode(msg): ...@@ -82,11 +91,12 @@ def pickle_decode(msg):
except AttributeError: except AttributeError:
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except: # NOQA: E722 bare except
logger.error("can't decode message: %s" % short_repr(msg)) logger.error("can't decode message: %s" % short_repr(msg))
raise raise
def server_decoder(protocol): def server_decoder(protocol):
if protocol[:1] == b'M': if protocol[:1] == b'M':
return decoder(protocol) return decoder(protocol)
...@@ -94,6 +104,7 @@ def server_decoder(protocol): ...@@ -94,6 +104,7 @@ def server_decoder(protocol):
assert protocol[:1] == b'Z' assert protocol[:1] == b'Z'
return pickle_server_decode return pickle_server_decode
def pickle_server_decode(msg): def pickle_server_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
...@@ -105,22 +116,25 @@ def pickle_server_decode(msg): ...@@ -105,22 +116,25 @@ def pickle_server_decode(msg):
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except: # NOQA: E722 bare except
logger.error("can't decode message: %s" % short_repr(msg)) logger.error("can't decode message: %s" % short_repr(msg))
raise raise
def server_default(obj): def server_default(obj):
if isinstance(obj, Exception): if isinstance(obj, Exception):
return reduce_exception(obj) return reduce_exception(obj)
else: else:
return obj return obj
def reduce_exception(exc): def reduce_exception(exc):
class_ = exc.__class__ class_ = exc.__class__
class_ = "%s.%s" % (class_.__module__, class_.__name__) class_ = "%s.%s" % (class_.__module__, class_.__name__)
return class_, exc.__dict__ or exc.args return class_, exc.__dict__ or exc.args
_globals = globals() _globals = globals()
_silly = ('__doc__',) _silly = ('__doc__',)
...@@ -131,6 +145,7 @@ _SAFE_MODULE_NAMES = ( ...@@ -131,6 +145,7 @@ _SAFE_MODULE_NAMES = (
'builtins', 'copy_reg', '__builtin__', 'builtins', 'copy_reg', '__builtin__',
) )
def find_global(module, name): def find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
try: try:
...@@ -143,7 +158,8 @@ def find_global(module, name): ...@@ -143,7 +158,8 @@ def find_global(module, name):
except AttributeError: except AttributeError:
raise ImportError("module %s has no global %s" % (module, name)) raise ImportError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) or (PY2 and module in _SAFE_MODULE_NAMES) safe = (getattr(r, '__no_side_effects__', 0) or
(PY2 and module in _SAFE_MODULE_NAMES))
if safe: if safe:
return r return r
...@@ -153,6 +169,7 @@ def find_global(module, name): ...@@ -153,6 +169,7 @@ def find_global(module, name):
raise ImportError("Unsafe global: %s.%s" % (module, name)) raise ImportError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name): def server_find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
if module not in _SAFE_MODULE_NAMES: if module not in _SAFE_MODULE_NAMES:
......
...@@ -72,6 +72,7 @@ import logging ...@@ -72,6 +72,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Acceptor(asyncore.dispatcher): class Acceptor(asyncore.dispatcher):
"""A server that accepts incoming RPC connections """A server that accepts incoming RPC connections
...@@ -115,13 +116,13 @@ class Acceptor(asyncore.dispatcher): ...@@ -115,13 +116,13 @@ class Acceptor(asyncore.dispatcher):
for i in range(25): for i in range(25):
try: try:
self.bind(addr) self.bind(addr)
except Exception as exc: except Exception:
logger.info("bind on %s failed %s waiting", addr, i) logger.info("bind on %s failed %s waiting", addr, i)
if i == 24: if i == 24:
raise raise
else: else:
time.sleep(5) time.sleep(5)
except: except: # NOQA: E722 bare except
logger.exception('binding') logger.exception('binding')
raise raise
else: else:
...@@ -146,7 +147,6 @@ class Acceptor(asyncore.dispatcher): ...@@ -146,7 +147,6 @@ class Acceptor(asyncore.dispatcher):
logger.info("accepted failed: %s", msg) logger.info("accepted failed: %s", msg)
return return
# We could short-circuit the attempt below in some edge cases # We could short-circuit the attempt below in some edge cases
# and avoid a log message by checking for addr being None. # and avoid a log message by checking for addr being None.
# Unfortunately, our test for the code below, # Unfortunately, our test for the code below,
...@@ -159,7 +159,7 @@ class Acceptor(asyncore.dispatcher): ...@@ -159,7 +159,7 @@ class Acceptor(asyncore.dispatcher):
# closed, but I don't see a way to do that. :( # closed, but I don't see a way to do that. :(
# Drop flow-info from IPv6 addresses # Drop flow-info from IPv6 addresses
if addr: # Sometimes None on Mac. See above. if addr: # Sometimes None on Mac. See above.
addr = addr[:2] addr = addr[:2]
try: try:
...@@ -172,23 +172,25 @@ class Acceptor(asyncore.dispatcher): ...@@ -172,23 +172,25 @@ class Acceptor(asyncore.dispatcher):
protocol.stop = loop.stop protocol.stop = loop.stop
if self.ssl_context is None: if self.ssl_context is None:
cr = loop.create_connection((lambda : protocol), sock=sock) cr = loop.create_connection((lambda: protocol), sock=sock)
else: else:
if hasattr(loop, 'connect_accepted_socket'): if hasattr(loop, 'connect_accepted_socket'):
cr = loop.connect_accepted_socket( cr = loop.connect_accepted_socket(
(lambda : protocol), sock, ssl=self.ssl_context) (lambda: protocol), sock, ssl=self.ssl_context)
else: else:
####################################################### #######################################################
# XXX See http://bugs.python.org/issue27392 :( # XXX See http://bugs.python.org/issue27392 :(
_make_ssl_transport = loop._make_ssl_transport _make_ssl_transport = loop._make_ssl_transport
def make_ssl_transport(*a, **kw): def make_ssl_transport(*a, **kw):
kw['server_side'] = True kw['server_side'] = True
return _make_ssl_transport(*a, **kw) return _make_ssl_transport(*a, **kw)
loop._make_ssl_transport = make_ssl_transport loop._make_ssl_transport = make_ssl_transport
# #
####################################################### #######################################################
cr = loop.create_connection( cr = loop.create_connection(
(lambda : protocol), sock=sock, (lambda: protocol), sock=sock,
ssl=self.ssl_context, ssl=self.ssl_context,
server_hostname='' server_hostname=''
) )
...@@ -212,11 +214,12 @@ class Acceptor(asyncore.dispatcher): ...@@ -212,11 +214,12 @@ class Acceptor(asyncore.dispatcher):
asyncore.loop(map=self.__socket_map, timeout=timeout) asyncore.loop(map=self.__socket_map, timeout=timeout)
except Exception: except Exception:
if not self.__closed: if not self.__closed:
raise # Unexpected exc raise # Unexpected exc
logger.debug('acceptor %s loop stopped', self.addr) logger.debug('acceptor %s loop stopped', self.addr)
__closed = False __closed = False
def close(self): def close(self):
if not self.__closed: if not self.__closed:
self.__closed = True self.__closed = True
......
import json import json
import logging import logging
import os import os
import random
import threading import threading
import ZODB.POSException import ZODB.POSException
logger = logging.getLogger(__name__)
from ..shortrepr import short_repr from ..shortrepr import short_repr
from . import base from . import base
from .compat import asyncio, new_event_loop from .compat import asyncio, new_event_loop
from .marshal import server_decoder, encoder, reduce_exception from .marshal import server_decoder, encoder, reduce_exception
logger = logging.getLogger(__name__)
class ServerProtocol(base.Protocol): class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface """asyncio low-level ZEO server interface
""" """
...@@ -39,6 +40,7 @@ class ServerProtocol(base.Protocol): ...@@ -39,6 +40,7 @@ class ServerProtocol(base.Protocol):
) )
closed = False closed = False
def close(self): def close(self):
logger.debug("Closing server protocol") logger.debug("Closing server protocol")
if not self.closed: if not self.closed:
...@@ -46,7 +48,8 @@ class ServerProtocol(base.Protocol): ...@@ -46,7 +48,8 @@ class ServerProtocol(base.Protocol):
if self.transport is not None: if self.transport is not None:
self.transport.close() self.transport.close()
connected = None # for tests connected = None # for tests
def connection_made(self, transport): def connection_made(self, transport):
self.connected = True self.connected = True
super(ServerProtocol, self).connection_made(transport) super(ServerProtocol, self).connection_made(transport)
...@@ -60,7 +63,7 @@ class ServerProtocol(base.Protocol): ...@@ -60,7 +63,7 @@ class ServerProtocol(base.Protocol):
self.stop() self.stop()
def stop(self): def stop(self):
pass # Might be replaced when running a thread per client pass # Might be replaced when running a thread per client
def finish_connect(self, protocol_version): def finish_connect(self, protocol_version):
if protocol_version == b'ruok': if protocol_version == b'ruok':
...@@ -95,7 +98,7 @@ class ServerProtocol(base.Protocol): ...@@ -95,7 +98,7 @@ class ServerProtocol(base.Protocol):
return return
if message_id == -1: if message_id == -1:
return # keep-alive return # keep-alive
if name not in self.methods: if name not in self.methods:
logger.error('Invalid method, %r', name) logger.error('Invalid method, %r', name)
...@@ -109,7 +112,7 @@ class ServerProtocol(base.Protocol): ...@@ -109,7 +112,7 @@ class ServerProtocol(base.Protocol):
"%s`%r` raised exception:", "%s`%r` raised exception:",
'async ' if async_ else '', name) 'async ' if async_ else '', name)
if async_: if async_:
return self.close() # No way to recover/cry for help return self.close() # No way to recover/cry for help
else: else:
return self.send_error(message_id, exc) return self.send_error(message_id, exc)
...@@ -147,16 +150,19 @@ class ServerProtocol(base.Protocol): ...@@ -147,16 +150,19 @@ class ServerProtocol(base.Protocol):
def async_threadsafe(self, method, *args): def async_threadsafe(self, method, *args):
self.call_soon_threadsafe(self.call_async, method, args) self.call_soon_threadsafe(self.call_async, method, args)
best_protocol_version = os.environ.get( best_protocol_version = os.environ.get(
'ZEO_SERVER_PROTOCOL', 'ZEO_SERVER_PROTOCOL',
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8') ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols assert best_protocol_version in ServerProtocol.protocols
def new_connection(loop, addr, socket, zeo_storage, msgpack): def new_connection(loop, addr, socket, zeo_storage, msgpack):
protocol = ServerProtocol(loop, addr, zeo_storage, msgpack) protocol = ServerProtocol(loop, addr, zeo_storage, msgpack)
cr = loop.create_connection((lambda : protocol), sock=socket) cr = loop.create_connection((lambda: protocol), sock=socket)
asyncio.ensure_future(cr, loop=loop) asyncio.ensure_future(cr, loop=loop)
class Delay(object): class Delay(object):
"""Used to delay response to client for synchronous calls. """Used to delay response to client for synchronous calls.
...@@ -192,6 +198,7 @@ class Delay(object): ...@@ -192,6 +198,7 @@ class Delay(object):
def __reduce__(self): def __reduce__(self):
raise TypeError("Can't pickle delays.") raise TypeError("Can't pickle delays.")
class Result(Delay): class Result(Delay):
def __init__(self, *args): def __init__(self, *args):
...@@ -202,6 +209,7 @@ class Result(Delay): ...@@ -202,6 +209,7 @@ class Result(Delay):
protocol.send_reply(msgid, reply) protocol.send_reply(msgid, reply)
callback() callback()
class MTDelay(Delay): class MTDelay(Delay):
def __init__(self): def __init__(self):
...@@ -266,6 +274,7 @@ class Acceptor(object): ...@@ -266,6 +274,7 @@ class Acceptor(object):
self.event_loop.close() self.event_loop.close()
closed = False closed = False
def close(self): def close(self):
if not self.closed: if not self.closed:
self.closed = True self.closed = True
...@@ -277,6 +286,7 @@ class Acceptor(object): ...@@ -277,6 +286,7 @@ class Acceptor(object):
self.server.close() self.server.close()
f = asyncio.ensure_future(self.server.wait_closed(), loop=loop) f = asyncio.ensure_future(self.server.wait_closed(), loop=loop)
@f.add_done_callback @f.add_done_callback
def server_closed(f): def server_closed(f):
# stop the loop when the server closes: # stop the loop when the server closes:
......
...@@ -11,7 +11,6 @@ except NameError: ...@@ -11,7 +11,6 @@ except NameError:
class ConnectionRefusedError(OSError): class ConnectionRefusedError(OSError):
pass pass
import pprint
class Loop(object): class Loop(object):
...@@ -19,7 +18,7 @@ class Loop(object): ...@@ -19,7 +18,7 @@ class Loop(object):
def __init__(self, addrs=(), debug=True): def __init__(self, addrs=(), debug=True):
self.addrs = addrs self.addrs = addrs
self.get_debug = lambda : debug self.get_debug = lambda: debug
self.connecting = {} self.connecting = {}
self.later = [] self.later = []
self.exceptions = [] self.exceptions = []
...@@ -31,7 +30,7 @@ class Loop(object): ...@@ -31,7 +30,7 @@ class Loop(object):
func(*args) func(*args)
def _connect(self, future, protocol_factory): def _connect(self, future, protocol_factory):
self.protocol = protocol = protocol_factory() self.protocol = protocol = protocol_factory()
self.transport = transport = Transport(protocol) self.transport = transport = Transport(protocol)
protocol.connection_made(transport) protocol.connection_made(transport)
future.set_result((transport, protocol)) future.set_result((transport, protocol))
...@@ -45,10 +44,8 @@ class Loop(object): ...@@ -45,10 +44,8 @@ class Loop(object):
if not future.cancelled(): if not future.cancelled():
future.set_exception(ConnectionRefusedError()) future.set_exception(ConnectionRefusedError())
def create_connection( def create_connection(self, protocol_factory, host=None, port=None,
self, protocol_factory, host=None, port=None, sock=None, sock=None, ssl=None, server_hostname=None):
ssl=None, server_hostname=None
):
future = asyncio.Future(loop=self) future = asyncio.Future(loop=self)
if sock is None: if sock is None:
addr = host, port addr = host, port
...@@ -83,13 +80,16 @@ class Loop(object): ...@@ -83,13 +80,16 @@ class Loop(object):
self.exceptions.append(context) self.exceptions.append(context)
closed = False closed = False
def close(self): def close(self):
self.closed = True self.closed = True
stopped = False stopped = False
def stop(self): def stop(self):
self.stopped = True self.stopped = True
class Handle(object): class Handle(object):
cancelled = False cancelled = False
...@@ -97,6 +97,7 @@ class Handle(object): ...@@ -97,6 +97,7 @@ class Handle(object):
def cancel(self): def cancel(self):
self.cancelled = True self.cancelled = True
class Transport(object): class Transport(object):
capacity = 1 << 64 capacity = 1 << 64
...@@ -136,12 +137,14 @@ class Transport(object): ...@@ -136,12 +137,14 @@ class Transport(object):
self.protocol.resume_writing() self.protocol.resume_writing()
closed = False closed = False
def close(self): def close(self):
self.closed = True self.closed = True
def get_extra_info(self, name): def get_extra_info(self, name):
return self.extra[name] return self.extra[name]
class AsyncRPC(object): class AsyncRPC(object):
"""Adapt an asyncio API to an RPC to help hysterical tests """Adapt an asyncio API to an RPC to help hysterical tests
""" """
...@@ -151,6 +154,7 @@ class AsyncRPC(object): ...@@ -151,6 +154,7 @@ class AsyncRPC(object):
def __getattr__(self, name): def __getattr__(self, name):
return lambda *a, **kw: self.api.call(name, *a, **kw) return lambda *a, **kw: self.api.call(name, *a, **kw)
class ClientRunner(object): class ClientRunner(object):
def __init__(self, addr, client, cache, storage, read_only, timeout, def __init__(self, addr, client, cache, storage, read_only, timeout,
......
...@@ -2,17 +2,17 @@ from .._compat import PY3 ...@@ -2,17 +2,17 @@ from .._compat import PY3
if PY3: if PY3:
import asyncio import asyncio
def to_byte(i): def to_byte(i):
return bytes([i]) return bytes([i])
else: else:
import trollius as asyncio import trollius as asyncio # NOQA: F401 unused import
def to_byte(b): def to_byte(b):
return b return b
from zope.testing import setupstack from zope.testing import setupstack
from concurrent.futures import Future
import mock import mock
from ZODB.POSException import ReadOnlyError
from ZODB.utils import maxtid, RLock from ZODB.utils import maxtid, RLock
import collections import collections
...@@ -28,6 +28,7 @@ from .client import ClientRunner, Fallback ...@@ -28,6 +28,7 @@ from .client import ClientRunner, Fallback
from .server import new_connection, best_protocol_version from .server import new_connection, best_protocol_version
from .marshal import encoder, decoder from .marshal import encoder, decoder
class Base(object): class Base(object):
enc = b'Z' enc = b'Z'
...@@ -56,6 +57,7 @@ class Base(object): ...@@ -56,6 +57,7 @@ class Base(object):
return self.unsized(data, True) return self.unsized(data, True)
target = None target = None
def send(self, method, *args, **kw): def send(self, method, *args, **kw):
target = kw.pop('target', self.target) target = kw.pop('target', self.target)
called = kw.pop('called', True) called = kw.pop('called', True)
...@@ -77,6 +79,7 @@ class Base(object): ...@@ -77,6 +79,7 @@ class Base(object):
def pop(self, count=None, parse=True): def pop(self, count=None, parse=True):
return self.unsized(self.loop.transport.pop(count), parse) return self.unsized(self.loop.transport.pop(count), parse)
class ClientTests(Base, setupstack.TestCase, ClientRunner): class ClientTests(Base, setupstack.TestCase, ClientRunner):
maxDiff = None maxDiff = None
...@@ -204,7 +207,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -204,7 +207,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
loaded = self.load_before(b'1'*8, maxtid) loaded = self.load_before(b'1'*8, maxtid)
# The data wasn't in the cache, so we made a server call: # The data wasn't in the cache, so we made a server call:
self.assertEqual(self.pop(), ((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid))) self.assertEqual(self.pop(),
((b'1'*8, maxtid),
False,
'loadBefore',
(b'1'*8, maxtid)))
# Note load_before uses the oid as the message id. # Note load_before uses the oid as the message id.
self.respond((b'1'*8, maxtid), (b'data', b'a'*8, None)) self.respond((b'1'*8, maxtid), (b'data', b'a'*8, None))
self.assertEqual(loaded.result(), (b'data', b'a'*8, None)) self.assertEqual(loaded.result(), (b'data', b'a'*8, None))
...@@ -224,7 +231,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -224,7 +231,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# the requests will be collapsed: # the requests will be collapsed:
loaded2 = self.load_before(b'1'*8, maxtid) loaded2 = self.load_before(b'1'*8, maxtid)
self.assertEqual(self.pop(), ((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid))) self.assertEqual(self.pop(),
((b'1'*8, maxtid),
False,
'loadBefore',
(b'1'*8, maxtid)))
self.respond((b'1'*8, maxtid), (b'data2', b'b'*8, None)) self.respond((b'1'*8, maxtid), (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None)) self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None)) self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None))
...@@ -238,7 +249,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -238,7 +249,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(transport.data) self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8) loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(self.pop(), ((b'1'*8, b'_'*8), False, 'loadBefore', (b'1'*8, b'_'*8))) self.assertEqual(self.pop(),
((b'1'*8, b'_'*8),
False,
'loadBefore',
(b'1'*8, b'_'*8)))
self.respond((b'1'*8, b'_'*8), (b'data0', b'^'*8, b'_'*8)) self.respond((b'1'*8, b'_'*8), (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8)) self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
...@@ -247,6 +262,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -247,6 +262,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# iteratable to tpc_finish_threadsafe. # iteratable to tpc_finish_threadsafe.
tids = [] tids = []
def finished_cb(tid): def finished_cb(tid):
tids.append(tid) tids.append(tid)
...@@ -349,7 +365,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -349,7 +365,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, ))) self.assertEqual(self.pop(),
(3, False, 'getInvalidations', (b'a'*8, )))
self.respond(3, (b'e'*8, [b'4'*8])) self.respond(3, (b'e'*8, [b'4'*8]))
self.assertEqual(self.pop(), (4, False, 'get_info', ())) self.assertEqual(self.pop(), (4, False, 'get_info', ()))
...@@ -361,7 +378,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -361,7 +378,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# And the cache has been updated: # And the cache has been updated:
self.assertEqual(cache.load(b'2'*8), self.assertEqual(cache.load(b'2'*8),
('2 data', b'a'*8)) # unchanged ('2 data', b'a'*8)) # unchanged
self.assertEqual(cache.load(b'4'*8), None) self.assertEqual(cache.load(b'4'*8), None)
# Because we were able to update the cache, we didn't have to # Because we were able to update the cache, we didn't have to
...@@ -384,7 +401,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -384,7 +401,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We have to verify the cache, so we're not done connecting: # We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done()) self.assertFalse(client.connected.done())
self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, ))) self.assertEqual(self.pop(),
(3, False, 'getInvalidations', (b'a'*8, )))
# We respond None, indicating that we're too far out of date: # We respond None, indicating that we're too far out of date:
self.respond(3, None) self.respond(3, None)
...@@ -451,10 +469,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -451,10 +469,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.respond(2, 'a'*8) self.respond(2, 'a'*8)
self.pop() self.pop()
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat
self.assertTrue(8 < delay < 10) self.assertTrue(8 < delay < 10)
self.assertEqual(len(loop.later), 1) # first in later is heartbeat self.assertEqual(len(loop.later), 1) # first in later is heartbeat
func(*args) # connect again func(*args) # connect again
self.assertFalse(protocol is loop.protocol) self.assertFalse(protocol is loop.protocol)
self.assertFalse(transport is loop.transport) self.assertFalse(transport is loop.transport)
protocol = loop.protocol protocol = loop.protocol
...@@ -512,7 +530,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -512,7 +530,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We connect the second address: # We connect the second address:
loop.connect_connecting(addrs[1]) loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(self.enc + b'3101')) loop.protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), self.enc + b'3101') self.assertEqual(self.unsized(loop.transport.pop(2)),
self.enc + b'3101')
self.assertEqual(self.parse(loop.transport.pop()), self.assertEqual(self.parse(loop.transport.pop()),
(1, False, 'register', ('TEST', False))) (1, False, 'register', ('TEST', False)))
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
...@@ -613,7 +632,6 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -613,7 +632,6 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.data_received(sized(self.enc + b'200')) protocol.data_received(sized(self.enc + b'200'))
self.assertTrue(isinstance(error.call_args[0][1], ProtocolError)) self.assertTrue(isinstance(error.call_args[0][1], ProtocolError))
def test_get_peername(self): def test_get_peername(self):
wrapper, cache, loop, client, protocol, transport = self.start( wrapper, cache, loop, client, protocol, transport = self.start(
finish_start=True) finish_start=True)
...@@ -641,7 +659,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -641,7 +659,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# that caused it to fail badly if errors were raised while # that caused it to fail badly if errors were raised while
# handling data. # handling data.
wrapper, cache, loop, client, protocol, transport =self.start( wrapper, cache, loop, client, protocol, transport = self.start(
finish_start=True) finish_start=True)
wrapper.receiveBlobStart.side_effect = ValueError('test') wrapper.receiveBlobStart.side_effect = ValueError('test')
...@@ -694,10 +712,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -694,10 +712,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.connection_lost(None) protocol.connection_lost(None)
self.assertTrue(handle.cancelled) self.assertTrue(handle.cancelled)
class MsgpackClientTests(ClientTests): class MsgpackClientTests(ClientTests):
enc = b'M' enc = b'M'
seq_type = tuple seq_type = tuple
class MemoryCache(object): class MemoryCache(object):
def __init__(self): def __init__(self):
...@@ -709,6 +729,7 @@ class MemoryCache(object): ...@@ -709,6 +729,7 @@ class MemoryCache(object):
clear = __init__ clear = __init__
closed = False closed = False
def close(self): def close(self):
self.closed = True self.closed = True
...@@ -771,6 +792,7 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -771,6 +792,7 @@ class ServerTests(Base, setupstack.TestCase):
message_id = 0 message_id = 0
target = None target = None
def call(self, meth, *args, **kw): def call(self, meth, *args, **kw):
if kw: if kw:
expect = kw.pop('expect', self) expect = kw.pop('expect', self)
...@@ -835,10 +857,12 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -835,10 +857,12 @@ class ServerTests(Base, setupstack.TestCase):
self.call('foo', target=None) self.call('foo', target=None)
self.assertTrue(protocol.loop.transport.closed) self.assertTrue(protocol.loop.transport.closed)
class MsgpackServerTests(ServerTests): class MsgpackServerTests(ServerTests):
enc = b'M' enc = b'M'
seq_type = tuple seq_type = tuple
def server_protocol(msgpack, def server_protocol(msgpack,
zeo_storage=None, zeo_storage=None,
protocol_version=None, protocol_version=None,
...@@ -847,18 +871,17 @@ def server_protocol(msgpack, ...@@ -847,18 +871,17 @@ def server_protocol(msgpack,
if zeo_storage is None: if zeo_storage is None:
zeo_storage = mock.Mock() zeo_storage = mock.Mock()
loop = Loop() loop = Loop()
sock = () # anything not None sock = () # anything not None
new_connection(loop, addr, sock, zeo_storage, msgpack) new_connection(loop, addr, sock, zeo_storage, msgpack)
if protocol_version: if protocol_version:
loop.protocol.data_received(sized(protocol_version)) loop.protocol.data_received(sized(protocol_version))
return loop.protocol return loop.protocol
def response(*data):
return sized(self.encode(*data))
def sized(message): def sized(message):
return struct.pack(">I", len(message)) + message return struct.pack(">I", len(message)) + message
class Logging(object): class Logging(object):
def __init__(self, level=logging.ERROR): def __init__(self, level=logging.ERROR):
...@@ -885,9 +908,11 @@ class ProtocolTests(setupstack.TestCase): ...@@ -885,9 +908,11 @@ class ProtocolTests(setupstack.TestCase):
loop = self.loop loop = self.loop
protocol, transport = loop.protocol, loop.transport protocol, transport = loop.protocol, loop.transport
transport.capacity = 1 # single message transport.capacity = 1 # single message
def it(tag): def it(tag):
yield tag yield tag
yield tag yield tag
protocol._writeit(it(b"0")) protocol._writeit(it(b"0"))
protocol._writeit(it(b"1")) protocol._writeit(it(b"1"))
for b in b"0011": for b in b"0011":
......
...@@ -86,7 +86,7 @@ ZEC_HEADER_SIZE = 12 ...@@ -86,7 +86,7 @@ ZEC_HEADER_SIZE = 12
# need to write a free block that is almost twice as big. If we die # need to write a free block that is almost twice as big. If we die
# in the middle of a store, then we need to split the large free records # in the middle of a store, then we need to split the large free records
# while opening. # while opening.
max_block_size = (1<<31) - 1 max_block_size = (1 << 31) - 1
# After the header, the file contains a contiguous sequence of blocks. All # After the header, the file contains a contiguous sequence of blocks. All
...@@ -132,12 +132,13 @@ allocated_record_overhead = 43 ...@@ -132,12 +132,13 @@ allocated_record_overhead = 43
# Under PyPy, the available dict specializations perform significantly # Under PyPy, the available dict specializations perform significantly
# better (faster) than the pure-Python BTree implementation. They may # better (faster) than the pure-Python BTree implementation. They may
# use less memory too. And we don't require any of the special BTree features... # use less memory too. And we don't require any of the special BTree features.
_current_index_type = ZODB.fsIndex.fsIndex if not PYPY else dict _current_index_type = ZODB.fsIndex.fsIndex if not PYPY else dict
_noncurrent_index_type = BTrees.LOBTree.LOBTree if not PYPY else dict _noncurrent_index_type = BTrees.LOBTree.LOBTree if not PYPY else dict
# ...except at this leaf level # ...except at this leaf level
_noncurrent_bucket_type = BTrees.LLBTree.LLBucket _noncurrent_bucket_type = BTrees.LLBTree.LLBucket
class ClientCache(object): class ClientCache(object):
"""A simple in-memory cache.""" """A simple in-memory cache."""
...@@ -193,7 +194,7 @@ class ClientCache(object): ...@@ -193,7 +194,7 @@ class ClientCache(object):
if path: if path:
self._lock_file = zc.lockfile.LockFile(path + '.lock') self._lock_file = zc.lockfile.LockFile(path + '.lock')
if not os.path.exists(path): if not os.path.exists(path):
# Create a small empty file. We'll make it bigger in _initfile. # Create a small empty file. We'll make it bigger in _initfile.
self.f = open(path, 'wb+') self.f = open(path, 'wb+')
self.f.write(magic+z64) self.f.write(magic+z64)
logger.info("created persistent cache file %r", path) logger.info("created persistent cache file %r", path)
...@@ -209,10 +210,10 @@ class ClientCache(object): ...@@ -209,10 +210,10 @@ class ClientCache(object):
try: try:
self._initfile(fsize) self._initfile(fsize)
except: except: # NOQA: E722 bare except
self.f.close() self.f.close()
if not path: if not path:
raise # unrecoverable temp file error :( raise # unrecoverable temp file error :(
badpath = path+'.bad' badpath = path+'.bad'
if os.path.exists(badpath): if os.path.exists(badpath):
logger.critical( logger.critical(
...@@ -271,7 +272,7 @@ class ClientCache(object): ...@@ -271,7 +272,7 @@ class ClientCache(object):
self.current = _current_index_type() self.current = _current_index_type()
self.noncurrent = _noncurrent_index_type() self.noncurrent = _noncurrent_index_type()
l = 0 length = 0
last = ofs = ZEC_HEADER_SIZE last = ofs = ZEC_HEADER_SIZE
first_free_offset = 0 first_free_offset = 0
current = self.current current = self.current
...@@ -290,7 +291,7 @@ class ClientCache(object): ...@@ -290,7 +291,7 @@ class ClientCache(object):
assert start_tid < end_tid, (ofs, f.tell()) assert start_tid < end_tid, (ofs, f.tell())
self._set_noncurrent(oid, start_tid, ofs) self._set_noncurrent(oid, start_tid, ofs)
assert lver == 0, "Versions aren't supported" assert lver == 0, "Versions aren't supported"
l += 1 length += 1
else: else:
# free block # free block
if first_free_offset == 0: if first_free_offset == 0:
...@@ -331,7 +332,7 @@ class ClientCache(object): ...@@ -331,7 +332,7 @@ class ClientCache(object):
break break
if fsize < maxsize: if fsize < maxsize:
assert ofs==fsize assert ofs == fsize
# Make sure the OS really saves enough bytes for the file. # Make sure the OS really saves enough bytes for the file.
seek(self.maxsize - 1) seek(self.maxsize - 1)
write(b'x') write(b'x')
...@@ -349,7 +350,7 @@ class ClientCache(object): ...@@ -349,7 +350,7 @@ class ClientCache(object):
assert last and (status in b' f1234') assert last and (status in b' f1234')
first_free_offset = last first_free_offset = last
else: else:
assert ofs==maxsize assert ofs == maxsize
if maxsize < fsize: if maxsize < fsize:
seek(maxsize) seek(maxsize)
f.truncate() f.truncate()
...@@ -357,7 +358,7 @@ class ClientCache(object): ...@@ -357,7 +358,7 @@ class ClientCache(object):
# We use the first_free_offset because it is most likely the # We use the first_free_offset because it is most likely the
# place where we last wrote. # place where we last wrote.
self.currentofs = first_free_offset or ZEC_HEADER_SIZE self.currentofs = first_free_offset or ZEC_HEADER_SIZE
self._len = l self._len = length
def _set_noncurrent(self, oid, tid, ofs): def _set_noncurrent(self, oid, tid, ofs):
noncurrent_for_oid = self.noncurrent.get(u64(oid)) noncurrent_for_oid = self.noncurrent.get(u64(oid))
...@@ -375,7 +376,6 @@ class ClientCache(object): ...@@ -375,7 +376,6 @@ class ClientCache(object):
except KeyError: except KeyError:
logger.error("Couldn't find non-current %r", (oid, tid)) logger.error("Couldn't find non-current %r", (oid, tid))
def clearStats(self): def clearStats(self):
self._n_adds = self._n_added_bytes = 0 self._n_adds = self._n_added_bytes = 0
self._n_evicts = self._n_evicted_bytes = 0 self._n_evicts = self._n_evicted_bytes = 0
...@@ -384,8 +384,7 @@ class ClientCache(object): ...@@ -384,8 +384,7 @@ class ClientCache(object):
def getStats(self): def getStats(self):
return (self._n_adds, self._n_added_bytes, return (self._n_adds, self._n_added_bytes,
self._n_evicts, self._n_evicted_bytes, self._n_evicts, self._n_evicted_bytes,
self._n_accesses self._n_accesses)
)
## ##
# The number of objects currently in the cache. # The number of objects currently in the cache.
...@@ -403,7 +402,7 @@ class ClientCache(object): ...@@ -403,7 +402,7 @@ class ClientCache(object):
sync(f) sync(f)
f.close() f.close()
if hasattr(self,'_lock_file'): if hasattr(self, '_lock_file'):
self._lock_file.close() self._lock_file.close()
## ##
...@@ -517,9 +516,9 @@ class ClientCache(object): ...@@ -517,9 +516,9 @@ class ClientCache(object):
if ofsofs < 0: if ofsofs < 0:
ofsofs += self.maxsize ofsofs += self.maxsize
if (ofsofs > self.rearrange and if ofsofs > self.rearrange and \
self.maxsize > 10*len(data) and self.maxsize > 10*len(data) and \
size > 4): size > 4:
# The record is far back and might get evicted, but it's # The record is far back and might get evicted, but it's
# valuable, so move it forward. # valuable, so move it forward.
...@@ -619,8 +618,8 @@ class ClientCache(object): ...@@ -619,8 +618,8 @@ class ClientCache(object):
raise ValueError("already have current data for oid") raise ValueError("already have current data for oid")
else: else:
noncurrent_for_oid = self.noncurrent.get(u64(oid)) noncurrent_for_oid = self.noncurrent.get(u64(oid))
if noncurrent_for_oid and ( if noncurrent_for_oid and \
u64(start_tid) in noncurrent_for_oid): u64(start_tid) in noncurrent_for_oid:
return return
size = allocated_record_overhead + len(data) size = allocated_record_overhead + len(data)
...@@ -692,7 +691,6 @@ class ClientCache(object): ...@@ -692,7 +691,6 @@ class ClientCache(object):
self.currentofs += size self.currentofs += size
## ##
# If `tid` is None, # If `tid` is None,
# forget all knowledge of `oid`. (`tid` can be None only for # forget all knowledge of `oid`. (`tid` can be None only for
...@@ -765,8 +763,7 @@ class ClientCache(object): ...@@ -765,8 +763,7 @@ class ClientCache(object):
for oid, tid in L: for oid, tid in L:
print(oid_repr(oid), oid_repr(tid)) print(oid_repr(oid), oid_repr(tid))
print("dll contents") print("dll contents")
L = list(self) L = sorted(list(self), key=lambda x: x.key)
L.sort(lambda x, y: cmp(x.key, y.key))
for x in L: for x in L:
end_tid = x.end_tid or z64 end_tid = x.end_tid or z64
print(oid_repr(x.key[0]), oid_repr(x.key[1]), oid_repr(end_tid)) print(oid_repr(x.key[0]), oid_repr(x.key[1]), oid_repr(end_tid))
...@@ -779,6 +776,7 @@ class ClientCache(object): ...@@ -779,6 +776,7 @@ class ClientCache(object):
# tracing by setting self._trace to a dummy function, and set # tracing by setting self._trace to a dummy function, and set
# self._tracefile to None. # self._tracefile to None.
_tracefile = None _tracefile = None
def _trace(self, *a, **kw): def _trace(self, *a, **kw):
pass pass
...@@ -797,6 +795,7 @@ class ClientCache(object): ...@@ -797,6 +795,7 @@ class ClientCache(object):
return return
now = time.time now = time.time
def _trace(code, oid=b"", tid=z64, end_tid=z64, dlen=0): def _trace(code, oid=b"", tid=z64, end_tid=z64, dlen=0):
# The code argument is two hex digits; bits 0 and 7 must be zero. # The code argument is two hex digits; bits 0 and 7 must be zero.
# The first hex digit shows the operation, the second the outcome. # The first hex digit shows the operation, the second the outcome.
...@@ -812,7 +811,7 @@ class ClientCache(object): ...@@ -812,7 +811,7 @@ class ClientCache(object):
pack(">iiH8s8s", pack(">iiH8s8s",
int(now()), encoded, len(oid), tid, end_tid) + oid, int(now()), encoded, len(oid), tid, end_tid) + oid,
) )
except: except: # NOQA: E722 bare except
print(repr(tid), repr(end_tid)) print(repr(tid), repr(end_tid))
raise raise
...@@ -826,10 +825,7 @@ class ClientCache(object): ...@@ -826,10 +825,7 @@ class ClientCache(object):
self._tracefile.close() self._tracefile.close()
del self._tracefile del self._tracefile
def sync(f): def sync(f):
f.flush() f.flush()
os.fsync(f.fileno())
if hasattr(os, 'fsync'):
def sync(f):
f.flush()
os.fsync(f.fileno())
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import zope.interface import zope.interface
class StaleCache(object): class StaleCache(object):
"""A ZEO cache is stale and requires verification. """A ZEO cache is stale and requires verification.
""" """
...@@ -21,6 +22,7 @@ class StaleCache(object): ...@@ -21,6 +22,7 @@ class StaleCache(object):
def __init__(self, storage): def __init__(self, storage):
self.storage = storage self.storage = storage
class IClientCache(zope.interface.Interface): class IClientCache(zope.interface.Interface):
"""Client cache interface. """Client cache interface.
...@@ -86,6 +88,7 @@ class IClientCache(zope.interface.Interface): ...@@ -86,6 +88,7 @@ class IClientCache(zope.interface.Interface):
"""Clear/empty the cache """Clear/empty the cache
""" """
class IServeable(zope.interface.Interface): class IServeable(zope.interface.Interface):
"""Interface provided by storages that can be served by ZEO """Interface provided by storages that can be served by ZEO
""" """
......
...@@ -30,10 +30,7 @@ from __future__ import print_function ...@@ -30,10 +30,7 @@ from __future__ import print_function
from __future__ import print_function from __future__ import print_function
from __future__ import print_function from __future__ import print_function
import asyncore
import socket
import time import time
import logging
zeo_version = 'unknown' zeo_version = 'unknown'
try: try:
...@@ -47,6 +44,7 @@ else: ...@@ -47,6 +44,7 @@ else:
if zeo_dist is not None: if zeo_dist is not None:
zeo_version = zeo_dist.version zeo_version = zeo_dist.version
class StorageStats(object): class StorageStats(object):
"""Per-storage usage statistics.""" """Per-storage usage statistics."""
......
...@@ -33,6 +33,7 @@ diff_names = 'aborts commits conflicts conflicts_resolved loads stores'.split() ...@@ -33,6 +33,7 @@ diff_names = 'aborts commits conflicts conflicts_resolved loads stores'.split()
per_times = dict(seconds=1.0, minutes=60.0, hours=3600.0, days=86400.0) per_times = dict(seconds=1.0, minutes=60.0, hours=3600.0, days=86400.0)
def new_metric(metrics, storage_id, name, value): def new_metric(metrics, storage_id, name, value):
if storage_id == '1': if storage_id == '1':
label = name label = name
...@@ -43,6 +44,7 @@ def new_metric(metrics, storage_id, name, value): ...@@ -43,6 +44,7 @@ def new_metric(metrics, storage_id, name, value):
label = "%s:%s" % (storage_id, name) label = "%s:%s" % (storage_id, name)
metrics.append("%s=%s" % (label, value)) metrics.append("%s=%s" % (label, value))
def result(messages, metrics=(), status=None): def result(messages, metrics=(), status=None):
if metrics: if metrics:
messages[0] += '|' + metrics[0] messages[0] += '|' + metrics[0]
...@@ -51,12 +53,15 @@ def result(messages, metrics=(), status=None): ...@@ -51,12 +53,15 @@ def result(messages, metrics=(), status=None):
print('\n'.join(messages)) print('\n'.join(messages))
return status return status
def error(message): def error(message):
return result((message, ), (), 2) return result((message, ), (), 2)
def warn(message): def warn(message):
return result((message, ), (), 1) return result((message, ), (), 1)
def check(addr, output_metrics, status, per): def check(addr, output_metrics, status, per):
m = re.match(r'\[(\S+)\]:(\d+)$', addr) m = re.match(r'\[(\S+)\]:(\d+)$', addr)
if m: if m:
...@@ -75,7 +80,7 @@ def check(addr, output_metrics, status, per): ...@@ -75,7 +80,7 @@ def check(addr, output_metrics, status, per):
return error("Can't connect %s" % err) return error("Can't connect %s" % err)
s.sendall(b'\x00\x00\x00\x04ruok') s.sendall(b'\x00\x00\x00\x04ruok')
proto = s.recv(struct.unpack(">I", s.recv(4))[0]) proto = s.recv(struct.unpack(">I", s.recv(4))[0]) # NOQA: F841 unused
datas = s.recv(struct.unpack(">I", s.recv(4))[0]) datas = s.recv(struct.unpack(">I", s.recv(4))[0])
s.close() s.close()
data = json.loads(datas.decode("ascii")) data = json.loads(datas.decode("ascii"))
...@@ -94,8 +99,8 @@ def check(addr, output_metrics, status, per): ...@@ -94,8 +99,8 @@ def check(addr, output_metrics, status, per):
now = time.time() now = time.time()
if os.path.exists(status): if os.path.exists(status):
dt = now - os.stat(status).st_mtime dt = now - os.stat(status).st_mtime
if dt > 0: # sanity :) if dt > 0: # sanity :)
with open(status) as f: # Read previous with open(status) as f: # Read previous
old = json.loads(f.read()) old = json.loads(f.read())
dt /= per_times[per] dt /= per_times[per]
for storage_id, sdata in sorted(data.items()): for storage_id, sdata in sorted(data.items()):
...@@ -105,7 +110,7 @@ def check(addr, output_metrics, status, per): ...@@ -105,7 +110,7 @@ def check(addr, output_metrics, status, per):
for name in diff_names: for name in diff_names:
v = (sdata[name] - sold[name]) / dt v = (sdata[name] - sold[name]) / dt
new_metric(metrics, storage_id, name, v) new_metric(metrics, storage_id, name, v)
with open(status, 'w') as f: # save current with open(status, 'w') as f: # save current
f.write(json.dumps(data)) f.write(json.dumps(data))
for storage_id, sdata in sorted(data.items()): for storage_id, sdata in sorted(data.items()):
...@@ -116,6 +121,7 @@ def check(addr, output_metrics, status, per): ...@@ -116,6 +121,7 @@ def check(addr, output_metrics, status, per):
messages.append('OK') messages.append('OK')
return result(messages, metrics, level or None) return result(messages, metrics, level or None)
def main(args=None): def main(args=None):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
...@@ -139,5 +145,6 @@ def main(args=None): ...@@ -139,5 +145,6 @@ def main(args=None):
return check( return check(
addr, options.output_metrics, options.status_path, options.time_units) addr, options.output_metrics, options.status_path, options.time_units)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -46,21 +46,25 @@ from zdaemon.zdoptions import ZDOptions ...@@ -46,21 +46,25 @@ from zdaemon.zdoptions import ZDOptions
logger = logging.getLogger('ZEO.runzeo') logger = logging.getLogger('ZEO.runzeo')
_pid = str(os.getpid()) _pid = str(os.getpid())
def log(msg, level=logging.INFO, exc_info=False): def log(msg, level=logging.INFO, exc_info=False):
"""Internal: generic logging function.""" """Internal: generic logging function."""
message = "(%s) %s" % (_pid, msg) message = "(%s) %s" % (_pid, msg)
logger.log(level, message, exc_info=exc_info) logger.log(level, message, exc_info=exc_info)
def parse_binding_address(arg): def parse_binding_address(arg):
# Caution: Not part of the official ZConfig API. # Caution: Not part of the official ZConfig API.
obj = ZConfig.datatypes.SocketBindingAddress(arg) obj = ZConfig.datatypes.SocketBindingAddress(arg)
return obj.family, obj.address return obj.family, obj.address
def windows_shutdown_handler(): def windows_shutdown_handler():
# Called by the signal mechanism on Windows to perform shutdown. # Called by the signal mechanism on Windows to perform shutdown.
import asyncore import asyncore
asyncore.close_all() asyncore.close_all()
class ZEOOptionsMixin(object): class ZEOOptionsMixin(object):
storages = None storages = None
...@@ -69,14 +73,17 @@ class ZEOOptionsMixin(object): ...@@ -69,14 +73,17 @@ class ZEOOptionsMixin(object):
self.family, self.address = parse_binding_address(arg) self.family, self.address = parse_binding_address(arg)
def handle_filename(self, arg): def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*! from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig(object): class FSConfig(object):
def __init__(self, name, path): def __init__(self, name, path):
self._name = name self._name = name
self.path = path self.path = path
self.stop = None self.stop = None
def getSectionName(self): def getSectionName(self):
return self._name return self._name
if not self.storages: if not self.storages:
self.storages = [] self.storages = []
name = str(1 + len(self.storages)) name = str(1 + len(self.storages))
...@@ -84,6 +91,7 @@ class ZEOOptionsMixin(object): ...@@ -84,6 +91,7 @@ class ZEOOptionsMixin(object):
self.storages.append(conf) self.storages.append(conf)
testing_exit_immediately = False testing_exit_immediately = False
def handle_test(self, *args): def handle_test(self, *args):
self.testing_exit_immediately = True self.testing_exit_immediately = True
...@@ -108,6 +116,7 @@ class ZEOOptionsMixin(object): ...@@ -108,6 +116,7 @@ class ZEOOptionsMixin(object):
None, 'pid-file=') None, 'pid-file=')
self.add("ssl", "zeo.ssl") self.add("ssl", "zeo.ssl")
class ZEOOptions(ZDOptions, ZEOOptionsMixin): class ZEOOptions(ZDOptions, ZEOOptionsMixin):
__doc__ = __doc__ __doc__ = __doc__
...@@ -164,15 +173,15 @@ class ZEOServer(object): ...@@ -164,15 +173,15 @@ class ZEOServer(object):
root = logging.getLogger() root = logging.getLogger()
root.setLevel(logging.INFO) root.setLevel(logging.INFO)
fmt = logging.Formatter( fmt = logging.Formatter(
"------\n%(asctime)s %(levelname)s %(name)s %(message)s", "------\n%(asctime)s %(levelname)s %(name)s %(message)s",
"%Y-%m-%dT%H:%M:%S") "%Y-%m-%dT%H:%M:%S")
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter(fmt) handler.setFormatter(fmt)
root.addHandler(handler) root.addHandler(handler)
def check_socket(self): def check_socket(self):
if (isinstance(self.options.address, tuple) and if isinstance(self.options.address, tuple) and \
self.options.address[1] is None): self.options.address[1] is None:
self.options.address = self.options.address[0], 0 self.options.address = self.options.address[0], 0
return return
...@@ -217,7 +226,7 @@ class ZEOServer(object): ...@@ -217,7 +226,7 @@ class ZEOServer(object):
self.setup_win32_signals() self.setup_win32_signals()
return return
if hasattr(signal, 'SIGXFSZ'): if hasattr(signal, 'SIGXFSZ'):
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
init_signames() init_signames()
for sig, name in signames.items(): for sig, name in signames.items():
method = getattr(self, "handle_" + name.lower(), None) method = getattr(self, "handle_" + name.lower(), None)
...@@ -237,12 +246,12 @@ class ZEOServer(object): ...@@ -237,12 +246,12 @@ class ZEOServer(object):
"will *not* be installed.") "will *not* be installed.")
return return
SignalHandler = Signals.Signals.SignalHandler SignalHandler = Signals.Signals.SignalHandler
if SignalHandler is not None: # may be None if no pywin32. if SignalHandler is not None: # may be None if no pywin32.
SignalHandler.registerHandler(signal.SIGTERM, SignalHandler.registerHandler(signal.SIGTERM,
windows_shutdown_handler) windows_shutdown_handler)
SignalHandler.registerHandler(signal.SIGINT, SignalHandler.registerHandler(signal.SIGINT,
windows_shutdown_handler) windows_shutdown_handler)
SIGUSR2 = 12 # not in signal module on Windows. SIGUSR2 = 12 # not in signal module on Windows.
SignalHandler.registerHandler(SIGUSR2, self.handle_sigusr2) SignalHandler.registerHandler(SIGUSR2, self.handle_sigusr2)
def create_server(self): def create_server(self):
...@@ -278,7 +287,8 @@ class ZEOServer(object): ...@@ -278,7 +287,8 @@ class ZEOServer(object):
def handle_sigusr2(self): def handle_sigusr2(self):
# log rotation signal - do the same as Zope 2.7/2.8... # log rotation signal - do the same as Zope 2.7/2.8...
if self.options.config_logger is None or os.name not in ("posix", "nt"): if self.options.config_logger is None or \
os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!", log("received SIGUSR2, but it was not handled!",
level=logging.WARNING) level=logging.WARNING)
return return
...@@ -286,13 +296,13 @@ class ZEOServer(object): ...@@ -286,13 +296,13 @@ class ZEOServer(object):
loggers = [self.options.config_logger] loggers = [self.options.config_logger]
if os.name == "posix": if os.name == "posix":
for l in loggers: for logger in loggers:
l.reopen() logger.reopen()
log("Log files reopened successfully", level=logging.INFO) log("Log files reopened successfully", level=logging.INFO)
else: # nt - same rotation code as in Zope's Signals/Signals.py else: # nt - same rotation code as in Zope's Signals/Signals.py
for l in loggers: for logger in loggers:
for f in l.handler_factories: for factory in logger.handler_factories:
handler = f() handler = factory()
if hasattr(handler, 'rotate') and callable(handler.rotate): if hasattr(handler, 'rotate') and callable(handler.rotate):
handler.rotate() handler.rotate()
log("Log files rotation complete", level=logging.INFO) log("Log files rotation complete", level=logging.INFO)
...@@ -350,21 +360,21 @@ def create_server(storages, options): ...@@ -350,21 +360,21 @@ def create_server(storages, options):
return StorageServer( return StorageServer(
options.address, options.address,
storages, storages,
read_only = options.read_only, read_only=options.read_only,
client_conflict_resolution=options.client_conflict_resolution, client_conflict_resolution=options.client_conflict_resolution,
msgpack=(options.msgpack if isinstance(options.msgpack, bool) msgpack=(options.msgpack if isinstance(options.msgpack, bool)
else os.environ.get('ZEO_MSGPACK')), else os.environ.get('ZEO_MSGPACK')),
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size=options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age=options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout=options.transaction_timeout,
ssl = options.ssl, ssl=options.ssl)
)
# Signal names # Signal names
signames = None signames = None
def signame(sig): def signame(sig):
"""Return a symbolic name for a signal. """Return a symbolic name for a signal.
...@@ -376,6 +386,7 @@ def signame(sig): ...@@ -376,6 +386,7 @@ def signame(sig):
init_signames() init_signames()
return signames.get(sig) or "signal %d" % sig return signames.get(sig) or "signal %d" % sig
def init_signames(): def init_signames():
global signames global signames
signames = {} signames = {}
...@@ -395,11 +406,13 @@ def main(args=None): ...@@ -395,11 +406,13 @@ def main(args=None):
s = ZEOServer(options) s = ZEOServer(options)
s.main() s.main()
def run(args): def run(args):
options = ZEOOptions() options = ZEOOptions()
options.realize(args) options.realize(args)
s = ZEOServer(options) s = ZEOServer(options)
s.run() s.run()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -27,6 +27,7 @@ from __future__ import print_function, absolute_import ...@@ -27,6 +27,7 @@ from __future__ import print_function, absolute_import
import bisect import bisect
import struct import struct
import random
import re import re
import sys import sys
import ZEO.cache import ZEO.cache
...@@ -34,6 +35,7 @@ import argparse ...@@ -34,6 +35,7 @@ import argparse
from ZODB.utils import z64 from ZODB.utils import z64
from ..cache import ZEC_HEADER_SIZE
from .cache_stats import add_interval_argument from .cache_stats import add_interval_argument
from .cache_stats import add_tracefile_argument from .cache_stats import add_tracefile_argument
...@@ -46,7 +48,7 @@ def main(args=None): ...@@ -46,7 +48,7 @@ def main(args=None):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
# Parse options. # Parse options.
MB = 1<<20 MB = 1 << 20
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--size", "-s", parser.add_argument("--size", "-s",
default=20*MB, dest="cachelimit", default=20*MB, dest="cachelimit",
...@@ -115,6 +117,7 @@ def main(args=None): ...@@ -115,6 +117,7 @@ def main(args=None):
interval_sim.report() interval_sim.report()
sim.finish() sim.finish()
class Simulation(object): class Simulation(object):
"""Base class for simulations. """Base class for simulations.
...@@ -270,7 +273,6 @@ class CircularCacheEntry(object): ...@@ -270,7 +273,6 @@ class CircularCacheEntry(object):
self.end_tid = end_tid self.end_tid = end_tid
self.offset = offset self.offset = offset
from ZEO.cache import ZEC_HEADER_SIZE
class CircularCacheSimulation(Simulation): class CircularCacheSimulation(Simulation):
"""Simulate the ZEO 3.0 cache.""" """Simulate the ZEO 3.0 cache."""
...@@ -285,8 +287,6 @@ class CircularCacheSimulation(Simulation): ...@@ -285,8 +287,6 @@ class CircularCacheSimulation(Simulation):
evicts = 0 evicts = 0
def __init__(self, cachelimit, rearrange): def __init__(self, cachelimit, rearrange):
from ZEO import cache
Simulation.__init__(self, cachelimit, rearrange) Simulation.__init__(self, cachelimit, rearrange)
self.total_evicts = 0 # number of cache evictions self.total_evicts = 0 # number of cache evictions
...@@ -296,7 +296,7 @@ class CircularCacheSimulation(Simulation): ...@@ -296,7 +296,7 @@ class CircularCacheSimulation(Simulation):
# Map offset in file to (size, CircularCacheEntry) pair, or to # Map offset in file to (size, CircularCacheEntry) pair, or to
# (size, None) if the offset starts a free block. # (size, None) if the offset starts a free block.
self.filemap = {ZEC_HEADER_SIZE: (self.cachelimit - ZEC_HEADER_SIZE, self.filemap = {ZEC_HEADER_SIZE: (self.cachelimit - ZEC_HEADER_SIZE,
None)} None)}
# Map key to CircularCacheEntry. A key is an (oid, tid) pair. # Map key to CircularCacheEntry. A key is an (oid, tid) pair.
self.key2entry = {} self.key2entry = {}
...@@ -322,10 +322,11 @@ class CircularCacheSimulation(Simulation): ...@@ -322,10 +322,11 @@ class CircularCacheSimulation(Simulation):
self.evicted_hit = self.evicted_miss = 0 self.evicted_hit = self.evicted_miss = 0
evicted_hit = evicted_miss = 0 evicted_hit = evicted_miss = 0
def load(self, oid, size, tid, code): def load(self, oid, size, tid, code):
if (code == 0x20) or (code == 0x22): if (code == 0x20) or (code == 0x22):
# Trying to load current revision. # Trying to load current revision.
if oid in self.current: # else it's a cache miss if oid in self.current: # else it's a cache miss
self.hits += 1 self.hits += 1
self.total_hits += 1 self.total_hits += 1
...@@ -433,7 +434,8 @@ class CircularCacheSimulation(Simulation): ...@@ -433,7 +434,8 @@ class CircularCacheSimulation(Simulation):
# Storing current revision. # Storing current revision.
if oid in self.current: # we already have it in cache if oid in self.current: # we already have it in cache
if evhit: if evhit:
import pdb; pdb.set_trace() import pdb
pdb.set_trace()
raise ValueError('WTF') raise ValueError('WTF')
return return
self.current[oid] = start_tid self.current[oid] = start_tid
...@@ -442,7 +444,8 @@ class CircularCacheSimulation(Simulation): ...@@ -442,7 +444,8 @@ class CircularCacheSimulation(Simulation):
self.add(oid, size, start_tid) self.add(oid, size, start_tid)
return return
if evhit: if evhit:
import pdb; pdb.set_trace() import pdb
pdb.set_trace()
raise ValueError('WTF') raise ValueError('WTF')
# Storing non-current revision. # Storing non-current revision.
L = self.noncurrent.setdefault(oid, []) L = self.noncurrent.setdefault(oid, [])
...@@ -514,7 +517,7 @@ class CircularCacheSimulation(Simulation): ...@@ -514,7 +517,7 @@ class CircularCacheSimulation(Simulation):
self.inuse = round(100.0 * used / total, 1) self.inuse = round(100.0 * used / total, 1)
self.total_inuse = self.inuse self.total_inuse = self.inuse
Simulation.report(self) Simulation.report(self)
#print self.evicted_hit, self.evicted_miss # print self.evicted_hit, self.evicted_miss
def check(self): def check(self):
oidcount = 0 oidcount = 0
...@@ -538,16 +541,18 @@ class CircularCacheSimulation(Simulation): ...@@ -538,16 +541,18 @@ class CircularCacheSimulation(Simulation):
def roundup(size): def roundup(size):
k = MINSIZE k = MINSIZE # NOQA: F821 undefined name
while k < size: while k < size:
k += k k += k
return k return k
def hitrate(loads, hits): def hitrate(loads, hits):
if loads < 1: if loads < 1:
return 'n/a' return 'n/a'
return "%5.1f%%" % (100.0 * hits / loads) return "%5.1f%%" % (100.0 * hits / loads)
def duration(secs): def duration(secs):
mm, ss = divmod(secs, 60) mm, ss = divmod(secs, 60)
hh, mm = divmod(mm, 60) hh, mm = divmod(mm, 60)
...@@ -557,7 +562,10 @@ def duration(secs): ...@@ -557,7 +562,10 @@ def duration(secs):
return "%d:%02d" % (mm, ss) return "%d:%02d" % (mm, ss)
return "%d" % ss return "%d" % ss
nre = re.compile('([=-]?)(\d+)([.]\d*)?').match
nre = re.compile(r'([=-]?)(\d+)([.]\d*)?').match
def addcommas(n): def addcommas(n):
sign, s, d = nre(str(n)).group(1, 2, 3) sign, s, d = nre(str(n)).group(1, 2, 3)
if d == '.0': if d == '.0':
...@@ -571,11 +579,11 @@ def addcommas(n): ...@@ -571,11 +579,11 @@ def addcommas(n):
return (sign or '') + result + (d or '') return (sign or '') + result + (d or '')
import random
def maybe(f, p=0.5): def maybe(f, p=0.5):
if random.random() < p: if random.random() < p:
f() f()
if __name__ == "__main__": if __name__ == "__main__":
sys.exit(main()) sys.exit(main())
...@@ -55,6 +55,7 @@ import gzip ...@@ -55,6 +55,7 @@ import gzip
from time import ctime from time import ctime
import six import six
def add_interval_argument(parser): def add_interval_argument(parser):
def _interval(a): def _interval(a):
interval = int(60 * float(a)) interval = int(60 * float(a))
...@@ -63,9 +64,11 @@ def add_interval_argument(parser): ...@@ -63,9 +64,11 @@ def add_interval_argument(parser):
elif interval > 3600: elif interval > 3600:
interval = 3600 interval = 3600
return interval return interval
parser.add_argument("--interval", "-i", parser.add_argument(
default=15*60, type=_interval, "--interval", "-i",
help="summarizing interval in minutes (default 15; max 60)") default=15*60, type=_interval,
help="summarizing interval in minutes (default 15; max 60)")
def add_tracefile_argument(parser): def add_tracefile_argument(parser):
...@@ -82,15 +85,17 @@ def add_tracefile_argument(parser): ...@@ -82,15 +85,17 @@ def add_tracefile_argument(parser):
parser.add_argument("tracefile", type=GzipFileType(), parser.add_argument("tracefile", type=GzipFileType(),
help="The trace to read; may be gzipped") help="The trace to read; may be gzipped")
def main(args=None): def main(args=None):
if args is None: if args is None:
args = sys.argv[1:] args = sys.argv[1:]
# Parse options # Parse options
parser = argparse.ArgumentParser(description="Trace file statistics analyzer", parser = argparse.ArgumentParser(
# Our -h, short for --load-histogram description="Trace file statistics analyzer",
# conflicts with default for help, so we handle # Our -h, short for --load-histogram
# manually. # conflicts with default for help, so we handle
add_help=False) # manually.
add_help=False)
verbose_group = parser.add_mutually_exclusive_group() verbose_group = parser.add_mutually_exclusive_group()
verbose_group.add_argument('--verbose', '-v', verbose_group.add_argument('--verbose', '-v',
default=False, action='store_true', default=False, action='store_true',
...@@ -99,18 +104,22 @@ def main(args=None): ...@@ -99,18 +104,22 @@ def main(args=None):
default=False, action='store_true', default=False, action='store_true',
help="Reduce output; don't print summaries") help="Reduce output; don't print summaries")
parser.add_argument("--sizes", '-s', parser.add_argument("--sizes", '-s',
default=False, action="store_true", dest="print_size_histogram", default=False, action="store_true",
dest="print_size_histogram",
help="print histogram of object sizes") help="print histogram of object sizes")
parser.add_argument("--no-stats", '-S', parser.add_argument("--no-stats", '-S',
default=True, action="store_false", dest="dostats", default=True, action="store_false", dest="dostats",
help="don't print statistics") help="don't print statistics")
parser.add_argument("--load-histogram", "-h", parser.add_argument("--load-histogram", "-h",
default=False, action="store_true", dest="print_histogram", default=False, action="store_true",
dest="print_histogram",
help="print histogram of object load frequencies") help="print histogram of object load frequencies")
parser.add_argument("--check", "-X", parser.add_argument("--check", "-X",
default=False, action="store_true", dest="heuristic", default=False, action="store_true", dest="heuristic",
help=" enable heuristic checking for misaligned records: oids > 2**32" help=" enable heuristic checking for misaligned "
" will be rejected; this requires the tracefile to be seekable") "records: oids > 2**32"
" will be rejected; this requires the tracefile "
"to be seekable")
add_interval_argument(parser) add_interval_argument(parser)
add_tracefile_argument(parser) add_tracefile_argument(parser)
...@@ -123,20 +132,20 @@ def main(args=None): ...@@ -123,20 +132,20 @@ def main(args=None):
f = options.tracefile f = options.tracefile
rt0 = time.time() rt0 = time.time()
bycode = {} # map code to count of occurrences bycode = {} # map code to count of occurrences
byinterval = {} # map code to count in current interval byinterval = {} # map code to count in current interval
records = 0 # number of trace records read records = 0 # number of trace records read
versions = 0 # number of trace records with versions versions = 0 # number of trace records with versions
datarecords = 0 # number of records with dlen set datarecords = 0 # number of records with dlen set
datasize = 0 # sum of dlen across records with dlen set datasize = 0 # sum of dlen across records with dlen set
oids = {} # map oid to number of times it was loaded oids = {} # map oid to number of times it was loaded
bysize = {} # map data size to number of loads bysize = {} # map data size to number of loads
bysizew = {} # map data size to number of writes bysizew = {} # map data size to number of writes
total_loads = 0 total_loads = 0
t0 = None # first timestamp seen t0 = None # first timestamp seen
te = None # most recent timestamp seen te = None # most recent timestamp seen
h0 = None # timestamp at start of current interval h0 = None # timestamp at start of current interval
he = None # timestamp at end of current interval he = None # timestamp at end of current interval
thisinterval = None # generally te//interval thisinterval = None # generally te//interval
f_read = f.read f_read = f.read
unpack = struct.unpack unpack = struct.unpack
...@@ -144,7 +153,8 @@ def main(args=None): ...@@ -144,7 +153,8 @@ def main(args=None):
FMT_SIZE = struct.calcsize(FMT) FMT_SIZE = struct.calcsize(FMT)
assert FMT_SIZE == 26 assert FMT_SIZE == 26
# Read file, gathering statistics, and printing each record if verbose. # Read file, gathering statistics, and printing each record if verbose.
print(' '*16, "%7s %7s %7s %7s" % ('loads', 'hits', 'inv(h)', 'writes'), end=' ') print(' '*16, "%7s %7s %7s %7s" % (
'loads', 'hits', 'inv(h)', 'writes'), end=' ')
print('hitrate') print('hitrate')
try: try:
while 1: while 1:
...@@ -187,10 +197,10 @@ def main(args=None): ...@@ -187,10 +197,10 @@ def main(args=None):
bycode[code] = bycode.get(code, 0) + 1 bycode[code] = bycode.get(code, 0) + 1
byinterval[code] = byinterval.get(code, 0) + 1 byinterval[code] = byinterval.get(code, 0) + 1
if dlen: if dlen:
if code & 0x70 == 0x20: # All loads if code & 0x70 == 0x20: # All loads
bysize[dlen] = d = bysize.get(dlen) or {} bysize[dlen] = d = bysize.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1 d[oid] = d.get(oid, 0) + 1
elif code & 0x70 == 0x50: # All stores elif code & 0x70 == 0x50: # All stores
bysizew[dlen] = d = bysizew.get(dlen) or {} bysizew[dlen] = d = bysizew.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1 d[oid] = d.get(oid, 0) + 1
if options.verbose: if options.verbose:
...@@ -205,7 +215,7 @@ def main(args=None): ...@@ -205,7 +215,7 @@ def main(args=None):
if code & 0x70 == 0x20: if code & 0x70 == 0x20:
oids[oid] = oids.get(oid, 0) + 1 oids[oid] = oids.get(oid, 0) + 1
total_loads += 1 total_loads += 1
elif code == 0x00: # restart elif code == 0x00: # restart
if not options.quiet: if not options.quiet:
dumpbyinterval(byinterval, h0, he) dumpbyinterval(byinterval, h0, he)
byinterval = {} byinterval = {}
...@@ -279,6 +289,7 @@ def main(args=None): ...@@ -279,6 +289,7 @@ def main(args=None):
dumpbysize(bysizew, "written", "writes") dumpbysize(bysizew, "written", "writes")
dumpbysize(bysize, "loaded", "loads") dumpbysize(bysize, "loaded", "loads")
def dumpbysize(bysize, how, how2): def dumpbysize(bysize, how, how2):
print() print()
print("Unique sizes %s: %s" % (how, addcommas(len(bysize)))) print("Unique sizes %s: %s" % (how, addcommas(len(bysize))))
...@@ -292,6 +303,7 @@ def dumpbysize(bysize, how, how2): ...@@ -292,6 +303,7 @@ def dumpbysize(bysize, how, how2):
len(bysize.get(size, "")), len(bysize.get(size, "")),
loads)) loads))
def dumpbyinterval(byinterval, h0, he): def dumpbyinterval(byinterval, h0, he):
loads = hits = invals = writes = 0 loads = hits = invals = writes = 0
for code in byinterval: for code in byinterval:
...@@ -301,7 +313,7 @@ def dumpbyinterval(byinterval, h0, he): ...@@ -301,7 +313,7 @@ def dumpbyinterval(byinterval, h0, he):
if code in (0x22, 0x26): if code in (0x22, 0x26):
hits += n hits += n
elif code & 0x40: elif code & 0x40:
writes += byinterval[code] writes += byinterval[code]
elif code & 0x10: elif code & 0x10:
if code != 0x10: if code != 0x10:
invals += byinterval[code] invals += byinterval[code]
...@@ -315,6 +327,7 @@ def dumpbyinterval(byinterval, h0, he): ...@@ -315,6 +327,7 @@ def dumpbyinterval(byinterval, h0, he):
ctime(h0)[4:-8], ctime(he)[14:-8], ctime(h0)[4:-8], ctime(he)[14:-8],
loads, hits, invals, writes, hr)) loads, hits, invals, writes, hr))
def hitrate(bycode): def hitrate(bycode):
loads = hits = 0 loads = hits = 0
for code in bycode: for code in bycode:
...@@ -328,6 +341,7 @@ def hitrate(bycode): ...@@ -328,6 +341,7 @@ def hitrate(bycode):
else: else:
return 0.0 return 0.0
def histogram(d): def histogram(d):
bins = {} bins = {}
for v in six.itervalues(d): for v in six.itervalues(d):
...@@ -335,15 +349,18 @@ def histogram(d): ...@@ -335,15 +349,18 @@ def histogram(d):
L = sorted(bins.items()) L = sorted(bins.items())
return L return L
def U64(s): def U64(s):
return struct.unpack(">Q", s)[0] return struct.unpack(">Q", s)[0]
def oid_repr(oid): def oid_repr(oid):
if isinstance(oid, six.binary_type) and len(oid) == 8: if isinstance(oid, six.binary_type) and len(oid) == 8:
return '%16x' % U64(oid) return '%16x' % U64(oid)
else: else:
return repr(oid) return repr(oid)
def addcommas(n): def addcommas(n):
sign, s = '', str(n) sign, s = '', str(n)
if s[0] == '-': if s[0] == '-':
...@@ -354,6 +371,7 @@ def addcommas(n): ...@@ -354,6 +371,7 @@ def addcommas(n):
i -= 3 i -= 3
return sign + s return sign + s
explain = { explain = {
# The first hex digit shows the operation, the second the outcome. # The first hex digit shows the operation, the second the outcome.
# If the second digit is in "02468" then it is a 'miss'. # If the second digit is in "02468" then it is a 'miss'.
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Parse the BLATHER logging generated by ZEO2. """Parse the BLATHER logging generated by ZEO2.
An example of the log format is: An example of the log format is:
2002-04-15T13:05:29 BLATHER(-100) ZEO Server storea(3235680, [714], 235339406490168806) ('10.0.26.30', 45514) 2002-04-15T13:05:29 BLATHER(-100) ZEO Server storea(3235680, [714], 235339406490168806) ('10.0.26.30', 45514) # NOQA: E501 line too long
""" """
from __future__ import print_function from __future__ import print_function
from __future__ import print_function from __future__ import print_function
...@@ -14,7 +14,8 @@ from __future__ import print_function ...@@ -14,7 +14,8 @@ from __future__ import print_function
import re import re
import time import time
rx_time = re.compile('(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)') rx_time = re.compile(r'(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
def parse_time(line): def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None.""" """Return the time portion of a zLOG line in seconds or None."""
...@@ -26,11 +27,14 @@ def parse_time(line): ...@@ -26,11 +27,14 @@ def parse_time(line):
time_l = [int(elt) for elt in time_.split(':')] time_l = [int(elt) for elt in time_.split(':')]
return int(time.mktime(date_l + time_l + [0, 0, 0])) return int(time.mktime(date_l + time_l + [0, 0, 0]))
rx_meth = re.compile("zrpc:\d+ calling (\w+)\((.*)")
rx_meth = re.compile(r"zrpc:\d+ calling (\w+)\((.*)")
def parse_method(line): def parse_method(line):
pass pass
def parse_line(line): def parse_line(line):
"""Parse a log entry and return time, method info, and client.""" """Parse a log entry and return time, method info, and client."""
t = parse_time(line) t = parse_time(line)
...@@ -47,6 +51,7 @@ def parse_line(line): ...@@ -47,6 +51,7 @@ def parse_line(line):
m = meth_name, tuple(meth_args) m = meth_name, tuple(meth_args)
return t, m return t, m
class TStats(object): class TStats(object):
counter = 1 counter = 1
...@@ -61,7 +66,6 @@ class TStats(object): ...@@ -61,7 +66,6 @@ class TStats(object):
def report(self): def report(self):
"""Print a report about the transaction""" """Print a report about the transaction"""
t = time.ctime(self.begin)
if hasattr(self, "vote"): if hasattr(self, "vote"):
d_vote = self.vote - self.begin d_vote = self.vote - self.begin
else: else:
...@@ -69,10 +73,11 @@ class TStats(object): ...@@ -69,10 +73,11 @@ class TStats(object):
if hasattr(self, "finish"): if hasattr(self, "finish"):
d_finish = self.finish - self.begin d_finish = self.finish - self.begin
else: else:
d_finish = "*" d_finish = "*"
print(self.fmt % (time.ctime(self.begin), d_vote, d_finish, print(self.fmt % (time.ctime(self.begin), d_vote, d_finish,
self.user, self.url)) self.user, self.url))
class TransactionParser(object): class TransactionParser(object):
def __init__(self): def __init__(self):
...@@ -122,6 +127,7 @@ class TransactionParser(object): ...@@ -122,6 +127,7 @@ class TransactionParser(object):
L.sort() L.sort()
return [t for (id, t) in L] return [t for (id, t) in L]
if __name__ == "__main__": if __name__ == "__main__":
import fileinput import fileinput
...@@ -131,7 +137,7 @@ if __name__ == "__main__": ...@@ -131,7 +137,7 @@ if __name__ == "__main__":
i += 1 i += 1
try: try:
p.parse(line) p.parse(line)
except: except: # NOQA: E722 bare except
print("line", i) print("line", i)
raise raise
print("Transaction: %d" % len(p.txns)) print("Transaction: %d" % len(p.txns))
......
...@@ -12,18 +12,21 @@ ...@@ -12,18 +12,21 @@
# #
############################################################################## ##############################################################################
from __future__ import print_function from __future__ import print_function
import doctest, re, unittest import doctest
import re
import unittest
from zope.testing import renormalizing from zope.testing import renormalizing
def test_suite(): def test_suite():
return unittest.TestSuite(( return unittest.TestSuite((
doctest.DocFileSuite( doctest.DocFileSuite(
'zeopack.test', 'zeopack.test',
checker=renormalizing.RENormalizing([ checker=renormalizing.RENormalizing([
(re.compile('usage: Usage: '), 'Usage: '), # Py 2.4 (re.compile('usage: Usage: '), 'Usage: '), # Py 2.4
(re.compile('options:'), 'Options:'), # Py 2.4 (re.compile('options:'), 'Options:'), # Py 2.4
]), ]),
globs={'print_function': print_function}, globs={'print_function': print_function},
), ),
)) ))
...@@ -25,6 +25,7 @@ from ZEO.ClientStorage import ClientStorage ...@@ -25,6 +25,7 @@ from ZEO.ClientStorage import ClientStorage
ZERO = '\0'*8 ZERO = '\0'*8
def main(): def main():
if len(sys.argv) not in (3, 4): if len(sys.argv) not in (3, 4):
sys.stderr.write("Usage: timeout.py address delay [storage-name]\n" % sys.stderr.write("Usage: timeout.py address delay [storage-name]\n" %
...@@ -68,5 +69,6 @@ def main(): ...@@ -68,5 +69,6 @@ def main():
time.sleep(delay) time.sleep(delay)
print("Done.") print("Done.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -8,7 +8,6 @@ import time ...@@ -8,7 +8,6 @@ import time
import traceback import traceback
import ZEO.ClientStorage import ZEO.ClientStorage
from six.moves import map from six.moves import map
from six.moves import zip
usage = """Usage: %prog [options] [servers] usage = """Usage: %prog [options] [servers]
...@@ -21,7 +20,8 @@ each is of the form: ...@@ -21,7 +20,8 @@ each is of the form:
""" """
WAIT = 10 # wait no more than 10 seconds for client to connect WAIT = 10 # wait no more than 10 seconds for client to connect
def _main(args=None, prog=None): def _main(args=None, prog=None):
if args is None: if args is None:
...@@ -160,10 +160,11 @@ def _main(args=None, prog=None): ...@@ -160,10 +160,11 @@ def _main(args=None, prog=None):
continue continue
cs.pack(packt, wait=True) cs.pack(packt, wait=True)
cs.close() cs.close()
except: except: # NOQA: E722 bare except
traceback.print_exception(*(sys.exc_info()+(99, sys.stderr))) traceback.print_exception(*(sys.exc_info()+(99, sys.stderr)))
error("Error packing storage %s in %r" % (name, addr)) error("Error packing storage %s in %r" % (name, addr))
def main(*args): def main(*args):
root_logger = logging.getLogger() root_logger = logging.getLogger()
old_level = root_logger.getEffectiveLevel() old_level = root_logger.getEffectiveLevel()
...@@ -178,6 +179,6 @@ def main(*args): ...@@ -178,6 +179,6 @@ def main(*args):
logging.getLogger().setLevel(old_level) logging.getLogger().setLevel(old_level)
logging.getLogger().removeHandler(handler) logging.getLogger().removeHandler(handler)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -37,7 +37,6 @@ STATEFILE = 'zeoqueue.pck' ...@@ -37,7 +37,6 @@ STATEFILE = 'zeoqueue.pck'
PROGRAM = sys.argv[0] PROGRAM = sys.argv[0]
tcre = re.compile(r""" tcre = re.compile(r"""
(?P<ymd> (?P<ymd>
\d{4}- # year \d{4}- # year
...@@ -67,7 +66,6 @@ ccre = re.compile(r""" ...@@ -67,7 +66,6 @@ ccre = re.compile(r"""
wcre = re.compile(r'Clients waiting: (?P<num>\d+)') wcre = re.compile(r'Clients waiting: (?P<num>\d+)')
def parse_time(line): def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None.""" """Return the time portion of a zLOG line in seconds or None."""
mo = tcre.match(line) mo = tcre.match(line)
...@@ -97,7 +95,6 @@ class Txn(object): ...@@ -97,7 +95,6 @@ class Txn(object):
return False return False
class Status(object): class Status(object):
"""Track status of ZEO server by replaying log records. """Track status of ZEO server by replaying log records.
...@@ -303,7 +300,6 @@ class Status(object): ...@@ -303,7 +300,6 @@ class Status(object):
break break
def usage(code, msg=''): def usage(code, msg=''):
print(__doc__ % globals(), file=sys.stderr) print(__doc__ % globals(), file=sys.stderr)
if msg: if msg:
......
...@@ -41,25 +41,25 @@ import time ...@@ -41,25 +41,25 @@ import time
import getopt import getopt
import operator import operator
# ZEO logs measure wall-clock time so for consistency we need to do the same # ZEO logs measure wall-clock time so for consistency we need to do the same
#from time import clock as now # from time import clock as now
from time import time as now from time import time as now
from ZODB.FileStorage import FileStorage from ZODB.FileStorage import FileStorage
#from BDBStorage.BDBFullStorage import BDBFullStorage # from BDBStorage.BDBFullStorage import BDBFullStorage
#from Standby.primary import PrimaryStorage # from Standby.primary import PrimaryStorage
#from Standby.config import RS_PORT # from Standby.config import RS_PORT
from ZODB.Connection import TransactionMetaData from ZODB.Connection import TransactionMetaData
from ZODB.utils import p64 from ZODB.utils import p64
from functools import reduce from functools import reduce
datecre = re.compile('(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)') datecre = re.compile(r'(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
methcre = re.compile("ZEO Server (\w+)\((.*)\) \('(.*)', (\d+)") methcre = re.compile(r"ZEO Server (\w+)\((.*)\) \('(.*)', (\d+)")
class StopParsing(Exception): class StopParsing(Exception):
pass pass
def usage(code, msg=''): def usage(code, msg=''):
print(__doc__) print(__doc__)
if msg: if msg:
...@@ -67,7 +67,6 @@ def usage(code, msg=''): ...@@ -67,7 +67,6 @@ def usage(code, msg=''):
sys.exit(code) sys.exit(code)
def parse_time(line): def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None.""" """Return the time portion of a zLOG line in seconds or None."""
mo = datecre.match(line) mo = datecre.match(line)
...@@ -95,7 +94,6 @@ def parse_line(line): ...@@ -95,7 +94,6 @@ def parse_line(line):
return t, m, c return t, m, c
class StoreStat(object): class StoreStat(object):
def __init__(self, when, oid, size): def __init__(self, when, oid, size):
self.when = when self.when = when
...@@ -104,8 +102,10 @@ class StoreStat(object): ...@@ -104,8 +102,10 @@ class StoreStat(object):
# Crufty # Crufty
def __getitem__(self, i): def __getitem__(self, i):
if i == 0: return self.oid if i == 0:
if i == 1: return self.size return self.oid
if i == 1:
return self.size
raise IndexError raise IndexError
...@@ -136,10 +136,10 @@ class TxnStat(object): ...@@ -136,10 +136,10 @@ class TxnStat(object):
self._finishtime = when self._finishtime = when
# Mapping oid -> revid # Mapping oid -> revid
_revids = {} _revids = {}
class ReplayTxn(TxnStat): class ReplayTxn(TxnStat):
def __init__(self, storage): def __init__(self, storage):
self._storage = storage self._storage = storage
...@@ -157,7 +157,7 @@ class ReplayTxn(TxnStat): ...@@ -157,7 +157,7 @@ class ReplayTxn(TxnStat):
# BAW: simulate a pickle of the given size # BAW: simulate a pickle of the given size
data = 'x' * obj.size data = 'x' * obj.size
# BAW: ignore versions for now # BAW: ignore versions for now
newrevid = self._storage.store(p64(oid), revid, data, '', t) newrevid = self._storage.store(p64(oid), revid, data, '', t)
_revids[oid] = newrevid _revids[oid] = newrevid
if self._aborttime: if self._aborttime:
self._storage.tpc_abort(t) self._storage.tpc_abort(t)
...@@ -172,7 +172,6 @@ class ReplayTxn(TxnStat): ...@@ -172,7 +172,6 @@ class ReplayTxn(TxnStat):
self._replaydelta = t1 - t0 - origdelta self._replaydelta = t1 - t0 - origdelta
class ZEOParser(object): class ZEOParser(object):
def __init__(self, maxtxns=-1, report=1, storage=None): def __init__(self, maxtxns=-1, report=1, storage=None):
self.__txns = [] self.__txns = []
...@@ -261,7 +260,6 @@ class ZEOParser(object): ...@@ -261,7 +260,6 @@ class ZEOParser(object):
print('average faster txn was:', float(sum) / len(faster)) print('average faster txn was:', float(sum) / len(faster))
def main(): def main():
try: try:
opts, args = getopt.getopt( opts, args = getopt.getopt(
...@@ -294,8 +292,8 @@ def main(): ...@@ -294,8 +292,8 @@ def main():
if replay: if replay:
storage = FileStorage(storagefile) storage = FileStorage(storagefile)
#storage = BDBFullStorage(storagefile) # storage = BDBFullStorage(storagefile)
#storage = PrimaryStorage('yyz', storage, RS_PORT) # storage = PrimaryStorage('yyz', storage, RS_PORT)
t0 = now() t0 = now()
p = ZEOParser(maxtxns, report, storage) p = ZEOParser(maxtxns, report, storage)
i = 0 i = 0
...@@ -308,7 +306,7 @@ def main(): ...@@ -308,7 +306,7 @@ def main():
p.parse(line) p.parse(line)
except StopParsing: except StopParsing:
break break
except: except: # NOQA: E722 bare except
print('input file line:', i) print('input file line:', i)
raise raise
t1 = now() t1 = now()
...@@ -321,6 +319,5 @@ def main(): ...@@ -321,6 +319,5 @@ def main():
print('total time:', t3-t0) print('total time:', t3-t0)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -169,9 +169,11 @@ from __future__ import print_function ...@@ -169,9 +169,11 @@ from __future__ import print_function
from __future__ import print_function from __future__ import print_function
from __future__ import print_function from __future__ import print_function
import datetime, sys, re, os import datetime
import os
import re
import sys
from six.moves import map from six.moves import map
from six.moves import zip
def time(line): def time(line):
...@@ -187,9 +189,10 @@ def sub(t1, t2): ...@@ -187,9 +189,10 @@ def sub(t1, t2):
return delta.days*86400.0+delta.seconds+delta.microseconds/1000000.0 return delta.days*86400.0+delta.seconds+delta.microseconds/1000000.0
waitre = re.compile(r'Clients waiting: (\d+)') waitre = re.compile(r'Clients waiting: (\d+)')
idre = re.compile(r' ZSS:\d+/(\d+.\d+.\d+.\d+:\d+) ') idre = re.compile(r' ZSS:\d+/(\d+.\d+.\d+.\d+:\d+) ')
def blocked_times(args): def blocked_times(args):
f, thresh = args f, thresh = args
...@@ -217,7 +220,6 @@ def blocked_times(args): ...@@ -217,7 +220,6 @@ def blocked_times(args):
t2 = t1 t2 = t1
if not blocking and last_blocking: if not blocking and last_blocking:
last_wait = 0
t2 = time(line) t2 = time(line)
cid = idre.search(line).group(1) cid = idre.search(line).group(1)
...@@ -225,11 +227,14 @@ def blocked_times(args): ...@@ -225,11 +227,14 @@ def blocked_times(args):
d = sub(t1, time(line)) d = sub(t1, time(line))
if d >= thresh: if d >= thresh:
print(t1, sub(t1, t2), cid, d) print(t1, sub(t1, t2), cid, d)
t1 = t2 = cid = blocking = waiting = last_wait = max_wait = 0 t1 = t2 = cid = blocking = waiting = 0
last_blocking = blocking last_blocking = blocking
connidre = re.compile(r' zrpc-conn:(\d+.\d+.\d+.\d+:\d+) ') connidre = re.compile(r' zrpc-conn:(\d+.\d+.\d+.\d+:\d+) ')
def time_calls(f): def time_calls(f):
f, thresh = f f, thresh = f
if f == '-': if f == '-':
...@@ -255,6 +260,7 @@ def time_calls(f): ...@@ -255,6 +260,7 @@ def time_calls(f):
print(maxd) print(maxd)
def xopen(f): def xopen(f):
if f == '-': if f == '-':
return sys.stdin return sys.stdin
...@@ -262,6 +268,7 @@ def xopen(f): ...@@ -262,6 +268,7 @@ def xopen(f):
return os.popen(f, 'r') return os.popen(f, 'r')
return open(f) return open(f)
def time_tpc(f): def time_tpc(f):
f, thresh = f f, thresh = f
if f == '-': if f == '-':
...@@ -307,11 +314,14 @@ def time_tpc(f): ...@@ -307,11 +314,14 @@ def time_tpc(f):
t = time(line) t = time(line)
d = sub(t1, t) d = sub(t1, t)
if d >= thresh: if d >= thresh:
print('c', t1, cid, sub(t1, t2), vs, sub(t2, t3), sub(t3, t)) print('c', t1, cid, sub(t1, t2),
vs, sub(t2, t3), sub(t3, t))
del transactions[cid] del transactions[cid]
newobre = re.compile(r"storea\(.*, '\\x00\\x00\\x00\\x00\\x00") newobre = re.compile(r"storea\(.*, '\\x00\\x00\\x00\\x00\\x00")
def time_trans(f): def time_trans(f):
f, thresh = f f, thresh = f
if f == '-': if f == '-':
...@@ -363,8 +373,8 @@ def time_trans(f): ...@@ -363,8 +373,8 @@ def time_trans(f):
t = time(line) t = time(line)
d = sub(t1, t) d = sub(t1, t)
if d >= thresh: if d >= thresh:
print(t1, cid, "%s/%s" % (stores, old), \ print(t1, cid, "%s/%s" % (stores, old),
sub(t0, t1), sub(t1, t2), vs, \ sub(t0, t1), sub(t1, t2), vs,
sub(t2, t), 'abort') sub(t2, t), 'abort')
del transactions[cid] del transactions[cid]
elif ' calling tpc_finish(' in line: elif ' calling tpc_finish(' in line:
...@@ -377,11 +387,12 @@ def time_trans(f): ...@@ -377,11 +387,12 @@ def time_trans(f):
t = time(line) t = time(line)
d = sub(t1, t) d = sub(t1, t)
if d >= thresh: if d >= thresh:
print(t1, cid, "%s/%s" % (stores, old), \ print(t1, cid, "%s/%s" % (stores, old),
sub(t0, t1), sub(t1, t2), vs, \ sub(t0, t1), sub(t1, t2), vs,
sub(t2, t3), sub(t3, t)) sub(t2, t3), sub(t3, t))
del transactions[cid] del transactions[cid]
def minute(f, slice=16, detail=1, summary=1): def minute(f, slice=16, detail=1, summary=1):
f, = f f, = f
...@@ -405,10 +416,9 @@ def minute(f, slice=16, detail=1, summary=1): ...@@ -405,10 +416,9 @@ def minute(f, slice=16, detail=1, summary=1):
for line in f: for line in f:
line = line.strip() line = line.strip()
if (line.find('returns') > 0 if line.find('returns') > 0 or \
or line.find('storea') > 0 line.find('storea') > 0 or \
or line.find('tpc_abort') > 0 line.find('tpc_abort') > 0:
):
client = connidre.search(line).group(1) client = connidre.search(line).group(1)
m = line[:slice] m = line[:slice]
if m != mlast: if m != mlast:
...@@ -452,12 +462,13 @@ def minute(f, slice=16, detail=1, summary=1): ...@@ -452,12 +462,13 @@ def minute(f, slice=16, detail=1, summary=1):
print('Summary: \t', '\t'.join(('min', '10%', '25%', 'med', print('Summary: \t', '\t'.join(('min', '10%', '25%', 'med',
'75%', '90%', 'max', 'mean'))) '75%', '90%', 'max', 'mean')))
print("n=%6d\t" % len(cls), '-'*62) print("n=%6d\t" % len(cls), '-'*62)
print('Clients: \t', '\t'.join(map(str,stats(cls)))) print('Clients: \t', '\t'.join(map(str, stats(cls))))
print('Reads: \t', '\t'.join(map(str,stats(rs)))) print('Reads: \t', '\t'.join(map(str, stats(rs))))
print('Stores: \t', '\t'.join(map(str,stats(ss)))) print('Stores: \t', '\t'.join(map(str, stats(ss))))
print('Commits: \t', '\t'.join(map(str,stats(cs)))) print('Commits: \t', '\t'.join(map(str, stats(cs))))
print('Aborts: \t', '\t'.join(map(str,stats(aborts)))) print('Aborts: \t', '\t'.join(map(str, stats(aborts))))
print('Trans: \t', '\t'.join(map(str,stats(ts)))) print('Trans: \t', '\t'.join(map(str, stats(ts))))
def stats(s): def stats(s):
s.sort() s.sort()
...@@ -468,13 +479,14 @@ def stats(s): ...@@ -468,13 +479,14 @@ def stats(s):
ni = n + 1 ni = n + 1
for p in .1, .25, .5, .75, .90: for p in .1, .25, .5, .75, .90:
lp = ni*p lp = ni*p
l = int(lp) lp_int = int(lp)
if lp < 1 or lp > n: if lp < 1 or lp > n:
out.append('-') out.append('-')
elif abs(lp-l) < .00001: elif abs(lp-lp_int) < .00001:
out.append(s[l-1]) out.append(s[lp_int-1])
else: else:
out.append(int(s[l-1] + (lp - l) * (s[l] - s[l-1]))) out.append(
int(s[lp_int-1] + (lp - lp_int) * (s[lp_int] - s[lp_int-1])))
mean = 0.0 mean = 0.0
for v in s: for v in s:
...@@ -484,24 +496,31 @@ def stats(s): ...@@ -484,24 +496,31 @@ def stats(s):
return out return out
def minutes(f): def minutes(f):
minute(f, 16, detail=0) minute(f, 16, detail=0)
def hour(f): def hour(f):
minute(f, 13) minute(f, 13)
def day(f): def day(f):
minute(f, 10) minute(f, 10)
def hours(f): def hours(f):
minute(f, 13, detail=0) minute(f, 13, detail=0)
def days(f): def days(f):
minute(f, 10, detail=0) minute(f, 10, detail=0)
new_connection_idre = re.compile( new_connection_idre = re.compile(
r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):") r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):")
def verify(f): def verify(f):
f, = f f, = f
...@@ -527,6 +546,7 @@ def verify(f): ...@@ -527,6 +546,7 @@ def verify(f):
d = sub(t1, time(line)) d = sub(t1, time(line))
print(cid, t1, n, d, n and (d*1000.0/n) or '-') print(cid, t1, n, d, n and (d*1000.0/n) or '-')
def recovery(f): def recovery(f):
f, = f f, = f
...@@ -542,16 +562,16 @@ def recovery(f): ...@@ -542,16 +562,16 @@ def recovery(f):
n += 1 n += 1
if line.find('RecoveryServer') < 0: if line.find('RecoveryServer') < 0:
continue continue
l = line.find('sending transaction ') pos = line.find('sending transaction ')
if l > 0 and last.find('sending transaction ') > 0: if pos > 0 and last.find('sending transaction ') > 0:
trans.append(line[l+20:].strip()) trans.append(line[pos+20:].strip())
else: else:
if trans: if trans:
if len(trans) > 1: if len(trans) > 1:
print(" ... %s similar records skipped ..." % ( print(" ... %s similar records skipped ..." % (
len(trans) - 1)) len(trans) - 1))
print(n, last.strip()) print(n, last.strip())
trans=[] trans = []
print(n, line.strip()) print(n, line.strip())
last = line last = line
...@@ -561,6 +581,5 @@ def recovery(f): ...@@ -561,6 +581,5 @@ def recovery(f):
print(n, last.strip()) print(n, last.strip())
if __name__ == '__main__': if __name__ == '__main__':
globals()[sys.argv[1]](sys.argv[2:]) globals()[sys.argv[1]](sys.argv[2:])
...@@ -47,6 +47,7 @@ from ZEO.ClientStorage import ClientStorage ...@@ -47,6 +47,7 @@ from ZEO.ClientStorage import ClientStorage
ZEO_VERSION = 2 ZEO_VERSION = 2
def setup_logging(): def setup_logging():
# Set up logging to stderr which will show messages originating # Set up logging to stderr which will show messages originating
# at severity ERROR or higher. # at severity ERROR or higher.
...@@ -59,6 +60,7 @@ def setup_logging(): ...@@ -59,6 +60,7 @@ def setup_logging():
handler.setFormatter(fmt) handler.setFormatter(fmt)
root.addHandler(handler) root.addHandler(handler)
def check_server(addr, storage, write): def check_server(addr, storage, write):
t0 = time.time() t0 = time.time()
if ZEO_VERSION == 2: if ZEO_VERSION == 2:
...@@ -97,11 +99,13 @@ def check_server(addr, storage, write): ...@@ -97,11 +99,13 @@ def check_server(addr, storage, write):
t1 = time.time() t1 = time.time()
print("Elapsed time: %.2f" % (t1 - t0)) print("Elapsed time: %.2f" % (t1 - t0))
def usage(exit=1): def usage(exit=1):
print(__doc__) print(__doc__)
print(" ".join(sys.argv)) print(" ".join(sys.argv))
sys.exit(exit) sys.exit(exit)
def main(): def main():
host = None host = None
port = None port = None
...@@ -123,7 +127,7 @@ def main(): ...@@ -123,7 +127,7 @@ def main():
elif o == '--nowrite': elif o == '--nowrite':
write = 0 write = 0
elif o == '-1': elif o == '-1':
ZEO_VERSION = 1 ZEO_VERSION = 1 # NOQA: F841 unused variable
except Exception as err: except Exception as err:
s = str(err) s = str(err)
if s: if s:
...@@ -143,6 +147,7 @@ def main(): ...@@ -143,6 +147,7 @@ def main():
setup_logging() setup_logging()
check_server(addr, storage, write) check_server(addr, storage, write)
if __name__ == "__main__": if __name__ == "__main__":
try: try:
main() main()
......
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
REPR_LIMIT = 60 REPR_LIMIT = 60
def short_repr(obj): def short_repr(obj):
"Return an object repr limited to REPR_LIMIT bytes." """Return an object repr limited to REPR_LIMIT bytes."""
# Some of the objects being repr'd are large strings. A lot of memory # Some of the objects being repr'd are large strings. A lot of memory
# would be wasted to repr them and then truncate, so they are treated # would be wasted to repr them and then truncate, so they are treated
......
...@@ -17,6 +17,7 @@ from ZODB.Connection import TransactionMetaData ...@@ -17,6 +17,7 @@ from ZODB.Connection import TransactionMetaData
from ZODB.tests.MinPO import MinPO from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_unpickle from ZODB.tests.StorageTestBase import zodb_unpickle
class TransUndoStorageWithCache(object): class TransUndoStorageWithCache(object):
def checkUndoInvalidation(self): def checkUndoInvalidation(self):
......
...@@ -20,12 +20,12 @@ from persistent.TimeStamp import TimeStamp ...@@ -20,12 +20,12 @@ from persistent.TimeStamp import TimeStamp
from ZODB.Connection import TransactionMetaData from ZODB.Connection import TransactionMetaData
from ZODB.tests.StorageTestBase import zodb_pickle, MinPO from ZODB.tests.StorageTestBase import zodb_pickle, MinPO
import ZEO.ClientStorage
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
from ZEO.tests.TestThread import TestThread from ZEO.tests.TestThread import TestThread
ZERO = b'\0'*8 ZERO = b'\0'*8
class WorkerThread(TestThread): class WorkerThread(TestThread):
# run the entire test in a thread so that the blocking call for # run the entire test in a thread so that the blocking call for
...@@ -62,6 +62,7 @@ class WorkerThread(TestThread): ...@@ -62,6 +62,7 @@ class WorkerThread(TestThread):
self.ready.set() self.ready.set()
future.result(9) future.result(9)
class CommitLockTests(object): class CommitLockTests(object):
NUM_CLIENTS = 5 NUM_CLIENTS = 5
...@@ -99,7 +100,7 @@ class CommitLockTests(object): ...@@ -99,7 +100,7 @@ class CommitLockTests(object):
for i in range(self.NUM_CLIENTS): for i in range(self.NUM_CLIENTS):
storage = self._new_storage_client() storage = self._new_storage_client()
txn = TransactionMetaData() txn = TransactionMetaData()
tid = self._get_timestamp() self._get_timestamp()
t = WorkerThread(self, storage, txn) t = WorkerThread(self, storage, txn)
self._threads.append(t) self._threads.append(t)
...@@ -118,9 +119,10 @@ class CommitLockTests(object): ...@@ -118,9 +119,10 @@ class CommitLockTests(object):
def _get_timestamp(self): def _get_timestamp(self):
t = time.time() t = time.time()
t = TimeStamp(*time.gmtime(t)[:5]+(t%60,)) t = TimeStamp(*time.gmtime(t)[:5]+(t % 60,))
return repr(t) return repr(t)
class CommitLockVoteTests(CommitLockTests): class CommitLockVoteTests(CommitLockTests):
def checkCommitLockVoteFinish(self): def checkCommitLockVoteFinish(self):
......
...@@ -26,11 +26,10 @@ from ZEO.tests import forker ...@@ -26,11 +26,10 @@ from ZEO.tests import forker
from ZODB.Connection import TransactionMetaData from ZODB.Connection import TransactionMetaData
from ZODB.DB import DB from ZODB.DB import DB
from ZODB.POSException import ReadOnlyError, ConflictError from ZODB.POSException import ReadOnlyError
from ZODB.tests.StorageTestBase import StorageTestBase from ZODB.tests.StorageTestBase import StorageTestBase
from ZODB.tests.MinPO import MinPO from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_pickle, zodb_unpickle from ZODB.tests.StorageTestBase import zodb_pickle, zodb_unpickle
import ZODB.tests.util
import transaction import transaction
...@@ -40,6 +39,7 @@ logger = logging.getLogger('ZEO.tests.ConnectionTests') ...@@ -40,6 +39,7 @@ logger = logging.getLogger('ZEO.tests.ConnectionTests')
ZERO = '\0'*8 ZERO = '\0'*8
class TestClientStorage(ClientStorage): class TestClientStorage(ClientStorage):
test_connection = False test_connection = False
...@@ -51,6 +51,7 @@ class TestClientStorage(ClientStorage): ...@@ -51,6 +51,7 @@ class TestClientStorage(ClientStorage):
self.connection_count_for_tests += 1 self.connection_count_for_tests += 1
self.verify_result = conn.verify_result self.verify_result = conn.verify_result
class DummyDB(object): class DummyDB(object):
def invalidate(self, *args, **kwargs): def invalidate(self, *args, **kwargs):
pass pass
...@@ -93,7 +94,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -93,7 +94,7 @@ class CommonSetupTearDown(StorageTestBase):
self._storage.close() self._storage.close()
if hasattr(self._storage, 'cleanup'): if hasattr(self._storage, 'cleanup'):
logging.debug("cleanup storage %s" % logging.debug("cleanup storage %s" %
self._storage.__name__) self._storage.__name__)
self._storage.cleanup() self._storage.cleanup()
for stop in self._servers: for stop in self._servers:
stop() stop()
...@@ -113,7 +114,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -113,7 +114,7 @@ class CommonSetupTearDown(StorageTestBase):
for dummy in range(5): for dummy in range(5):
try: try:
os.unlink(path) os.unlink(path)
except: except: # NOQA: E722 bare except
time.sleep(0.5) time.sleep(0.5)
else: else:
need_to_delete = False need_to_delete = False
...@@ -188,7 +189,7 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -188,7 +189,7 @@ class CommonSetupTearDown(StorageTestBase):
stop = self._servers[index] stop = self._servers[index]
if stop is not None: if stop is not None:
stop() stop()
self._servers[index] = lambda : None self._servers[index] = lambda: None
def pollUp(self, timeout=30.0, storage=None): def pollUp(self, timeout=30.0, storage=None):
if storage is None: if storage is None:
...@@ -271,7 +272,6 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -271,7 +272,6 @@ class ConnectionTests(CommonSetupTearDown):
self.assertRaises(ReadOnlyError, self._dostore) self.assertRaises(ReadOnlyError, self._dostore)
self._storage.close() self._storage.close()
def checkDisconnectionError(self): def checkDisconnectionError(self):
# Make sure we get a ClientDisconnected when we try to read an # Make sure we get a ClientDisconnected when we try to read an
# object when we're not connected to a storage server and the # object when we're not connected to a storage server and the
...@@ -374,7 +374,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -374,7 +374,7 @@ class ConnectionTests(CommonSetupTearDown):
pickle, rev = self._storage.load(oid, '') pickle, rev = self._storage.load(oid, '')
newobj = zodb_unpickle(pickle) newobj = zodb_unpickle(pickle)
self.assertEqual(newobj, obj) self.assertEqual(newobj, obj)
newobj.value = 42 # .value *should* be 42 forever after now, not 13 newobj.value = 42 # .value *should* be 42 forever after now, not 13
self._dostore(oid, data=newobj, revid=rev) self._dostore(oid, data=newobj, revid=rev)
self._storage.close() self._storage.close()
...@@ -416,6 +416,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -416,6 +416,7 @@ class ConnectionTests(CommonSetupTearDown):
def checkBadMessage2(self): def checkBadMessage2(self):
# just like a real message, but with an unpicklable argument # just like a real message, but with an unpicklable argument
global Hack global Hack
class Hack(object): class Hack(object):
pass pass
...@@ -505,7 +506,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -505,7 +506,7 @@ class ConnectionTests(CommonSetupTearDown):
r1["a"] = MinPO("a") r1["a"] = MinPO("a")
transaction.commit() transaction.commit()
self.assertEqual(r1._p_state, 0) # up-to-date self.assertEqual(r1._p_state, 0) # up-to-date
db2 = DB(self.openClientStorage()) db2 = DB(self.openClientStorage())
r2 = db2.open().root() r2 = db2.open().root()
...@@ -524,9 +525,9 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -524,9 +525,9 @@ class ConnectionTests(CommonSetupTearDown):
if r1._p_state == -1: if r1._p_state == -1:
break break
time.sleep(i / 10.0) time.sleep(i / 10.0)
self.assertEqual(r1._p_state, -1) # ghost self.assertEqual(r1._p_state, -1) # ghost
r1.keys() # unghostify r1.keys() # unghostify
self.assertEqual(r1._p_serial, r2._p_serial) self.assertEqual(r1._p_serial, r2._p_serial)
self.assertEqual(r1["b"].value, "b") self.assertEqual(r1["b"].value, "b")
...@@ -551,6 +552,7 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -551,6 +552,7 @@ class ConnectionTests(CommonSetupTearDown):
self.assertRaises(ClientDisconnected, self.assertRaises(ClientDisconnected,
self._storage.load, b'\0'*8, '') self._storage.load, b'\0'*8, '')
class SSLConnectionTests(ConnectionTests): class SSLConnectionTests(ConnectionTests):
def getServerConfig(self, addr, ro_svr): def getServerConfig(self, addr, ro_svr):
...@@ -585,13 +587,13 @@ class InvqTests(CommonSetupTearDown): ...@@ -585,13 +587,13 @@ class InvqTests(CommonSetupTearDown):
revid2 = self._dostore(oid2, revid2) revid2 = self._dostore(oid2, revid2)
forker.wait_until( forker.wait_until(
lambda : lambda:
perstorage.lastTransaction() == self._storage.lastTransaction()) perstorage.lastTransaction() == self._storage.lastTransaction())
perstorage.load(oid, '') perstorage.load(oid, '')
perstorage.close() perstorage.close()
forker.wait_until(lambda : os.path.exists('test-1.zec')) forker.wait_until(lambda: os.path.exists('test-1.zec'))
revid = self._dostore(oid, revid) revid = self._dostore(oid, revid)
...@@ -617,7 +619,7 @@ class InvqTests(CommonSetupTearDown): ...@@ -617,7 +619,7 @@ class InvqTests(CommonSetupTearDown):
revid = self._dostore(oid, revid) revid = self._dostore(oid, revid)
forker.wait_until( forker.wait_until(
"Client has seen all of the transactions from the server", "Client has seen all of the transactions from the server",
lambda : lambda:
perstorage.lastTransaction() == self._storage.lastTransaction() perstorage.lastTransaction() == self._storage.lastTransaction()
) )
perstorage.load(oid, '') perstorage.load(oid, '')
...@@ -635,6 +637,7 @@ class InvqTests(CommonSetupTearDown): ...@@ -635,6 +637,7 @@ class InvqTests(CommonSetupTearDown):
perstorage.close() perstorage.close()
class ReconnectionTests(CommonSetupTearDown): class ReconnectionTests(CommonSetupTearDown):
# The setUp() starts a server automatically. In order for its # The setUp() starts a server automatically. In order for its
# state to persist, we set the class variable keep to 1. In # state to persist, we set the class variable keep to 1. In
...@@ -798,7 +801,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -798,7 +801,7 @@ class ReconnectionTests(CommonSetupTearDown):
# Start a read-write server # Start a read-write server
self.startServer(index=1, read_only=0, keep=0) self.startServer(index=1, read_only=0, keep=0)
# After a while, stores should work # After a while, stores should work
for i in range(300): # Try for 30 seconds for i in range(300): # Try for 30 seconds
try: try:
self._dostore() self._dostore()
break break
...@@ -840,7 +843,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -840,7 +843,7 @@ class ReconnectionTests(CommonSetupTearDown):
revid = self._dostore(oid, revid) revid = self._dostore(oid, revid)
forker.wait_until( forker.wait_until(
"Client has seen all of the transactions from the server", "Client has seen all of the transactions from the server",
lambda : lambda:
perstorage.lastTransaction() == self._storage.lastTransaction() perstorage.lastTransaction() == self._storage.lastTransaction()
) )
perstorage.load(oid, '') perstorage.load(oid, '')
...@@ -894,7 +897,6 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -894,7 +897,6 @@ class ReconnectionTests(CommonSetupTearDown):
# Module ZEO.ClientStorage, line 709, in _update_cache # Module ZEO.ClientStorage, line 709, in _update_cache
# KeyError: ... # KeyError: ...
def checkReconnection(self): def checkReconnection(self):
# Check that the client reconnects when a server restarts. # Check that the client reconnects when a server restarts.
...@@ -952,6 +954,7 @@ class ReconnectionTests(CommonSetupTearDown): ...@@ -952,6 +954,7 @@ class ReconnectionTests(CommonSetupTearDown):
self.assertTrue(did_a_store) self.assertTrue(did_a_store)
self._storage.close() self._storage.close()
class TimeoutTests(CommonSetupTearDown): class TimeoutTests(CommonSetupTearDown):
timeout = 1 timeout = 1
...@@ -967,9 +970,8 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -967,9 +970,8 @@ class TimeoutTests(CommonSetupTearDown):
# Make sure it's logged as CRITICAL # Make sure it's logged as CRITICAL
with open("server.log") as f: with open("server.log") as f:
for line in f: for line in f:
if (('Transaction timeout after' in line) and if ('Transaction timeout after' in line) and \
('CRITICAL ZEO.StorageServer' in line) ('CRITICAL ZEO.StorageServer' in line):
):
break break
else: else:
self.fail('bad logging') self.fail('bad logging')
...@@ -1002,7 +1004,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -1002,7 +1004,7 @@ class TimeoutTests(CommonSetupTearDown):
t = TransactionMetaData() t = TransactionMetaData()
old_connection_count = storage.connection_count_for_tests old_connection_count = storage.connection_count_for_tests
storage.tpc_begin(t) storage.tpc_begin(t)
revid1 = storage.store(oid, ZERO, zodb_pickle(obj), '', t) storage.store(oid, ZERO, zodb_pickle(obj), '', t)
storage.tpc_vote(t) storage.tpc_vote(t)
# Now sleep long enough for the storage to time out # Now sleep long enough for the storage to time out
time.sleep(3) time.sleep(3)
...@@ -1021,6 +1023,7 @@ class TimeoutTests(CommonSetupTearDown): ...@@ -1021,6 +1023,7 @@ class TimeoutTests(CommonSetupTearDown):
# or the server. # or the server.
self.assertRaises(KeyError, storage.load, oid, '') self.assertRaises(KeyError, storage.load, oid, '')
class MSTThread(threading.Thread): class MSTThread(threading.Thread):
__super_init = threading.Thread.__init__ __super_init = threading.Thread.__init__
...@@ -1054,7 +1057,7 @@ class MSTThread(threading.Thread): ...@@ -1054,7 +1057,7 @@ class MSTThread(threading.Thread):
# Begin a transaction # Begin a transaction
t = TransactionMetaData() t = TransactionMetaData()
for c in clients: for c in clients:
#print("%s.%s.%s begin" % (tname, c.__name, i)) # print("%s.%s.%s begin" % (tname, c.__name, i))
c.tpc_begin(t) c.tpc_begin(t)
for j in range(testcase.nobj): for j in range(testcase.nobj):
...@@ -1063,18 +1066,18 @@ class MSTThread(threading.Thread): ...@@ -1063,18 +1066,18 @@ class MSTThread(threading.Thread):
oid = c.new_oid() oid = c.new_oid()
c.__oids.append(oid) c.__oids.append(oid)
data = MinPO("%s.%s.t%d.o%d" % (tname, c.__name, i, j)) data = MinPO("%s.%s.t%d.o%d" % (tname, c.__name, i, j))
#print(data.value) # print(data.value)
data = zodb_pickle(data) data = zodb_pickle(data)
c.store(oid, ZERO, data, '', t) c.store(oid, ZERO, data, '', t)
# Vote on all servers and handle serials # Vote on all servers and handle serials
for c in clients: for c in clients:
#print("%s.%s.%s vote" % (tname, c.__name, i)) # print("%s.%s.%s vote" % (tname, c.__name, i))
c.tpc_vote(t) c.tpc_vote(t)
# Finish on all servers # Finish on all servers
for c in clients: for c in clients:
#print("%s.%s.%s finish\n" % (tname, c.__name, i)) # print("%s.%s.%s finish\n" % (tname, c.__name, i))
c.tpc_finish(t) c.tpc_finish(t)
for c in clients: for c in clients:
...@@ -1090,7 +1093,7 @@ class MSTThread(threading.Thread): ...@@ -1090,7 +1093,7 @@ class MSTThread(threading.Thread):
for c in self.clients: for c in self.clients:
try: try:
c.close() c.close()
except: except: # NOQA: E722 bare except
pass pass
...@@ -1101,6 +1104,7 @@ def short_timeout(self): ...@@ -1101,6 +1104,7 @@ def short_timeout(self):
yield yield
self._storage._server.timeout = old self._storage._server.timeout = old
# Run IPv6 tests if V6 sockets are supported # Run IPv6 tests if V6 sockets are supported
try: try:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
......
...@@ -41,6 +41,7 @@ from ZODB.POSException import ReadConflictError, ConflictError ...@@ -41,6 +41,7 @@ from ZODB.POSException import ReadConflictError, ConflictError
# thought they added (i.e., the keys for which transaction.commit() # thought they added (i.e., the keys for which transaction.commit()
# did not raise any exception). # did not raise any exception).
class FailableThread(TestThread): class FailableThread(TestThread):
# mixin class # mixin class
...@@ -52,7 +53,7 @@ class FailableThread(TestThread): ...@@ -52,7 +53,7 @@ class FailableThread(TestThread):
def testrun(self): def testrun(self):
try: try:
self._testrun() self._testrun()
except: except: # NOQA: E722 bare except
# Report the failure here to all the other threads, so # Report the failure here to all the other threads, so
# that they stop quickly. # that they stop quickly.
self.stop.set() self.stop.set()
...@@ -81,12 +82,11 @@ class StressTask(object): ...@@ -81,12 +82,11 @@ class StressTask(object):
tree[key] = self.threadnum tree[key] = self.threadnum
def commit(self): def commit(self):
cn = self.cn
key = self.startnum key = self.startnum
self.tm.get().note(u"add key %s" % key) self.tm.get().note(u"add key %s" % key)
try: try:
self.tm.get().commit() self.tm.get().commit()
except ConflictError as msg: except ConflictError:
self.tm.abort() self.tm.abort()
else: else:
if self.sleep: if self.sleep:
...@@ -98,15 +98,18 @@ class StressTask(object): ...@@ -98,15 +98,18 @@ class StressTask(object):
self.tm.get().abort() self.tm.get().abort()
self.cn.close() self.cn.close()
def _runTasks(rounds, *tasks): def _runTasks(rounds, *tasks):
'''run *task* interleaved for *rounds* rounds.''' '''run *task* interleaved for *rounds* rounds.'''
def commit(run, actions): def commit(run, actions):
actions.append(':') actions.append(':')
for t in run: for t in run:
t.commit() t.commit()
del run[:] del run[:]
r = Random() r = Random()
r.seed(1064589285) # make it deterministic r.seed(1064589285) # make it deterministic
run = [] run = []
actions = [] actions = []
try: try:
...@@ -117,7 +120,7 @@ def _runTasks(rounds, *tasks): ...@@ -117,7 +120,7 @@ def _runTasks(rounds, *tasks):
run.append(t) run.append(t)
t.doStep() t.doStep()
actions.append(repr(t.startnum)) actions.append(repr(t.startnum))
commit(run,actions) commit(run, actions)
# stderr.write(' '.join(actions)+'\n') # stderr.write(' '.join(actions)+'\n')
finally: finally:
for t in tasks: for t in tasks:
...@@ -160,13 +163,14 @@ class StressThread(FailableThread): ...@@ -160,13 +163,14 @@ class StressThread(FailableThread):
self.commitdict[self] = 1 self.commitdict[self] = 1
if self.sleep: if self.sleep:
time.sleep(self.sleep) time.sleep(self.sleep)
except (ReadConflictError, ConflictError) as msg: except (ReadConflictError, ConflictError):
tm.abort() tm.abort()
else: else:
self.added_keys.append(key) self.added_keys.append(key)
key += self.step key += self.step
cn.close() cn.close()
class LargeUpdatesThread(FailableThread): class LargeUpdatesThread(FailableThread):
# A thread that performs a lot of updates. It attempts to modify # A thread that performs a lot of updates. It attempts to modify
...@@ -195,7 +199,7 @@ class LargeUpdatesThread(FailableThread): ...@@ -195,7 +199,7 @@ class LargeUpdatesThread(FailableThread):
# print("%d getting tree abort" % self.threadnum) # print("%d getting tree abort" % self.threadnum)
transaction.abort() transaction.abort()
keys_added = {} # set of keys we commit keys_added = {} # set of keys we commit
tkeys = [] tkeys = []
while not self.stop.isSet(): while not self.stop.isSet():
...@@ -212,7 +216,7 @@ class LargeUpdatesThread(FailableThread): ...@@ -212,7 +216,7 @@ class LargeUpdatesThread(FailableThread):
for key in keys: for key in keys:
try: try:
tree[key] = self.threadnum tree[key] = self.threadnum
except (ReadConflictError, ConflictError) as msg: except (ReadConflictError, ConflictError): # as msg:
# print("%d setting key %s" % (self.threadnum, msg)) # print("%d setting key %s" % (self.threadnum, msg))
transaction.abort() transaction.abort()
break break
...@@ -224,7 +228,7 @@ class LargeUpdatesThread(FailableThread): ...@@ -224,7 +228,7 @@ class LargeUpdatesThread(FailableThread):
self.commitdict[self] = 1 self.commitdict[self] = 1
if self.sleep: if self.sleep:
time.sleep(self.sleep) time.sleep(self.sleep)
except ConflictError as msg: except ConflictError: # as msg
# print("%d commit %s" % (self.threadnum, msg)) # print("%d commit %s" % (self.threadnum, msg))
transaction.abort() transaction.abort()
continue continue
...@@ -234,6 +238,7 @@ class LargeUpdatesThread(FailableThread): ...@@ -234,6 +238,7 @@ class LargeUpdatesThread(FailableThread):
self.added_keys = keys_added.keys() self.added_keys = keys_added.keys()
cn.close() cn.close()
class InvalidationTests(object): class InvalidationTests(object):
# Minimum # of seconds the main thread lets the workers run. The # Minimum # of seconds the main thread lets the workers run. The
...@@ -261,7 +266,7 @@ class InvalidationTests(object): ...@@ -261,7 +266,7 @@ class InvalidationTests(object):
transaction.abort() transaction.abort()
else: else:
raise raise
except: except: # NOQA: E722 bare except
display(tree) display(tree)
raise raise
......
...@@ -21,6 +21,7 @@ from ZODB.Connection import TransactionMetaData ...@@ -21,6 +21,7 @@ from ZODB.Connection import TransactionMetaData
from ..asyncio.testing import AsyncRPC from ..asyncio.testing import AsyncRPC
class IterationTests(object): class IterationTests(object):
def _assertIteratorIdsEmpty(self): def _assertIteratorIdsEmpty(self):
...@@ -44,7 +45,7 @@ class IterationTests(object): ...@@ -44,7 +45,7 @@ class IterationTests(object):
# everything goes away as expected. # everything goes away as expected.
gc.enable() gc.enable()
gc.collect() gc.collect()
gc.collect() # sometimes PyPy needs it twice to clear weak refs gc.collect() # sometimes PyPy needs it twice to clear weak refs
self._storage._iterator_gc() self._storage._iterator_gc()
...@@ -147,7 +148,6 @@ class IterationTests(object): ...@@ -147,7 +148,6 @@ class IterationTests(object):
self._dostore() self._dostore()
six.advance_iterator(self._storage.iterator()) six.advance_iterator(self._storage.iterator())
iid = list(self._storage._iterator_ids)[0]
t = TransactionMetaData() t = TransactionMetaData()
self._storage.tpc_begin(t) self._storage.tpc_begin(t)
# Show that after disconnecting, the client side GCs the iterators # Show that after disconnecting, the client side GCs the iterators
...@@ -176,12 +176,12 @@ def iterator_sane_after_reconnect(): ...@@ -176,12 +176,12 @@ def iterator_sane_after_reconnect():
Start a server: Start a server:
>>> addr, adminaddr = start_server( >>> addr, adminaddr = start_server( # NOQA: F821 undefined
... '<filestorage>\npath fs\n</filestorage>', keep=1) ... '<filestorage>\npath fs\n</filestorage>', keep=1)
Open a client storage to it and commit a some transactions: Open a client storage to it and commit a some transactions:
>>> import ZEO, ZODB, transaction >>> import ZEO, ZODB
>>> client = ZEO.client(addr) >>> client = ZEO.client(addr)
>>> db = ZODB.DB(client) >>> db = ZODB.DB(client)
>>> conn = db.open() >>> conn = db.open()
...@@ -196,10 +196,11 @@ Create an iterator: ...@@ -196,10 +196,11 @@ Create an iterator:
Restart the storage: Restart the storage:
>>> stop_server(adminaddr) >>> stop_server(adminaddr) # NOQA: F821 undefined
>>> wait_disconnected(client) >>> wait_disconnected(client) # NOQA: F821 undefined
>>> _ = start_server('<filestorage>\npath fs\n</filestorage>', addr=addr) >>> _ = start_server( # NOQA: F821 undefined
>>> wait_connected(client) ... '<filestorage>\npath fs\n</filestorage>', addr=addr)
>>> wait_connected(client) # NOQA: F821 undefined
Now, we'll create a second iterator: Now, we'll create a second iterator:
......
...@@ -16,6 +16,7 @@ import threading ...@@ -16,6 +16,7 @@ import threading
import sys import sys
import six import six
class TestThread(threading.Thread): class TestThread(threading.Thread):
"""Base class for defining threads that run from unittest. """Base class for defining threads that run from unittest.
...@@ -46,12 +47,14 @@ class TestThread(threading.Thread): ...@@ -46,12 +47,14 @@ class TestThread(threading.Thread):
def run(self): def run(self):
try: try:
self.testrun() self.testrun()
except: except: # NOQA: E722 blank except
self._exc_info = sys.exc_info() self._exc_info = sys.exc_info()
def cleanup(self, timeout=15): def cleanup(self, timeout=15):
self.join(timeout) self.join(timeout)
if self._exc_info: if self._exc_info:
six.reraise(self._exc_info[0], self._exc_info[1], self._exc_info[2]) six.reraise(self._exc_info[0],
self._exc_info[1],
self._exc_info[2])
if self.is_alive(): if self.is_alive():
self._testcase.fail("Thread did not finish: %s" % self) self._testcase.fail("Thread did not finish: %s" % self)
...@@ -21,6 +21,7 @@ import ZEO.Exceptions ...@@ -21,6 +21,7 @@ import ZEO.Exceptions
ZERO = '\0'*8 ZERO = '\0'*8
class BasicThread(threading.Thread): class BasicThread(threading.Thread):
def __init__(self, storage, doNextEvent, threadStartedEvent): def __init__(self, storage, doNextEvent, threadStartedEvent):
self.storage = storage self.storage = storage
...@@ -123,7 +124,6 @@ class ThreadTests(object): ...@@ -123,7 +124,6 @@ class ThreadTests(object):
# Helper for checkMTStores # Helper for checkMTStores
def mtstorehelper(self): def mtstorehelper(self):
name = threading.currentThread().getName()
objs = [] objs = []
for i in range(10): for i in range(10):
objs.append(MinPO("X" * 200000)) objs.append(MinPO("X" * 200000))
......
This diff is collapsed.
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
_auth_modules = {} _auth_modules = {}
def get_module(name): def get_module(name):
if name == 'sha': if name == 'sha':
from auth_sha import StorageClass, SHAClient, Database from auth_sha import StorageClass, SHAClient, Database
...@@ -24,6 +25,7 @@ def get_module(name): ...@@ -24,6 +25,7 @@ def get_module(name):
else: else:
return _auth_modules.get(name) return _auth_modules.get(name)
def register_module(name, storage_class, client, db): def register_module(name, storage_class, client, db):
if name in _auth_modules: if name in _auth_modules:
raise TypeError("%s is already registred" % name) raise TypeError("%s is already registred" % name)
......
...@@ -45,6 +45,7 @@ from ..StorageServer import ZEOStorage ...@@ -45,6 +45,7 @@ from ..StorageServer import ZEOStorage
from ZEO.Exceptions import AuthError from ZEO.Exceptions import AuthError
from ..hash import sha1 from ..hash import sha1
def get_random_bytes(n=8): def get_random_bytes(n=8):
try: try:
b = os.urandom(n) b = os.urandom(n)
...@@ -53,9 +54,11 @@ def get_random_bytes(n=8): ...@@ -53,9 +54,11 @@ def get_random_bytes(n=8):
b = b"".join(L) b = b"".join(L)
return b return b
def hexdigest(s): def hexdigest(s):
return sha1(s.encode()).hexdigest() return sha1(s.encode()).hexdigest()
class DigestDatabase(Database): class DigestDatabase(Database):
def __init__(self, filename, realm=None): def __init__(self, filename, realm=None):
Database.__init__(self, filename, realm) Database.__init__(self, filename, realm)
...@@ -69,6 +72,7 @@ class DigestDatabase(Database): ...@@ -69,6 +72,7 @@ class DigestDatabase(Database):
dig = hexdigest("%s:%s:%s" % (username, self.realm, password)) dig = hexdigest("%s:%s:%s" % (username, self.realm, password))
self._users[username] = dig self._users[username] = dig
def session_key(h_up, nonce): def session_key(h_up, nonce):
# The hash itself is a bit too short to be a session key. # The hash itself is a bit too short to be a session key.
# HMAC wants a 64-byte key. We don't want to use h_up # HMAC wants a 64-byte key. We don't want to use h_up
...@@ -77,6 +81,7 @@ def session_key(h_up, nonce): ...@@ -77,6 +81,7 @@ def session_key(h_up, nonce):
return (sha1(("%s:%s" % (h_up, nonce)).encode('latin-1')).digest() + return (sha1(("%s:%s" % (h_up, nonce)).encode('latin-1')).digest() +
h_up.encode('utf-8')[:44]) h_up.encode('utf-8')[:44])
class StorageClass(ZEOStorage): class StorageClass(ZEOStorage):
def set_database(self, database): def set_database(self, database):
assert isinstance(database, DigestDatabase) assert isinstance(database, DigestDatabase)
...@@ -124,6 +129,7 @@ class StorageClass(ZEOStorage): ...@@ -124,6 +129,7 @@ class StorageClass(ZEOStorage):
extensions = [auth_get_challenge, auth_response] extensions = [auth_get_challenge, auth_response]
class DigestClient(Client): class DigestClient(Client):
extensions = ["auth_get_challenge", "auth_response"] extensions = ["auth_get_challenge", "auth_response"]
......
...@@ -22,6 +22,7 @@ from __future__ import print_function ...@@ -22,6 +22,7 @@ from __future__ import print_function
import os import os
from ..hash import sha1 from ..hash import sha1
class Client(object): class Client(object):
# Subclass should override to list the names of methods that # Subclass should override to list the names of methods that
# will be called on the server. # will be called on the server.
...@@ -32,11 +33,13 @@ class Client(object): ...@@ -32,11 +33,13 @@ class Client(object):
for m in self.extensions: for m in self.extensions:
setattr(self.stub, m, self.stub.extensionMethod(m)) setattr(self.stub, m, self.stub.extensionMethod(m))
def sort(L): def sort(L):
"""Sort a list in-place and return it.""" """Sort a list in-place and return it."""
L.sort() L.sort()
return L return L
class Database(object): class Database(object):
"""Abstracts a password database. """Abstracts a password database.
...@@ -49,6 +52,7 @@ class Database(object): ...@@ -49,6 +52,7 @@ class Database(object):
produced from the password string. produced from the password string.
""" """
realm = None realm = None
def __init__(self, filename, realm=None): def __init__(self, filename, realm=None):
"""Creates a new Database """Creates a new Database
......
...@@ -3,24 +3,26 @@ ...@@ -3,24 +3,26 @@
Implements the HMAC algorithm as described by RFC 2104. Implements the HMAC algorithm as described by RFC 2104.
""" """
from six.moves import map from six.moves import map
from six.moves import zip
def _strxor(s1, s2): def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length). """Utility method. XOR the two strings s1 and s2 (must have same length).
""" """
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2)) return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying # The size of the digests returned by HMAC depends on the underlying
# hashing module used. # hashing module used.
digest_size = None digest_size = None
class HMAC(object): class HMAC(object):
"""RFC2104 HMAC class. """RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247). This supports the API for Cryptographic Hash Functions (PEP 247).
""" """
def __init__(self, key, msg = None, digestmod = None): def __init__(self, key, msg=None, digestmod=None):
"""Create a new HMAC object. """Create a new HMAC object.
key: key for the keyed hash object. key: key for the keyed hash object.
...@@ -49,8 +51,8 @@ class HMAC(object): ...@@ -49,8 +51,8 @@ class HMAC(object):
if msg is not None: if msg is not None:
self.update(msg) self.update(msg)
## def clear(self): # def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.") # raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg): def update(self, msg):
"""Update this hashing object with the string msg. """Update this hashing object with the string msg.
...@@ -85,7 +87,8 @@ class HMAC(object): ...@@ -85,7 +87,8 @@ class HMAC(object):
return "".join([hex(ord(x))[2:].zfill(2) return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())]) for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
def new(key, msg=None, digestmod=None):
"""Create a new hashing object and return it. """Create a new hashing object and return it.
key: The starting key for the hash. key: The starting key for the hash.
......
...@@ -47,6 +47,7 @@ else: ...@@ -47,6 +47,7 @@ else:
if zeo_dist is not None: if zeo_dist is not None:
zeo_version = zeo_dist.version zeo_version = zeo_dist.version
class StorageStats(object): class StorageStats(object):
"""Per-storage usage statistics.""" """Per-storage usage statistics."""
...@@ -113,6 +114,7 @@ class StorageStats(object): ...@@ -113,6 +114,7 @@ class StorageStats(object):
print("Conflicts:", self.conflicts, file=f) print("Conflicts:", self.conflicts, file=f)
print("Conflicts resolved:", self.conflicts_resolved, file=f) print("Conflicts resolved:", self.conflicts_resolved, file=f)
class StatsClient(asyncore.dispatcher): class StatsClient(asyncore.dispatcher):
def __init__(self, sock, addr): def __init__(self, sock, addr):
...@@ -144,6 +146,7 @@ class StatsClient(asyncore.dispatcher): ...@@ -144,6 +146,7 @@ class StatsClient(asyncore.dispatcher):
if self.closed and not self.buf: if self.closed and not self.buf:
asyncore.dispatcher.close(self) asyncore.dispatcher.close(self)
class StatsServer(asyncore.dispatcher): class StatsServer(asyncore.dispatcher):
StatsConnectionClass = StatsClient StatsConnectionClass = StatsClient
......
...@@ -49,21 +49,24 @@ from zdaemon.zdoptions import ZDOptions ...@@ -49,21 +49,24 @@ from zdaemon.zdoptions import ZDOptions
logger = logging.getLogger('ZEO.runzeo') logger = logging.getLogger('ZEO.runzeo')
_pid = str(os.getpid()) _pid = str(os.getpid())
def log(msg, level=logging.INFO, exc_info=False): def log(msg, level=logging.INFO, exc_info=False):
"""Internal: generic logging function.""" """Internal: generic logging function."""
message = "(%s) %s" % (_pid, msg) message = "(%s) %s" % (_pid, msg)
logger.log(level, message, exc_info=exc_info) logger.log(level, message, exc_info=exc_info)
def parse_binding_address(arg): def parse_binding_address(arg):
# Caution: Not part of the official ZConfig API. # Caution: Not part of the official ZConfig API.
obj = ZConfig.datatypes.SocketBindingAddress(arg) obj = ZConfig.datatypes.SocketBindingAddress(arg)
return obj.family, obj.address return obj.family, obj.address
def windows_shutdown_handler(): def windows_shutdown_handler():
# Called by the signal mechanism on Windows to perform shutdown. # Called by the signal mechanism on Windows to perform shutdown.
import asyncore
asyncore.close_all() asyncore.close_all()
class ZEOOptionsMixin(object): class ZEOOptionsMixin(object):
storages = None storages = None
...@@ -75,14 +78,18 @@ class ZEOOptionsMixin(object): ...@@ -75,14 +78,18 @@ class ZEOOptionsMixin(object):
self.monitor_family, self.monitor_address = parse_binding_address(arg) self.monitor_family, self.monitor_address = parse_binding_address(arg)
def handle_filename(self, arg): def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*! from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig(object): class FSConfig(object):
def __init__(self, name, path): def __init__(self, name, path):
self._name = name self._name = name
self.path = path self.path = path
self.stop = None self.stop = None
def getSectionName(self): def getSectionName(self):
return self._name return self._name
if not self.storages: if not self.storages:
self.storages = [] self.storages = []
name = str(1 + len(self.storages)) name = str(1 + len(self.storages))
...@@ -90,6 +97,7 @@ class ZEOOptionsMixin(object): ...@@ -90,6 +97,7 @@ class ZEOOptionsMixin(object):
self.storages.append(conf) self.storages.append(conf)
testing_exit_immediately = False testing_exit_immediately = False
def handle_test(self, *args): def handle_test(self, *args):
self.testing_exit_immediately = True self.testing_exit_immediately = True
...@@ -117,6 +125,7 @@ class ZEOOptionsMixin(object): ...@@ -117,6 +125,7 @@ class ZEOOptionsMixin(object):
self.add('pid_file', 'zeo.pid_filename', self.add('pid_file', 'zeo.pid_filename',
None, 'pid-file=') None, 'pid-file=')
class ZEOOptions(ZDOptions, ZEOOptionsMixin): class ZEOOptions(ZDOptions, ZEOOptionsMixin):
__doc__ = __doc__ __doc__ = __doc__
...@@ -179,8 +188,8 @@ class ZEOServer(object): ...@@ -179,8 +188,8 @@ class ZEOServer(object):
root.addHandler(handler) root.addHandler(handler)
def check_socket(self): def check_socket(self):
if (isinstance(self.options.address, tuple) and if isinstance(self.options.address, tuple) and \
self.options.address[1] is None): self.options.address[1] is None:
self.options.address = self.options.address[0], 0 self.options.address = self.options.address[0], 0
return return
if self.can_connect(self.options.family, self.options.address): if self.can_connect(self.options.family, self.options.address):
...@@ -224,7 +233,7 @@ class ZEOServer(object): ...@@ -224,7 +233,7 @@ class ZEOServer(object):
self.setup_win32_signals() self.setup_win32_signals()
return return
if hasattr(signal, 'SIGXFSZ'): if hasattr(signal, 'SIGXFSZ'):
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
init_signames() init_signames()
for sig, name in signames.items(): for sig, name in signames.items():
method = getattr(self, "handle_" + name.lower(), None) method = getattr(self, "handle_" + name.lower(), None)
...@@ -244,12 +253,12 @@ class ZEOServer(object): ...@@ -244,12 +253,12 @@ class ZEOServer(object):
"will *not* be installed.") "will *not* be installed.")
return return
SignalHandler = Signals.Signals.SignalHandler SignalHandler = Signals.Signals.SignalHandler
if SignalHandler is not None: # may be None if no pywin32. if SignalHandler is not None: # may be None if no pywin32.
SignalHandler.registerHandler(signal.SIGTERM, SignalHandler.registerHandler(signal.SIGTERM,
windows_shutdown_handler) windows_shutdown_handler)
SignalHandler.registerHandler(signal.SIGINT, SignalHandler.registerHandler(signal.SIGINT,
windows_shutdown_handler) windows_shutdown_handler)
SIGUSR2 = 12 # not in signal module on Windows. SIGUSR2 = 12 # not in signal module on Windows.
SignalHandler.registerHandler(SIGUSR2, self.handle_sigusr2) SignalHandler.registerHandler(SIGUSR2, self.handle_sigusr2)
def create_server(self): def create_server(self):
...@@ -275,20 +284,21 @@ class ZEOServer(object): ...@@ -275,20 +284,21 @@ class ZEOServer(object):
def handle_sigusr2(self): def handle_sigusr2(self):
# log rotation signal - do the same as Zope 2.7/2.8... # log rotation signal - do the same as Zope 2.7/2.8...
if self.options.config_logger is None or os.name not in ("posix", "nt"): if self.options.config_logger is None or \
log("received SIGUSR2, but it was not handled!", os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!",
level=logging.WARNING) level=logging.WARNING)
return return
loggers = [self.options.config_logger] loggers = [self.options.config_logger]
if os.name == "posix": if os.name == "posix":
for l in loggers: for logger in loggers:
l.reopen() logger.reopen()
log("Log files reopened successfully", level=logging.INFO) log("Log files reopened successfully", level=logging.INFO)
else: # nt - same rotation code as in Zope's Signals/Signals.py else: # nt - same rotation code as in Zope's Signals/Signals.py
for l in loggers: for logger in loggers:
for f in l.handler_factories: for f in logger.handler_factories:
handler = f() handler = f()
if hasattr(handler, 'rotate') and callable(handler.rotate): if hasattr(handler, 'rotate') and callable(handler.rotate):
handler.rotate() handler.rotate()
...@@ -347,14 +357,14 @@ def create_server(storages, options): ...@@ -347,14 +357,14 @@ def create_server(storages, options):
return StorageServer( return StorageServer(
options.address, options.address,
storages, storages,
read_only = options.read_only, read_only=options.read_only,
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size=options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age=options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout=options.transaction_timeout,
monitor_address = options.monitor_address, monitor_address=options.monitor_address,
auth_protocol = options.auth_protocol, auth_protocol=options.auth_protocol,
auth_database = options.auth_database, auth_database=options.auth_database,
auth_realm = options.auth_realm, auth_realm=options.auth_realm,
) )
...@@ -362,6 +372,7 @@ def create_server(storages, options): ...@@ -362,6 +372,7 @@ def create_server(storages, options):
signames = None signames = None
def signame(sig): def signame(sig):
"""Return a symbolic name for a signal. """Return a symbolic name for a signal.
...@@ -373,6 +384,7 @@ def signame(sig): ...@@ -373,6 +384,7 @@ def signame(sig):
init_signames() init_signames()
return signames.get(sig) or "signal %d" % sig return signames.get(sig) or "signal %d" % sig
def init_signames(): def init_signames():
global signames global signames
signames = {} signames = {}
...@@ -392,5 +404,6 @@ def main(args=None): ...@@ -392,5 +404,6 @@ def main(args=None):
s = ZEOServer(options) s = ZEOServer(options)
s.main() s.main()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -6,24 +6,26 @@ ...@@ -6,24 +6,26 @@
Implements the HMAC algorithm as described by RFC 2104. Implements the HMAC algorithm as described by RFC 2104.
""" """
from six.moves import map from six.moves import map
from six.moves import zip
def _strxor(s1, s2): def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length). """Utility method. XOR the two strings s1 and s2 (must have same length).
""" """
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2)) return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying # The size of the digests returned by HMAC depends on the underlying
# hashing module used. # hashing module used.
digest_size = None digest_size = None
class HMAC(object): class HMAC(object):
"""RFC2104 HMAC class. """RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247). This supports the API for Cryptographic Hash Functions (PEP 247).
""" """
def __init__(self, key, msg = None, digestmod = None): def __init__(self, key, msg=None, digestmod=None):
"""Create a new HMAC object. """Create a new HMAC object.
key: key for the keyed hash object. key: key for the keyed hash object.
...@@ -56,8 +58,8 @@ class HMAC(object): ...@@ -56,8 +58,8 @@ class HMAC(object):
if msg is not None: if msg is not None:
self.update(msg) self.update(msg)
## def clear(self): # def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.") # raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg): def update(self, msg):
"""Update this hashing object with the string msg. """Update this hashing object with the string msg.
...@@ -92,7 +94,8 @@ class HMAC(object): ...@@ -92,7 +94,8 @@ class HMAC(object):
return "".join([hex(ord(x))[2:].zfill(2) return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())]) for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
def new(key, msg=None, digestmod=None):
"""Create a new hashing object and return it. """Create a new hashing object and return it.
key: The starting key for the hash. key: The starting key for the hash.
......
...@@ -34,6 +34,7 @@ from six.moves import map ...@@ -34,6 +34,7 @@ from six.moves import map
def client_timeout(): def client_timeout():
return 30.0 return 30.0
def client_loop(map): def client_loop(map):
read = asyncore.read read = asyncore.read
write = asyncore.write write = asyncore.write
...@@ -52,7 +53,7 @@ def client_loop(map): ...@@ -52,7 +53,7 @@ def client_loop(map):
r, w, e = select.select(r, w, e, client_timeout()) r, w, e = select.select(r, w, e, client_timeout())
except (select.error, RuntimeError) as err: except (select.error, RuntimeError) as err:
# Python >= 3.3 makes select.error an alias of OSError, # Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute # which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0] err_errno = getattr(err, 'errno', None) or err[0]
if err_errno != errno.EINTR: if err_errno != errno.EINTR:
if err_errno == errno.EBADF: if err_errno == errno.EBADF:
...@@ -114,14 +115,13 @@ def client_loop(map): ...@@ -114,14 +115,13 @@ def client_loop(map):
continue continue
_exception(obj) _exception(obj)
except: except: # NOQA: E722 bare except
if map: if map:
try: try:
logging.getLogger(__name__+'.client_loop').critical( logging.getLogger(__name__+'.client_loop').critical(
'A ZEO client loop failed.', 'A ZEO client loop failed.',
exc_info=sys.exc_info()) exc_info=sys.exc_info())
except: except: # NOQA: E722 bare except
pass pass
for fd, obj in map.items(): for fd, obj in map.items():
...@@ -129,14 +129,14 @@ def client_loop(map): ...@@ -129,14 +129,14 @@ def client_loop(map):
continue continue
try: try:
obj.mgr.client.close() obj.mgr.client.close()
except: except: # NOQA: E722 bare except
map.pop(fd, None) map.pop(fd, None)
try: try:
logging.getLogger(__name__+'.client_loop' logging.getLogger(__name__+'.client_loop'
).critical( ).critical(
"Couldn't close a dispatcher.", "Couldn't close a dispatcher.",
exc_info=sys.exc_info()) exc_info=sys.exc_info())
except: except: # NOQA: E722 bare except
pass pass
...@@ -152,11 +152,11 @@ class ConnectionManager(object): ...@@ -152,11 +152,11 @@ class ConnectionManager(object):
self.tmin = min(tmin, tmax) self.tmin = min(tmin, tmax)
self.tmax = tmax self.tmax = tmax
self.cond = threading.Condition(threading.Lock()) self.cond = threading.Condition(threading.Lock())
self.connection = None # Protected by self.cond self.connection = None # Protected by self.cond
self.closed = 0 self.closed = 0
# If thread is not None, then there is a helper thread # If thread is not None, then there is a helper thread
# attempting to connect. # attempting to connect.
self.thread = None # Protected by self.cond self.thread = None # Protected by self.cond
def new_addrs(self, addrs): def new_addrs(self, addrs):
self.addrlist = self._parse_addrs(addrs) self.addrlist = self._parse_addrs(addrs)
...@@ -189,7 +189,8 @@ class ConnectionManager(object): ...@@ -189,7 +189,8 @@ class ConnectionManager(object):
for addr in addrs: for addr in addrs:
addr_type = self._guess_type(addr) addr_type = self._guess_type(addr)
if addr_type is None: if addr_type is None:
raise ValueError("unknown address in list: %s" % repr(addr)) raise ValueError(
"unknown address in list: %s" % repr(addr))
addrlist.append((addr_type, addr)) addrlist.append((addr_type, addr))
return addrlist return addrlist
...@@ -197,10 +198,10 @@ class ConnectionManager(object): ...@@ -197,10 +198,10 @@ class ConnectionManager(object):
if isinstance(addr, str): if isinstance(addr, str):
return socket.AF_UNIX return socket.AF_UNIX
if (len(addr) == 2 if len(addr) == 2 and \
and isinstance(addr[0], str) isinstance(addr[0], str) and \
and isinstance(addr[1], int)): isinstance(addr[1], int):
return socket.AF_INET # also denotes IPv6 return socket.AF_INET # also denotes IPv6
# not anything I know about # not anything I know about
return None return None
...@@ -226,7 +227,7 @@ class ConnectionManager(object): ...@@ -226,7 +227,7 @@ class ConnectionManager(object):
if obj is not self.trigger: if obj is not self.trigger:
try: try:
obj.close() obj.close()
except: except: # NOQA: E722 bare except
logging.getLogger(__name__+'.'+self.__class__.__name__ logging.getLogger(__name__+'.'+self.__class__.__name__
).critical( ).critical(
"Couldn't close a dispatcher.", "Couldn't close a dispatcher.",
...@@ -237,7 +238,7 @@ class ConnectionManager(object): ...@@ -237,7 +238,7 @@ class ConnectionManager(object):
try: try:
self.loop_thread.join(9) self.loop_thread.join(9)
except RuntimeError: except RuntimeError:
pass # we are the thread :) pass # we are the thread :)
self.trigger.close() self.trigger.close()
def attempt_connect(self): def attempt_connect(self):
...@@ -304,7 +305,7 @@ class ConnectionManager(object): ...@@ -304,7 +305,7 @@ class ConnectionManager(object):
self.connection = conn self.connection = conn
if preferred: if preferred:
self.thread = None self.thread = None
self.cond.notifyAll() # Wake up connect(sync=1) self.cond.notifyAll() # Wake up connect(sync=1)
finally: finally:
self.cond.release() self.cond.release()
...@@ -331,6 +332,7 @@ class ConnectionManager(object): ...@@ -331,6 +332,7 @@ class ConnectionManager(object):
finally: finally:
self.cond.release() self.cond.release()
# When trying to do a connect on a non-blocking socket, some outcomes # When trying to do a connect on a non-blocking socket, some outcomes
# are expected. Set _CONNECT_IN_PROGRESS to the errno value(s) expected # are expected. Set _CONNECT_IN_PROGRESS to the errno value(s) expected
# when an initial connect can't complete immediately. Set _CONNECT_OK # when an initial connect can't complete immediately. Set _CONNECT_OK
...@@ -342,10 +344,11 @@ if hasattr(errno, "WSAEWOULDBLOCK"): # Windows ...@@ -342,10 +344,11 @@ if hasattr(errno, "WSAEWOULDBLOCK"): # Windows
# seen this. # seen this.
_CONNECT_IN_PROGRESS = (errno.WSAEWOULDBLOCK,) _CONNECT_IN_PROGRESS = (errno.WSAEWOULDBLOCK,)
# Win98: WSAEISCONN; Win2K: WSAEINVAL # Win98: WSAEISCONN; Win2K: WSAEINVAL
_CONNECT_OK = (0, errno.WSAEISCONN, errno.WSAEINVAL) _CONNECT_OK = (0, errno.WSAEISCONN, errno.WSAEINVAL)
else: # Unix else: # Unix
_CONNECT_IN_PROGRESS = (errno.EINPROGRESS,) _CONNECT_IN_PROGRESS = (errno.EINPROGRESS,)
_CONNECT_OK = (0, errno.EISCONN) _CONNECT_OK = (0, errno.EISCONN)
class ConnectThread(threading.Thread): class ConnectThread(threading.Thread):
"""Thread that tries to connect to server given one or more addresses. """Thread that tries to connect to server given one or more addresses.
...@@ -455,7 +458,7 @@ class ConnectThread(threading.Thread): ...@@ -455,7 +458,7 @@ class ConnectThread(threading.Thread):
) in socket.getaddrinfo(host or 'localhost', port, ) in socket.getaddrinfo(host or 'localhost', port,
socket.AF_INET, socket.AF_INET,
socket.SOCK_STREAM socket.SOCK_STREAM
): # prune non-TCP results ): # prune non-TCP results
# for IPv6, drop flowinfo, and restrict addresses # for IPv6, drop flowinfo, and restrict addresses
# to [host]:port # to [host]:port
yield family, sockaddr[:2] yield family, sockaddr[:2]
...@@ -495,7 +498,7 @@ class ConnectThread(threading.Thread): ...@@ -495,7 +498,7 @@ class ConnectThread(threading.Thread):
break break
try: try:
r, w, x = select.select([], connecting, connecting, 1.0) r, w, x = select.select([], connecting, connecting, 1.0)
log("CT: select() %d, %d, %d" % tuple(map(len, (r,w,x)))) log("CT: select() %d, %d, %d" % tuple(map(len, (r, w, x))))
except select.error as msg: except select.error as msg:
log("CT: select failed; msg=%s" % str(msg), log("CT: select failed; msg=%s" % str(msg),
level=logging.WARNING) level=logging.WARNING)
...@@ -509,7 +512,7 @@ class ConnectThread(threading.Thread): ...@@ -509,7 +512,7 @@ class ConnectThread(threading.Thread):
for wrap in w: for wrap in w:
wrap.connect_procedure() wrap.connect_procedure()
if wrap.state == "notified": if wrap.state == "notified":
del wrappers[wrap] # Don't close this one del wrappers[wrap] # Don't close this one
for wrap in wrappers.keys(): for wrap in wrappers.keys():
wrap.close() wrap.close()
return 1 return 1
...@@ -526,7 +529,7 @@ class ConnectThread(threading.Thread): ...@@ -526,7 +529,7 @@ class ConnectThread(threading.Thread):
else: else:
wrap.notify_client() wrap.notify_client()
if wrap.state == "notified": if wrap.state == "notified":
del wrappers[wrap] # Don't close this one del wrappers[wrap] # Don't close this one
for wrap in wrappers.keys(): for wrap in wrappers.keys():
wrap.close() wrap.close()
return -1 return -1
...@@ -602,7 +605,7 @@ class ConnectWrapper(object): ...@@ -602,7 +605,7 @@ class ConnectWrapper(object):
to do app-level check of the connection. to do app-level check of the connection.
""" """
self.conn = ManagedClientConnection(self.sock, self.addr, self.mgr) self.conn = ManagedClientConnection(self.sock, self.addr, self.mgr)
self.sock = None # The socket is now owned by the connection self.sock = None # The socket is now owned by the connection
try: try:
self.preferred = self.client.testConnection(self.conn) self.preferred = self.client.testConnection(self.conn)
self.state = "tested" self.state = "tested"
...@@ -610,7 +613,7 @@ class ConnectWrapper(object): ...@@ -610,7 +613,7 @@ class ConnectWrapper(object):
log("CW: ReadOnlyError in testConnection (%s)" % repr(self.addr)) log("CW: ReadOnlyError in testConnection (%s)" % repr(self.addr))
self.close() self.close()
return return
except: except: # NOQA: E722 bare except
log("CW: error in testConnection (%s)" % repr(self.addr), log("CW: error in testConnection (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True) level=logging.ERROR, exc_info=True)
self.close() self.close()
...@@ -629,7 +632,7 @@ class ConnectWrapper(object): ...@@ -629,7 +632,7 @@ class ConnectWrapper(object):
""" """
try: try:
self.client.notifyConnected(self.conn) self.client.notifyConnected(self.conn)
except: except: # NOQA: E722 bare except
log("CW: error in notifyConnected (%s)" % repr(self.addr), log("CW: error in notifyConnected (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True) level=logging.ERROR, exc_info=True)
self.close() self.close()
......
...@@ -26,12 +26,13 @@ from .log import short_repr, log ...@@ -26,12 +26,13 @@ from .log import short_repr, log
from ZODB.loglevels import BLATHER, TRACE from ZODB.loglevels import BLATHER, TRACE
import ZODB.POSException import ZODB.POSException
REPLY = ".reply" # message name used for replies REPLY = ".reply" # message name used for replies
exception_type_type = type(Exception) exception_type_type = type(Exception)
debug_zrpc = False debug_zrpc = False
class Delay(object): class Delay(object):
"""Used to delay response to client for synchronous calls. """Used to delay response to client for synchronous calls.
...@@ -57,7 +58,9 @@ class Delay(object): ...@@ -57,7 +58,9 @@ class Delay(object):
def __repr__(self): def __repr__(self):
return "%s[%s, %r, %r, %r]" % ( return "%s[%s, %r, %r, %r]" % (
self.__class__.__name__, id(self), self.msgid, self.conn, self.sent) self.__class__.__name__, id(self), self.msgid,
self.conn, self.sent)
class Result(Delay): class Result(Delay):
...@@ -69,6 +72,7 @@ class Result(Delay): ...@@ -69,6 +72,7 @@ class Result(Delay):
conn.send_reply(msgid, reply, False) conn.send_reply(msgid, reply, False)
callback() callback()
class MTDelay(Delay): class MTDelay(Delay):
def __init__(self): def __init__(self):
...@@ -147,6 +151,7 @@ class MTDelay(Delay): ...@@ -147,6 +151,7 @@ class MTDelay(Delay):
# supply a handshake() method appropriate for their role in protocol # supply a handshake() method appropriate for their role in protocol
# negotiation. # negotiation.
class Connection(smac.SizedMessageAsyncConnection, object): class Connection(smac.SizedMessageAsyncConnection, object):
"""Dispatcher for RPC on object on both sides of socket. """Dispatcher for RPC on object on both sides of socket.
...@@ -294,7 +299,7 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -294,7 +299,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.fast_encode = marshal.fast_encode self.fast_encode = marshal.fast_encode
self.closed = False self.closed = False
self.peer_protocol_version = None # set in recv_handshake() self.peer_protocol_version = None # set in recv_handshake()
assert tag in b"CS" assert tag in b"CS"
self.tag = tag self.tag = tag
...@@ -359,7 +364,7 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -359,7 +364,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
def __repr__(self): def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.addr) return "<%s %s>" % (self.__class__.__name__, self.addr)
__str__ = __repr__ # Defeat asyncore's dreaded __getattr__ __str__ = __repr__ # Defeat asyncore's dreaded __getattr__
def log(self, message, level=BLATHER, exc_info=False): def log(self, message, level=BLATHER, exc_info=False):
self.logger.log(level, self.log_label + message, exc_info=exc_info) self.logger.log(level, self.log_label + message, exc_info=exc_info)
...@@ -441,7 +446,7 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -441,7 +446,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
try: try:
self.message_output(self.fast_encode(msgid, 0, REPLY, ret)) self.message_output(self.fast_encode(msgid, 0, REPLY, ret))
self.poll() self.poll()
except: except: # NOQA: E722 bare except
# Fall back to normal version for better error handling # Fall back to normal version for better error handling
self.send_reply(msgid, ret) self.send_reply(msgid, ret)
...@@ -520,10 +525,10 @@ class Connection(smac.SizedMessageAsyncConnection, object): ...@@ -520,10 +525,10 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# cPickle may raise. # cPickle may raise.
try: try:
msg = self.encode(msgid, 0, REPLY, (err_type, err_value)) msg = self.encode(msgid, 0, REPLY, (err_type, err_value))
except: # see above except: # NOQA: E722 bare except; see above
try: try:
r = short_repr(err_value) r = short_repr(err_value)
except: except: # NOQA: E722 bare except
r = "<unreprable>" r = "<unreprable>"
err = ZRPCError("Couldn't pickle error %.100s" % r) err = ZRPCError("Couldn't pickle error %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err)) msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
...@@ -656,10 +661,10 @@ class ManagedServerConnection(Connection): ...@@ -656,10 +661,10 @@ class ManagedServerConnection(Connection):
# cPickle may raise. # cPickle may raise.
try: try:
msg = self.encode(msgid, 0, REPLY, ret) msg = self.encode(msgid, 0, REPLY, ret)
except: # see above except: # NOQA: E722 bare except; see above
try: try:
r = short_repr(ret) r = short_repr(ret)
except: except: # NOQA: E722 bare except
r = "<unreprable>" r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r) err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err)) msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
...@@ -669,6 +674,7 @@ class ManagedServerConnection(Connection): ...@@ -669,6 +674,7 @@ class ManagedServerConnection(Connection):
poll = smac.SizedMessageAsyncConnection.handle_write poll = smac.SizedMessageAsyncConnection.handle_write
def server_loop(map): def server_loop(map):
while len(map) > 1: while len(map) > 1:
try: try:
...@@ -680,6 +686,7 @@ def server_loop(map): ...@@ -680,6 +686,7 @@ def server_loop(map):
for o in tuple(map.values()): for o in tuple(map.values()):
o.close() o.close()
class ManagedClientConnection(Connection): class ManagedClientConnection(Connection):
"""Client-side Connection subclass.""" """Client-side Connection subclass."""
__super_init = Connection.__init__ __super_init = Connection.__init__
...@@ -740,7 +747,7 @@ class ManagedClientConnection(Connection): ...@@ -740,7 +747,7 @@ class ManagedClientConnection(Connection):
# are queued for the duration. The client will send its own # are queued for the duration. The client will send its own
# handshake after the server's handshake is seen, in recv_handshake() # handshake after the server's handshake is seen, in recv_handshake()
# below. It will then send any messages queued while waiting. # below. It will then send any messages queued while waiting.
assert self.queue_output # the constructor already set this assert self.queue_output # the constructor already set this
def recv_handshake(self, proto): def recv_handshake(self, proto):
# The protocol to use is the older of our and the server's preferred # The protocol to use is the older of our and the server's preferred
...@@ -778,11 +785,11 @@ class ManagedClientConnection(Connection): ...@@ -778,11 +785,11 @@ class ManagedClientConnection(Connection):
raise DisconnectedError() raise DisconnectedError()
msgid = self.send_call(method, args) msgid = self.send_call(method, args)
r_args = self.wait(msgid) r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1 if isinstance(r_args, tuple) and len(r_args) > 1 and \
and type(r_args[0]) == exception_type_type type(r_args[0]) == exception_type_type and \
and issubclass(r_args[0], Exception)): issubclass(r_args[0], Exception):
inst = r_args[1] inst = r_args[1]
raise inst # error raised by server raise inst # error raised by server
else: else:
return r_args return r_args
...@@ -821,11 +828,11 @@ class ManagedClientConnection(Connection): ...@@ -821,11 +828,11 @@ class ManagedClientConnection(Connection):
def _deferred_wait(self, msgid): def _deferred_wait(self, msgid):
r_args = self.wait(msgid) r_args = self.wait(msgid)
if (isinstance(r_args, tuple) if isinstance(r_args, tuple) and \
and type(r_args[0]) == exception_type_type type(r_args[0]) == exception_type_type and \
and issubclass(r_args[0], Exception)): issubclass(r_args[0], Exception):
inst = r_args[1] inst = r_args[1]
raise inst # error raised by server raise inst # error raised by server
else: else:
return r_args return r_args
......
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
from ZODB import POSException from ZODB import POSException
from ZEO.Exceptions import ClientDisconnected from ZEO.Exceptions import ClientDisconnected
class ZRPCError(POSException.StorageError): class ZRPCError(POSException.StorageError):
pass pass
class DisconnectedError(ZRPCError, ClientDisconnected): class DisconnectedError(ZRPCError, ClientDisconnected):
"""The database storage is disconnected from the storage server. """The database storage is disconnected from the storage server.
......
...@@ -17,24 +17,29 @@ import logging ...@@ -17,24 +17,29 @@ import logging
from ZODB.loglevels import BLATHER from ZODB.loglevels import BLATHER
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
logger = logging.getLogger('ZEO.zrpc') logger = logging.getLogger('ZEO.zrpc')
_label = "%s" % os.getpid() _label = "%s" % os.getpid()
def new_label(): def new_label():
global _label global _label
_label = str(os.getpid()) _label = str(os.getpid())
def log(message, level=BLATHER, label=None, exc_info=False): def log(message, level=BLATHER, label=None, exc_info=False):
label = label or _label label = label or _label
if LOG_THREAD_ID: if LOG_THREAD_ID:
label = label + ':' + threading.currentThread().getName() label = label + ':' + threading.currentThread().getName()
logger.log(level, '(%s) %s' % (label, message), exc_info=exc_info) logger.log(level, '(%s) %s' % (label, message), exc_info=exc_info)
REPR_LIMIT = 60 REPR_LIMIT = 60
def short_repr(obj): def short_repr(obj):
"Return an object repr limited to REPR_LIMIT bytes." "Return an object repr limited to REPR_LIMIT bytes."
......
...@@ -19,7 +19,8 @@ from .log import log, short_repr ...@@ -19,7 +19,8 @@ from .log import log, short_repr
PY2 = not PY3 PY2 = not PY3
def encode(*args): # args: (msgid, flags, name, args)
def encode(*args): # args: (msgid, flags, name, args)
# (We used to have a global pickler, but that's not thread-safe. :-( ) # (We used to have a global pickler, but that's not thread-safe. :-( )
# It's not thread safe if, in the couse of pickling, we call the # It's not thread safe if, in the couse of pickling, we call the
...@@ -41,7 +42,6 @@ def encode(*args): # args: (msgid, flags, name, args) ...@@ -41,7 +42,6 @@ def encode(*args): # args: (msgid, flags, name, args)
return res return res
if PY3: if PY3:
# XXX: Py3: Needs optimization. # XXX: Py3: Needs optimization.
fast_encode = encode fast_encode = encode
...@@ -50,48 +50,57 @@ elif PYPY: ...@@ -50,48 +50,57 @@ elif PYPY:
# every time, getvalue() only works once # every time, getvalue() only works once
fast_encode = encode fast_encode = encode
else: else:
def fast_encode(): def fast_encode():
# Only use in cases where you *know* the data contains only basic # Only use in cases where you *know* the data contains only basic
# Python objects # Python objects
pickler = Pickler(1) pickler = Pickler(1)
pickler.fast = 1 pickler.fast = 1
dump = pickler.dump dump = pickler.dump
def fast_encode(*args): def fast_encode(*args):
return dump(args, 1) return dump(args, 1)
return fast_encode return fast_encode
fast_encode = fast_encode() fast_encode = fast_encode()
def decode(msg): def decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global unpickler.find_global = find_global
try: try:
unpickler.find_class = find_global # PyPy, zodbpickle, the non-c-accelerated version # PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = find_global
except AttributeError: except AttributeError:
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except: # NOQA: E722 bare except
log("can't decode message: %s" % short_repr(msg), log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR) level=logging.ERROR)
raise raise
def server_decode(msg): def server_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global unpickler.find_global = server_find_global
try: try:
unpickler.find_class = server_find_global # PyPy, zodbpickle, the non-c-accelerated version # PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = server_find_global
except AttributeError: except AttributeError:
pass pass
try: try:
return unpickler.load() # msgid, flags, name, args return unpickler.load() # msgid, flags, name, args
except: except: # NOQA: E722 bare except
log("can't decode message: %s" % short_repr(msg), log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR) level=logging.ERROR)
raise raise
_globals = globals() _globals = globals()
_silly = ('__doc__',) _silly = ('__doc__',)
...@@ -102,6 +111,7 @@ _SAFE_MODULE_NAMES = ( ...@@ -102,6 +111,7 @@ _SAFE_MODULE_NAMES = (
'builtins', 'copy_reg', '__builtin__', 'builtins', 'copy_reg', '__builtin__',
) )
def find_global(module, name): def find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
try: try:
...@@ -114,7 +124,8 @@ def find_global(module, name): ...@@ -114,7 +124,8 @@ def find_global(module, name):
except AttributeError: except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name)) raise ZRPCError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) or (PY2 and module in _SAFE_MODULE_NAMES) safe = (getattr(r, '__no_side_effects__', 0) or
(PY2 and module in _SAFE_MODULE_NAMES))
if safe: if safe:
return r return r
...@@ -124,6 +135,7 @@ def find_global(module, name): ...@@ -124,6 +135,7 @@ def find_global(module, name):
raise ZRPCError("Unsafe global: %s.%s" % (module, name)) raise ZRPCError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name): def server_find_global(module, name):
"""Helper for message unpickler""" """Helper for message unpickler"""
if module not in _SAFE_MODULE_NAMES: if module not in _SAFE_MODULE_NAMES:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
############################################################################## ##############################################################################
import asyncore import asyncore
import socket import socket
import time
# _has_dualstack: True if the dual-stack sockets are supported # _has_dualstack: True if the dual-stack sockets are supported
try: try:
...@@ -39,6 +40,7 @@ import logging ...@@ -39,6 +40,7 @@ import logging
# Export the main asyncore loop # Export the main asyncore loop
loop = asyncore.loop loop = asyncore.loop
class Dispatcher(asyncore.dispatcher): class Dispatcher(asyncore.dispatcher):
"""A server that accepts incoming RPC connections""" """A server that accepts incoming RPC connections"""
__super_init = asyncore.dispatcher.__init__ __super_init = asyncore.dispatcher.__init__
...@@ -74,7 +76,7 @@ class Dispatcher(asyncore.dispatcher): ...@@ -74,7 +76,7 @@ class Dispatcher(asyncore.dispatcher):
for i in range(25): for i in range(25):
try: try:
self.bind(self.addr) self.bind(self.addr)
except Exception as exc: except Exception:
log("bind failed %s waiting", i) log("bind failed %s waiting", i)
if i == 24: if i == 24:
raise raise
...@@ -98,7 +100,6 @@ class Dispatcher(asyncore.dispatcher): ...@@ -98,7 +100,6 @@ class Dispatcher(asyncore.dispatcher):
log("accepted failed: %s" % msg) log("accepted failed: %s" % msg)
return return
# We could short-circuit the attempt below in some edge cases # We could short-circuit the attempt below in some edge cases
# and avoid a log message by checking for addr being None. # and avoid a log message by checking for addr being None.
# Unfortunately, our test for the code below, # Unfortunately, our test for the code below,
...@@ -111,12 +112,12 @@ class Dispatcher(asyncore.dispatcher): ...@@ -111,12 +112,12 @@ class Dispatcher(asyncore.dispatcher):
# closed, but I don't see a way to do that. :( # closed, but I don't see a way to do that. :(
# Drop flow-info from IPv6 addresses # Drop flow-info from IPv6 addresses
if addr: # Sometimes None on Mac. See above. if addr: # Sometimes None on Mac. See above.
addr = addr[:2] addr = addr[:2]
try: try:
c = self.factory(sock, addr) c = self.factory(sock, addr)
except: except: # NOQA: E722 bare except
if sock.fileno() in asyncore.socket_map: if sock.fileno() in asyncore.socket_map:
del asyncore.socket_map[sock.fileno()] del asyncore.socket_map[sock.fileno()]
logger.exception("Error in handle_accept") logger.exception("Error in handle_accept")
......
...@@ -67,19 +67,20 @@ MAC_BIT = 0x80000000 ...@@ -67,19 +67,20 @@ MAC_BIT = 0x80000000
_close_marker = object() _close_marker = object()
class SizedMessageAsyncConnection(asyncore.dispatcher): class SizedMessageAsyncConnection(asyncore.dispatcher):
__super_init = asyncore.dispatcher.__init__ __super_init = asyncore.dispatcher.__init__
__super_close = asyncore.dispatcher.close __super_close = asyncore.dispatcher.close
__closed = True # Marker indicating that we're closed __closed = True # Marker indicating that we're closed
socket = None # to outwit Sam's getattr socket = None # to outwit Sam's getattr
def __init__(self, sock, addr, map=None): def __init__(self, sock, addr, map=None):
self.addr = addr self.addr = addr
# __input_lock protects __inp, __input_len, __state, __msg_size # __input_lock protects __inp, __input_len, __state, __msg_size
self.__input_lock = threading.Lock() self.__input_lock = threading.Lock()
self.__inp = None # None, a single String, or a list self.__inp = None # None, a single String, or a list
self.__input_len = 0 self.__input_len = 0
# Instance variables __state, __msg_size and __has_mac work together: # Instance variables __state, __msg_size and __has_mac work together:
# when __state == 0: # when __state == 0:
...@@ -168,7 +169,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher): ...@@ -168,7 +169,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
d = self.recv(8192) d = self.recv(8192)
except socket.error as err: except socket.error as err:
# Python >= 3.3 makes select.error an alias of OSError, # Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute # which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0] err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_read_errors: if err_errno in expected_socket_read_errors:
return return
...@@ -190,7 +191,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher): ...@@ -190,7 +191,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
else: else:
self.__inp.append(d) self.__inp.append(d)
self.__input_len = input_len self.__input_len = input_len
return # keep waiting for more input return # keep waiting for more input
# load all previous input and d into single string inp # load all previous input and d into single string inp
if isinstance(inp, six.binary_type): if isinstance(inp, six.binary_type):
...@@ -298,15 +299,15 @@ class SizedMessageAsyncConnection(asyncore.dispatcher): ...@@ -298,15 +299,15 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
# ensure the above mentioned "output" invariant # ensure the above mentioned "output" invariant
output.insert(0, v) output.insert(0, v)
# Python >= 3.3 makes select.error an alias of OSError, # Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute # which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0] err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_write_errors: if err_errno in expected_socket_write_errors:
break # we couldn't write anything break # we couldn't write anything
raise raise
if n < len(v): if n < len(v):
output.append(v[n:]) output.append(v[n:])
break # we can't write any more break # we can't write any more
def handle_close(self): def handle_close(self):
self.close() self.close()
......
...@@ -21,7 +21,7 @@ import socket ...@@ -21,7 +21,7 @@ import socket
import errno import errno
from ZODB.utils import positive_id from ZODB.utils import positive_id
from ZEO._compat import thread, get_ident from ZEO._compat import thread
# Original comments follow; they're hard to follow in the context of # Original comments follow; they're hard to follow in the context of
# ZEO's use of triggers. TODO: rewrite from a ZEO perspective. # ZEO's use of triggers. TODO: rewrite from a ZEO perspective.
...@@ -56,6 +56,7 @@ from ZEO._compat import thread, get_ident ...@@ -56,6 +56,7 @@ from ZEO._compat import thread, get_ident
# new data onto a channel's outgoing data queue at the same time that # new data onto a channel's outgoing data queue at the same time that
# the main thread is trying to remove some] # the main thread is trying to remove some]
class _triggerbase(object): class _triggerbase(object):
"""OS-independent base class for OS-dependent trigger class.""" """OS-independent base class for OS-dependent trigger class."""
...@@ -127,7 +128,7 @@ class _triggerbase(object): ...@@ -127,7 +128,7 @@ class _triggerbase(object):
return return
try: try:
thunk[0](*thunk[1:]) thunk[0](*thunk[1:])
except: except: # NOQA: E722 bare except
nil, t, v, tbinfo = asyncore.compact_traceback() nil, t, v, tbinfo = asyncore.compact_traceback()
print(('exception in trigger thunk:' print(('exception in trigger thunk:'
' (%s:%s %s)' % (t, v, tbinfo))) ' (%s:%s %s)' % (t, v, tbinfo)))
...@@ -135,6 +136,7 @@ class _triggerbase(object): ...@@ -135,6 +136,7 @@ class _triggerbase(object):
def __repr__(self): def __repr__(self):
return '<select-trigger (%s) at %x>' % (self.kind, positive_id(self)) return '<select-trigger (%s) at %x>' % (self.kind, positive_id(self))
if os.name == 'posix': if os.name == 'posix':
class trigger(_triggerbase, asyncore.file_dispatcher): class trigger(_triggerbase, asyncore.file_dispatcher):
...@@ -187,39 +189,39 @@ else: ...@@ -187,39 +189,39 @@ else:
count = 0 count = 0
while 1: while 1:
count += 1 count += 1
# Bind to a local port; for efficiency, let the OS pick # Bind to a local port; for efficiency, let the OS pick
# a free port for us. # a free port for us.
# Unfortunately, stress tests showed that we may not # Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in # be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears # use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation. # to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always # So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at # on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html # http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details. # for hideous details.
a = socket.socket() a = socket.socket()
a.bind(("127.0.0.1", 0)) a.bind(("127.0.0.1", 0))
connect_address = a.getsockname() # assigned (host, port) pair connect_address = a.getsockname() # assigned (host, port) pair
a.listen(1) a.listen(1)
try: try:
w.connect(connect_address) w.connect(connect_address)
break # success break # success
except socket.error as detail: except socket.error as detail:
if detail[0] != errno.WSAEADDRINUSE: if detail[0] != errno.WSAEADDRINUSE:
# "Address already in use" is the only error # "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under # I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1. # Pythons 2.3.5 and 2.4.1.
raise raise
# (10048, 'Address already in use') # (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests # assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2 if count >= 10: # I've never seen it go above 2
a.close() a.close()
w.close() w.close()
raise BindError("Cannot bind trigger!") raise BindError("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short # Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt. # sleep() here, but it didn't appear to help or hurt.
a.close() a.close()
r, addr = a.accept() # r becomes asyncore's (self.)socket r, addr = a.accept() # r becomes asyncore's (self.)socket
a.close() a.close()
......
...@@ -16,7 +16,6 @@ from __future__ import print_function ...@@ -16,7 +16,6 @@ from __future__ import print_function
import random import random
import sys
import time import time
...@@ -56,8 +55,8 @@ def encode_format(fmt): ...@@ -56,8 +55,8 @@ def encode_format(fmt):
fmt = fmt.replace(*xform) fmt = fmt.replace(*xform)
return fmt return fmt
runner = _forker.runner
runner = _forker.runner
stop_runner = _forker.stop_runner stop_runner = _forker.stop_runner
start_zeo_server = _forker.start_zeo_server start_zeo_server = _forker.start_zeo_server
...@@ -70,6 +69,7 @@ else: ...@@ -70,6 +69,7 @@ else:
shutdown_zeo_server = _forker.shutdown_zeo_server shutdown_zeo_server = _forker.shutdown_zeo_server
def get_port(ignored=None): def get_port(ignored=None):
"""Return a port that is not in use. """Return a port that is not in use.
...@@ -107,6 +107,7 @@ def get_port(ignored=None): ...@@ -107,6 +107,7 @@ def get_port(ignored=None):
s1.close() s1.close()
raise RuntimeError("Can't find port") raise RuntimeError("Can't find port")
def can_connect(port): def can_connect(port):
c = socket.socket(socket.AF_INET, socket.SOCK_STREAM) c = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try: try:
...@@ -119,6 +120,7 @@ def can_connect(port): ...@@ -119,6 +120,7 @@ def can_connect(port):
finally: finally:
c.close() c.close()
def setUp(test): def setUp(test):
ZODB.tests.util.setUp(test) ZODB.tests.util.setUp(test)
...@@ -194,9 +196,11 @@ def wait_until(label=None, func=None, timeout=30, onfail=None): ...@@ -194,9 +196,11 @@ def wait_until(label=None, func=None, timeout=30, onfail=None):
return onfail() return onfail()
time.sleep(0.01) time.sleep(0.01)
def wait_connected(storage): def wait_connected(storage):
wait_until("storage is connected", storage.is_connected) wait_until("storage is connected", storage.is_connected)
def wait_disconnected(storage): def wait_disconnected(storage):
wait_until("storage is disconnected", wait_until("storage is disconnected",
lambda: not storage.is_connected()) lambda: not storage.is_connected())
......
...@@ -34,6 +34,7 @@ import ZEO.asyncio.tests ...@@ -34,6 +34,7 @@ import ZEO.asyncio.tests
import ZEO.StorageServer import ZEO.StorageServer
import ZODB.MappingStorage import ZODB.MappingStorage
class StorageServer(ZEO.StorageServer.StorageServer): class StorageServer(ZEO.StorageServer.StorageServer):
def __init__(self, addr='test_addr', storages=None, **kw): def __init__(self, addr='test_addr', storages=None, **kw):
...@@ -41,6 +42,7 @@ class StorageServer(ZEO.StorageServer.StorageServer): ...@@ -41,6 +42,7 @@ class StorageServer(ZEO.StorageServer.StorageServer):
storages = {'1': ZODB.MappingStorage.MappingStorage()} storages = {'1': ZODB.MappingStorage.MappingStorage()}
ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw) ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw)
def client(server, name='client'): def client(server, name='client'):
zs = ZEO.StorageServer.ZEOStorage(server) zs = ZEO.StorageServer.ZEOStorage(server)
protocol = ZEO.asyncio.tests.server_protocol( protocol = ZEO.asyncio.tests.server_protocol(
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -16,15 +16,18 @@ import unittest ...@@ -16,15 +16,18 @@ import unittest
from ZEO.TransactionBuffer import TransactionBuffer from ZEO.TransactionBuffer import TransactionBuffer
def random_string(size): def random_string(size):
"""Return a random string of size size.""" """Return a random string of size size."""
l = [chr(random.randrange(256)) for i in range(size)] lst = [chr(random.randrange(256)) for i in range(size)]
return "".join(l) return "".join(lst)
def new_store_data(): def new_store_data():
"""Return arbitrary data to use as argument to store() method.""" """Return arbitrary data to use as argument to store() method."""
return random_string(8), random_string(random.randrange(1000)) return random_string(8), random_string(random.randrange(1000))
def store(tbuf, resolved=False): def store(tbuf, resolved=False):
data = new_store_data() data = new_store_data()
tbuf.store(*data) tbuf.store(*data)
...@@ -32,6 +35,7 @@ def store(tbuf, resolved=False): ...@@ -32,6 +35,7 @@ def store(tbuf, resolved=False):
tbuf.server_resolve(data[0]) tbuf.server_resolve(data[0])
return data return data
class TransBufTests(unittest.TestCase): class TransBufTests(unittest.TestCase):
def checkTypicalUsage(self): def checkTypicalUsage(self):
...@@ -54,5 +58,6 @@ class TransBufTests(unittest.TestCase): ...@@ -54,5 +58,6 @@ class TransBufTests(unittest.TestCase):
self.assertEqual(resolved, data[i][1]) self.assertEqual(resolved, data[i][1])
tbuf.close() tbuf.close()
def test_suite(): def test_suite():
return unittest.makeSuite(TransBufTests, 'check') return unittest.makeSuite(TransBufTests, 'check')
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -17,7 +17,7 @@ class MarshalTests(unittest.TestCase): ...@@ -17,7 +17,7 @@ class MarshalTests(unittest.TestCase):
# this is an example (1) of Zope2's arguments for # this is an example (1) of Zope2's arguments for
# undoInfo call. Arguments are encoded by ZEO client # undoInfo call. Arguments are encoded by ZEO client
# and decoded by server. The operation must be idempotent. # and decoded by server. The operation must be idempotent.
# (1) https://github.com/zopefoundation/Zope/blob/2.13/src/App/Undo.py#L111 # (1) https://github.com/zopefoundation/Zope/blob/2.13/src/App/Undo.py#L111 # NOQA: E501 line too long
args = (0, 20, {'user_name': Prefix('test')}) args = (0, 20, {'user_name': Prefix('test')})
# test against repr because Prefix __eq__ operator # test against repr because Prefix __eq__ operator
# doesn't compare Prefix with Prefix but only # doesn't compare Prefix with Prefix but only
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment