Commit 5ffb1e89 authored by Jens Vagelpohl's avatar Jens Vagelpohl Committed by GitHub

Merge pull request #197 from zopefoundation/flake8

Full linting with flake8
parents fe9aebc0 f2122aa4
......@@ -2,7 +2,7 @@
# https://github.com/zopefoundation/meta/tree/master/config/pure-python
[meta]
template = "pure-python"
commit-id = "3b712f305ca8207e971c5bf81f2bdb5872489f2f"
commit-id = "0c07a1cfd78d28a07aebd23383ed16959f166574"
[python]
with-windows = false
......@@ -13,7 +13,7 @@ with-docs = true
with-sphinx-doctests = false
[tox]
use-flake8 = false
use-flake8 = true
testenv-commands = [
"# Run unit tests first.",
"zope-testrunner -u --test-path=src {posargs:-vc}",
......
......@@ -4,6 +4,8 @@ Changelog
5.4.0 (unreleased)
------------------
- Lint the code with flake8
- Add support for Python 3.10.
- Add ``ConflictError`` to the list of unlogged server exceptions
......
......@@ -11,11 +11,12 @@
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
version = '5.3.1.dev0'
from setuptools import setup, find_packages
import os
version = '5.4.0.dev0'
install_requires = [
'ZODB >= 5.1.1',
'six',
......@@ -64,12 +65,14 @@ Operating System :: Unix
Framework :: ZODB
""".strip().split('\n')
def _modname(path, base, name=''):
if path == base:
return name
dirname, basename = os.path.split(path)
return _modname(dirname, base, basename + '.' + name)
def _flatten(suite, predicate=lambda *x: True):
from unittest import TestCase
for suite_or_case in suite:
......@@ -80,18 +83,20 @@ def _flatten(suite, predicate=lambda *x: True):
for x in _flatten(suite_or_case):
yield x
def _no_layer(suite_or_case):
return getattr(suite_or_case, 'layer', None) is None
def _unittests_only(suite, mod_suite):
for case in _flatten(mod_suite, _no_layer):
suite.addTest(case)
def alltests():
import logging
import pkg_resources
import unittest
import ZEO.ClientStorage
class NullHandler(logging.Handler):
level = 50
......@@ -107,7 +112,8 @@ def alltests():
for dirpath, dirnames, filenames in os.walk(base):
if os.path.basename(dirpath) == 'tests':
for filename in filenames:
if filename != 'testZEO.py': continue
if filename != 'testZEO.py':
continue
if filename.endswith('.py') and filename.startswith('test'):
mod = __import__(
_modname(dirpath, base, os.path.splitext(filename)[0]),
......@@ -115,11 +121,13 @@ def alltests():
_unittests_only(suite, mod.test_suite())
return suite
long_description = (
open('README.rst').read()
+ '\n' +
open('CHANGES.rst').read()
)
setup(name="ZEO",
version=version,
description=long_description.split('\n', 2)[1],
......@@ -133,7 +141,7 @@ setup(name="ZEO",
license="ZPL 2.1",
platforms=["any"],
classifiers=classifiers,
test_suite="__main__.alltests", # to support "setup.py test"
test_suite="__main__.alltests", # to support "setup.py test"
tests_require=tests_require,
extras_require={
'test': tests_require,
......@@ -164,4 +172,4 @@ setup(name="ZEO",
""",
include_package_data=True,
python_requires='>=2.7.9,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*',
)
)
......@@ -52,9 +52,11 @@ import ZEO.cache
logger = logging.getLogger(__name__)
def tid2time(tid):
return str(TimeStamp(tid))
def get_timestamp(prev_ts=None):
"""Internal helper to return a unique TimeStamp instance.
......@@ -69,8 +71,10 @@ def get_timestamp(prev_ts=None):
t = t.laterThan(prev_ts)
return t
MB = 1024**2
@zope.interface.implementer(ZODB.interfaces.IMultiCommitStorage)
class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
"""A storage class that is a network client to a remote storage.
......@@ -90,7 +94,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
blob_cache_size=None, blob_cache_size_check=10,
client_label=None,
cache=None,
ssl = None, ssl_server_hostname=None,
ssl=None, ssl_server_hostname=None,
# Mostly ignored backward-compatability options
client=None, var=None,
min_disconnect_poll=1, max_disconnect_poll=None,
......@@ -189,14 +193,15 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
if isinstance(addr, int):
addr = ('127.0.0.1', addr)
self.__name__ = name or str(addr) # Standard convention for storages
self.__name__ = name or str(addr) # Standard convention for storages
if isinstance(addr, six.string_types):
if WIN:
raise ValueError("Unix sockets are not available on Windows")
addr = [addr]
elif (isinstance(addr, tuple) and len(addr) == 2 and
isinstance(addr[0], six.string_types) and isinstance(addr[1], int)):
isinstance(addr[0], six.string_types) and
isinstance(addr[1], int)):
addr = [addr]
logger.info(
......@@ -212,7 +217,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._is_read_only = read_only
self._read_only_fallback = read_only_fallback
self._addr = addr # For tests
self._addr = addr # For tests
self._iterators = weakref.WeakValueDictionary()
self._iterator_ids = set()
......@@ -228,7 +233,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._db = None
self._oids = [] # List of pre-fetched oids from server
self._oids = [] # List of pre-fetched oids from server
cache = self._cache = open_cache(
cache, var, client, storage, cache_size)
......@@ -266,7 +271,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
addr, self, cache, storage,
ZEO.asyncio.client.Fallback if read_only_fallback else read_only,
wait_timeout or 30,
ssl = ssl, ssl_server_hostname=ssl_server_hostname,
ssl=ssl, ssl_server_hostname=ssl_server_hostname,
credentials=credentials,
)
self._call = self._server.call
......@@ -308,6 +313,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._check_blob_size_thread.join()
_check_blob_size_thread = None
def _check_blob_size(self, bytes=None):
if self._blob_cache_size is None:
return
......@@ -349,8 +355,8 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
pass
_connection_generation = 0
def notify_connected(self, conn, info):
reconnected = self._connection_generation
self.set_server_addr(conn.get_peername())
self.protocol_version = conn.protocol_version
self._is_read_only = conn.is_read_only()
......@@ -373,22 +379,20 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
self._info.update(info)
for iface in (
ZODB.interfaces.IStorageRestoreable,
ZODB.interfaces.IStorageIteration,
ZODB.interfaces.IStorageUndoable,
ZODB.interfaces.IStorageCurrentRecordIteration,
ZODB.interfaces.IBlobStorage,
ZODB.interfaces.IExternalGC,
):
if (iface.__module__, iface.__name__) in self._info.get(
'interfaces', ()):
for iface in (ZODB.interfaces.IStorageRestoreable,
ZODB.interfaces.IStorageIteration,
ZODB.interfaces.IStorageUndoable,
ZODB.interfaces.IStorageCurrentRecordIteration,
ZODB.interfaces.IBlobStorage,
ZODB.interfaces.IExternalGC):
if (iface.__module__, iface.__name__) in \
self._info.get('interfaces', ()):
zope.interface.alsoProvides(self, iface)
if self.protocol_version[1:] >= b'5':
self.ping = lambda : self._call('ping')
self.ping = lambda: self._call('ping')
else:
self.ping = lambda : self._call('lastTransaction')
self.ping = lambda: self._call('lastTransaction')
if self.server_sync:
self.sync = self.ping
......@@ -536,7 +540,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
try:
return self._oids.pop()
except IndexError:
pass # We ran out. We need to get some more.
pass # We ran out. We need to get some more.
self._oids[:0] = reversed(self._call('new_oids'))
......@@ -735,7 +739,6 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
finally:
lock.close()
def temporaryDirectory(self):
return self.fshelper.temp_dir
......@@ -747,7 +750,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
conflicts = True
vote_attempts = 0
while conflicts and vote_attempts < 9: # 9? Mainly avoid inf. loop
while conflicts and vote_attempts < 9: # 9? Mainly avoid inf. loop
conflicts = False
for oid in self._call('vote', id(txn)) or ():
if isinstance(oid, dict):
......@@ -843,11 +846,11 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
def tpc_abort(self, txn, timeout=None):
"""Storage API: abort a transaction.
(The timeout keyword argument is for tests to wat longer than
(The timeout keyword argument is for tests to wait longer than
they normally would.)
"""
try:
tbuf = txn.data(self)
tbuf = txn.data(self) # NOQA: F841 unused variable
except KeyError:
return
......@@ -899,7 +902,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
while blobs:
oid, blobfilename = blobs.pop()
self._blob_data_bytes_loaded += os.stat(blobfilename).st_size
targetpath = self.fshelper.getPathForOID(oid, create=True)
self.fshelper.getPathForOID(oid, create=True)
target_blob_file_name = self.fshelper.getBlobFilename(oid, tid)
lock = _lock_blob(target_blob_file_name)
try:
......@@ -1037,6 +1040,7 @@ class ClientStorage(ZODB.ConflictResolution.ConflictResolvingStorage):
def server_status(self):
return self._call('server_status')
class TransactionIterator(object):
def __init__(self, storage, iid, *args):
......@@ -1130,14 +1134,18 @@ class BlobCacheLayout(object):
ZODB.blob.BLOB_SUFFIX)
)
def _accessed(filename):
try:
os.utime(filename, (time.time(), os.stat(filename).st_mtime))
except OSError:
pass # We tried. :)
pass # We tried. :)
return filename
cache_file_name = re.compile(r'\d+$').match
def _check_blob_cache_size(blob_dir, target):
logger = logging.getLogger(__name__+'.check_blob_cache')
......@@ -1162,7 +1170,7 @@ def _check_blob_cache_size(blob_dir, target):
# Someone is already cleaning up, so don't bother
logger.debug("%s Another thread is checking the blob cache size.",
get_ident())
open(attempt_path, 'w').close() # Mark that we tried
open(attempt_path, 'w').close() # Mark that we tried
return
logger.debug("%s Checking blob cache size. (target: %s)",
......@@ -1200,7 +1208,7 @@ def _check_blob_cache_size(blob_dir, target):
try:
os.remove(attempt_path)
except OSError:
pass # Sigh, windows
pass # Sigh, windows
continue
logger.debug("%s -->", get_ident())
break
......@@ -1222,8 +1230,8 @@ def _check_blob_cache_size(blob_dir, target):
fsize = os.stat(file_name).st_size
try:
ZODB.blob.remove_committed(file_name)
except OSError as v:
pass # probably open on windows
except OSError:
pass # probably open on windows
else:
size -= fsize
finally:
......@@ -1238,12 +1246,14 @@ def _check_blob_cache_size(blob_dir, target):
finally:
check_lock.close()
def check_blob_size_script(args=None):
if args is None:
args = sys.argv[1:]
blob_dir, target = args
_check_blob_cache_size(blob_dir, int(target))
def _lock_blob(path):
lockfilename = os.path.join(os.path.dirname(path), '.lock')
n = 0
......@@ -1258,6 +1268,7 @@ def _lock_blob(path):
else:
break
def open_cache(cache, var, client, storage, cache_size):
if isinstance(cache, (None.__class__, str)):
from ZEO.cache import ClientCache
......
......@@ -17,27 +17,33 @@ import transaction.interfaces
from ZODB.POSException import StorageError
class ClientStorageError(StorageError):
"""An error occurred in the ZEO Client Storage.
"""
class UnrecognizedResult(ClientStorageError):
"""A server call returned an unrecognized result.
"""
class ClientDisconnected(ClientStorageError,
transaction.interfaces.TransientError):
"""The database storage is disconnected from the storage.
"""
class AuthError(StorageError):
"""The client provided invalid authentication credentials.
"""
class ProtocolError(ClientStorageError):
"""A client contacted a server with an incomparible protocol
"""
class ServerException(ClientStorageError):
"""
"""
This diff is collapsed.
......@@ -21,12 +21,11 @@ is used to store the data until a commit or abort.
# A faster implementation might store trans data in memory until it
# reaches a certain size.
import os
import tempfile
import ZODB.blob
from ZEO._compat import Pickler, Unpickler
class TransactionBuffer(object):
# The TransactionBuffer is used by client storage to hold update
......@@ -44,8 +43,8 @@ class TransactionBuffer(object):
# stored are builtin types -- strings or None.
self.pickler = Pickler(self.file, 1)
self.pickler.fast = 1
self.server_resolved = set() # {oid}
self.client_resolved = {} # {oid -> buffer_record_number}
self.server_resolved = set() # {oid}
self.client_resolved = {} # {oid -> buffer_record_number}
self.exception = None
def close(self):
......@@ -93,9 +92,7 @@ class TransactionBuffer(object):
if oid not in seen:
yield oid, None, True
# Support ZEO4:
def serialnos(self, args):
for oid in args:
if isinstance(oid, bytes):
......
......@@ -21,6 +21,7 @@ ZEO is now part of ZODB; ZODB's home on the web is
"""
def client(*args, **kw):
"""
Shortcut for :class:`ZEO.ClientStorage.ClientStorage`.
......@@ -28,6 +29,7 @@ def client(*args, **kw):
import ZEO.ClientStorage
return ZEO.ClientStorage.ClientStorage(*args, **kw)
def DB(*args, **kw):
"""
Shortcut for creating a :class:`ZODB.DB` using a ZEO :func:`~ZEO.client`.
......@@ -40,6 +42,7 @@ def DB(*args, **kw):
s.close()
raise
def connection(*args, **kw):
db = DB(*args, **kw)
try:
......@@ -48,6 +51,7 @@ def connection(*args, **kw):
db.close()
raise
def server(path=None, blob_dir=None, storage_conf=None, zeo_conf=None,
port=0, threaded=True, **kw):
"""Convenience function to start a server for interactive exploration
......
......@@ -16,13 +16,20 @@
import sys
import platform
from ZODB._compat import BytesIO # NOQA: F401 unused import
PY3 = sys.version_info[0] >= 3
PY32 = sys.version_info[:2] == (3, 2)
PYPY = getattr(platform, 'python_implementation', lambda: None)() == 'PyPy'
WIN = sys.platform.startswith('win')
if PY3:
from zodbpickle.pickle import Pickler, Unpickler as _Unpickler, dump, dumps, loads
from zodbpickle.pickle import dump
from zodbpickle.pickle import dumps
from zodbpickle.pickle import loads
from zodbpickle.pickle import Pickler
from zodbpickle.pickle import Unpickler as _Unpickler
class Unpickler(_Unpickler):
# Py3: Python 3 doesn't allow assignments to find_global,
# instead, find_class can be overridden
......@@ -44,24 +51,17 @@ else:
dumps = cPickle.dumps
loads = cPickle.loads
# String and Bytes IO
from ZODB._compat import BytesIO
if PY3:
import _thread as thread
import _thread as thread # NOQA: F401 unused import
if PY32:
from threading import _get_ident as get_ident
from threading import _get_ident as get_ident # NOQA: F401 unused
else:
from threading import get_ident
from threading import get_ident # NOQA: F401 unused import
else:
import thread
from thread import get_ident
import thread # NOQA: F401 unused import
from thread import get_ident # NOQA: F401 unused import
try:
from cStringIO import StringIO
except:
from io import StringIO
from cStringIO import StringIO # NOQA: F401 unused import
except ImportError:
from io import StringIO # NOQA: F401 unused import
......@@ -26,11 +26,10 @@ import six
from ZEO._compat import StringIO
logger = logging.getLogger('ZEO.tests.forker')
DEBUG = os.environ.get('ZEO_TEST_SERVER_DEBUG')
ZEO4_SERVER = os.environ.get('ZEO4_SERVER')
class ZEOConfig(object):
"""Class to generate ZEO configuration file. """
......@@ -61,8 +60,7 @@ class ZEOConfig(object):
for name in (
'invalidation_queue_size', 'invalidation_age',
'transaction_timeout', 'pid_filename', 'msgpack',
'ssl_certificate', 'ssl_key', 'client_conflict_resolution',
):
'ssl_certificate', 'ssl_key', 'client_conflict_resolution'):
v = getattr(self, name, None)
if v:
print(name.replace('_', '-'), v, file=f)
......@@ -134,7 +132,7 @@ def runner(config, qin, qout, timeout=None,
os.remove(config)
try:
qin.get(timeout=timeout) # wait for shutdown
qin.get(timeout=timeout) # wait for shutdown
except Empty:
pass
server.server.close()
......@@ -158,6 +156,7 @@ def runner(config, qin, qout, timeout=None,
ZEO.asyncio.server.best_protocol_version = old_protocol
ZEO.asyncio.server.ServerProtocol.protocols = old_protocols
def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None):
qin.put('stop')
try:
......@@ -180,6 +179,7 @@ def stop_runner(thread, config, qin, qout, stop_timeout=19, pid=None):
gc.collect()
def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
path='Data.fs', protocol=None, blob_dir=None,
suicide=True, debug=False,
......@@ -220,7 +220,8 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
print(zeo_conf)
# Store the config info in a temp file.
fd, tmpfile = tempfile.mkstemp(".conf", prefix='ZEO_forker', dir=os.getcwd())
fd, tmpfile = tempfile.mkstemp(".conf", prefix='ZEO_forker',
dir=os.getcwd())
with os.fdopen(fd, 'w') as fp:
fp.write(zeo_conf)
......@@ -273,10 +274,12 @@ def debug_logging(logger='ZEO', stream='stderr', level=logging.DEBUG):
return stop
def whine(*message):
print(*message, file=sys.stderr)
sys.stderr.flush()
class ThreadlessQueue(object):
def __init__(self):
......
......@@ -14,6 +14,7 @@ logger = logging.getLogger(__name__)
INET_FAMILIES = socket.AF_INET, socket.AF_INET6
class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO base interface
"""
......@@ -30,9 +31,9 @@ class Protocol(asyncio.Protocol):
def __init__(self, loop, addr):
self.loop = loop
self.addr = addr
self.input = [] # Input buffer when assembling messages
self.output = [] # Output buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
self.input = [] # Input buffer when assembling messages
self.output = [] # Output buffer when paused
self.paused = [] # Paused indicator, mutable to avoid attr lookup
# Handle the first message, the protocol handshake, differently
self.message_received = self.first_message_received
......@@ -41,6 +42,7 @@ class Protocol(asyncio.Protocol):
return self.name
closed = False
def close(self):
if not self.closed:
self.closed = True
......@@ -50,7 +52,6 @@ class Protocol(asyncio.Protocol):
def connection_made(self, transport):
logger.info("Connected %s", self)
if sys.version_info < (3, 6):
sock = transport.get_extra_info('socket')
if sock is not None and sock.family in INET_FAMILIES:
......@@ -91,6 +92,7 @@ class Protocol(asyncio.Protocol):
got = 0
want = 4
getting_size = True
def data_received(self, data):
# Low-level input handler collects data into sized messages.
......@@ -135,7 +137,7 @@ class Protocol(asyncio.Protocol):
def first_message_received(self, protocol_version):
# Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on
del self.message_received # use default handler from here on
self.finish_connect(protocol_version)
def call_async(self, method, args):
......@@ -162,7 +164,7 @@ class Protocol(asyncio.Protocol):
data = message
for message in data:
writelines((pack(">I", len(message)), message))
if paused: # paused again. Put iter back.
if paused: # paused again. Put iter back.
output.insert(0, data)
break
......
......@@ -19,7 +19,8 @@ logger = logging.getLogger(__name__)
Fallback = object()
local_random = random.Random() # use separate generator to facilitate tests
local_random = random.Random() # use separate generator to facilitate tests
def future_generator(func):
"""Decorates a generator that generates futures
......@@ -52,6 +53,7 @@ def future_generator(func):
return call_generator
class Protocol(base.Protocol):
"""asyncio low-level ZEO client interface
"""
......@@ -85,7 +87,7 @@ class Protocol(base.Protocol):
self.client = client
self.connect_poll = connect_poll
self.heartbeat_interval = heartbeat_interval
self.futures = {} # { message_id -> future }
self.futures = {} # { message_id -> future }
self.ssl = ssl
self.ssl_server_hostname = ssl_server_hostname
self.credentials = credentials
......@@ -132,7 +134,9 @@ class Protocol(base.Protocol):
elif future.exception() is not None:
logger.info("Connection to %r failed, %s",
self.addr, future.exception())
else: return
else:
return
# keep trying
if not self.closed:
logger.info("retry connecting %r", self.addr)
......@@ -141,7 +145,6 @@ class Protocol(base.Protocol):
self.connect,
)
def connection_made(self, transport):
super(Protocol, self).connection_made(transport)
self.heartbeat(write=False)
......@@ -190,7 +193,8 @@ class Protocol(base.Protocol):
try:
server_tid = yield self.fut(
'register', self.storage_key,
self.read_only if self.read_only is not Fallback else False,
(self.read_only if self.read_only is not Fallback
else False),
*credentials)
except ZODB.POSException.ReadOnlyError:
if self.read_only is Fallback:
......@@ -208,11 +212,12 @@ class Protocol(base.Protocol):
self.client.registered(self, server_tid)
exception_type_type = type(Exception)
def message_received(self, data):
msgid, async_, name, args = self.decode(data)
if name == '.reply':
future = self.futures.pop(msgid)
if async_: # ZEO 5 exception
if async_: # ZEO 5 exception
class_, args = args
factory = exc_factories.get(class_)
if factory:
......@@ -237,13 +242,14 @@ class Protocol(base.Protocol):
else:
future.set_result(args)
else:
assert async_ # clients only get async calls
assert async_ # clients only get async calls
if name in self.client_methods:
getattr(self.client, name)(*args)
else:
raise AttributeError(name)
message_id = 0
def call(self, future, method, args):
self.message_id += 1
self.futures[self.message_id] = future
......@@ -262,6 +268,7 @@ class Protocol(base.Protocol):
self.futures[message_id] = future
self._write(
self.encode(message_id, False, 'loadBefore', (oid, tid)))
@future.add_done_callback
def _(future):
try:
......@@ -271,6 +278,7 @@ class Protocol(base.Protocol):
if data:
data, start, end = data
self.client.cache.store(oid, start, end, data)
return future
# Methods called by the server.
......@@ -290,29 +298,34 @@ class Protocol(base.Protocol):
self.heartbeat_handle = self.loop.call_later(
self.heartbeat_interval, self.heartbeat)
def create_Exception(class_, args):
return exc_classes[class_](*args)
def create_ConflictError(class_, args):
exc = exc_classes[class_](
message = args['message'],
oid = args['oid'],
serials = args['serials'],
message=args['message'],
oid=args['oid'],
serials=args['serials'],
)
exc.class_name = args.get('class_name')
return exc
def create_BTreesConflictError(class_, args):
return ZODB.POSException.BTreesConflictError(
p1 = args['p1'],
p2 = args['p2'],
p3 = args['p3'],
reason = args['reason'],
p1=args['p1'],
p2=args['p2'],
p3=args['p3'],
reason=args['reason'],
)
def create_MultipleUndoErrors(class_, args):
return ZODB.POSException.MultipleUndoErrors(args['_errs'])
exc_classes = {
'builtins.KeyError': KeyError,
'builtins.TypeError': TypeError,
......@@ -340,6 +353,8 @@ exc_factories = {
}
unlogged_exceptions = (ZODB.POSException.POSKeyError,
ZODB.POSException.ConflictError)
class Client(object):
"""asyncio low-level ZEO client interface
"""
......@@ -352,8 +367,11 @@ class Client(object):
# connect.
protocol = None
ready = None # Tri-value: None=Never connected, True=connected,
# False=Disconnected
# ready can have three values:
# None=Never connected
# True=connected
# False=Disconnected
ready = None
def __init__(self, loop,
addrs, client, cache, storage_key, read_only, connect_poll,
......@@ -404,6 +422,7 @@ class Client(object):
self.is_read_only() and self.read_only is Fallback)
closed = False
def close(self):
if not self.closed:
self.closed = True
......@@ -466,7 +485,7 @@ class Client(object):
self.upgrade(protocol)
self.verify(server_tid)
else:
protocol.close() # too late, we went home with another
protocol.close() # too late, we went home with another
def register_failed(self, protocol, exc):
# A protocol failed registration. That's weird. If they've all
......@@ -474,18 +493,17 @@ class Client(object):
if protocol is not self:
protocol.close()
logger.exception("Registration or cache validation failed, %s", exc)
if (self.protocol is None and not
any(not p.closed for p in self.protocols)
):
if self.protocol is None and \
not any(not p.closed for p in self.protocols):
self.loop.call_later(
self.register_failed_poll + local_random.random(),
self.try_connecting)
verify_result = None # for tests
verify_result = None # for tests
@future_generator
def verify(self, server_tid):
self.verify_invalidation_queue = [] # See comment in init :(
self.verify_invalidation_queue = [] # See comment in init :(
protocol = self.protocol
try:
......@@ -739,6 +757,7 @@ class Client(object):
else:
return protocol.read_only
class ClientRunner(object):
def set_options(self, addrs, wrapper, cache, storage_key, read_only,
......@@ -855,6 +874,7 @@ class ClientRunner(object):
timeout = self.timeout
self.wait_for_result(self.client.connected, timeout)
class ClientThread(ClientRunner):
"""Thread wrapper for client interface
......@@ -883,6 +903,7 @@ class ClientThread(ClientRunner):
raise self.exception
exception = None
def run(self):
loop = None
try:
......@@ -909,6 +930,7 @@ class ClientThread(ClientRunner):
logger.debug('Stopping client thread')
closed = False
def close(self):
if not self.closed:
self.closed = True
......@@ -918,6 +940,7 @@ class ClientThread(ClientRunner):
if self.exception:
raise self.exception
class Fut(object):
"""Lightweight future that calls it's callbacks immediately rather than soon
"""
......@@ -929,6 +952,7 @@ class Fut(object):
self.cbv.append(cb)
exc = None
def set_exception(self, exc):
self.exc = exc
for cb in self.cbv:
......
......@@ -6,5 +6,5 @@ if PY3:
except ImportError:
from asyncio import new_event_loop
else:
import trollius as asyncio
from trollius import new_event_loop
import trollius as asyncio # NOQA: F401 unused import
from trollius import new_event_loop # NOQA: F401 unused import
......@@ -21,19 +21,22 @@ Python-independent format, or possibly a minimal pickle subset.
import logging
from .._compat import Unpickler, Pickler, BytesIO, PY3, PYPY
from .._compat import Unpickler, Pickler, BytesIO, PY3
from ..shortrepr import short_repr
PY2 = not PY3
logger = logging.getLogger(__name__)
def encoder(protocol, server=False):
"""Return a non-thread-safe encoder
"""
if protocol[:1] == b'M':
from msgpack import packb
default = server_default if server else None
def encode(*args):
return packb(
args, use_bin_type=True, default=default)
......@@ -49,6 +52,7 @@ def encoder(protocol, server=False):
pickler = Pickler(f, 3)
pickler.fast = 1
dump = pickler.dump
def encode(*args):
seek(0)
truncate()
......@@ -57,21 +61,26 @@ def encoder(protocol, server=False):
return encode
def encode(*args):
return encoder(b'Z')(*args)
def decoder(protocol):
if protocol[:1] == b'M':
from msgpack import unpackb
def msgpack_decode(data):
"""Decodes msg and returns its parts"""
return unpackb(data, raw=False, use_list=False)
return msgpack_decode
else:
assert protocol[:1] == b'Z'
return pickle_decode
def pickle_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
......@@ -82,11 +91,12 @@ def pickle_decode(msg):
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
return unpickler.load() # msgid, flags, name, args
except: # NOQA: E722 bare except
logger.error("can't decode message: %s" % short_repr(msg))
raise
def server_decoder(protocol):
if protocol[:1] == b'M':
return decoder(protocol)
......@@ -94,6 +104,7 @@ def server_decoder(protocol):
assert protocol[:1] == b'Z'
return pickle_server_decode
def pickle_server_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
......@@ -105,22 +116,25 @@ def pickle_server_decode(msg):
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
return unpickler.load() # msgid, flags, name, args
except: # NOQA: E722 bare except
logger.error("can't decode message: %s" % short_repr(msg))
raise
def server_default(obj):
if isinstance(obj, Exception):
return reduce_exception(obj)
else:
return obj
def reduce_exception(exc):
class_ = exc.__class__
class_ = "%s.%s" % (class_.__module__, class_.__name__)
return class_, exc.__dict__ or exc.args
_globals = globals()
_silly = ('__doc__',)
......@@ -131,6 +145,7 @@ _SAFE_MODULE_NAMES = (
'builtins', 'copy_reg', '__builtin__',
)
def find_global(module, name):
"""Helper for message unpickler"""
try:
......@@ -143,7 +158,8 @@ def find_global(module, name):
except AttributeError:
raise ImportError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) or (PY2 and module in _SAFE_MODULE_NAMES)
safe = (getattr(r, '__no_side_effects__', 0) or
(PY2 and module in _SAFE_MODULE_NAMES))
if safe:
return r
......@@ -153,6 +169,7 @@ def find_global(module, name):
raise ImportError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name):
"""Helper for message unpickler"""
if module not in _SAFE_MODULE_NAMES:
......
......@@ -72,6 +72,7 @@ import logging
logger = logging.getLogger(__name__)
class Acceptor(asyncore.dispatcher):
"""A server that accepts incoming RPC connections
......@@ -115,13 +116,13 @@ class Acceptor(asyncore.dispatcher):
for i in range(25):
try:
self.bind(addr)
except Exception as exc:
except Exception:
logger.info("bind on %s failed %s waiting", addr, i)
if i == 24:
raise
else:
time.sleep(5)
except:
except: # NOQA: E722 bare except
logger.exception('binding')
raise
else:
......@@ -146,7 +147,6 @@ class Acceptor(asyncore.dispatcher):
logger.info("accepted failed: %s", msg)
return
# We could short-circuit the attempt below in some edge cases
# and avoid a log message by checking for addr being None.
# Unfortunately, our test for the code below,
......@@ -159,7 +159,7 @@ class Acceptor(asyncore.dispatcher):
# closed, but I don't see a way to do that. :(
# Drop flow-info from IPv6 addresses
if addr: # Sometimes None on Mac. See above.
if addr: # Sometimes None on Mac. See above.
addr = addr[:2]
try:
......@@ -172,23 +172,25 @@ class Acceptor(asyncore.dispatcher):
protocol.stop = loop.stop
if self.ssl_context is None:
cr = loop.create_connection((lambda : protocol), sock=sock)
cr = loop.create_connection((lambda: protocol), sock=sock)
else:
if hasattr(loop, 'connect_accepted_socket'):
cr = loop.connect_accepted_socket(
(lambda : protocol), sock, ssl=self.ssl_context)
(lambda: protocol), sock, ssl=self.ssl_context)
else:
#######################################################
# XXX See http://bugs.python.org/issue27392 :(
_make_ssl_transport = loop._make_ssl_transport
def make_ssl_transport(*a, **kw):
kw['server_side'] = True
return _make_ssl_transport(*a, **kw)
loop._make_ssl_transport = make_ssl_transport
#
#######################################################
cr = loop.create_connection(
(lambda : protocol), sock=sock,
(lambda: protocol), sock=sock,
ssl=self.ssl_context,
server_hostname=''
)
......@@ -212,11 +214,12 @@ class Acceptor(asyncore.dispatcher):
asyncore.loop(map=self.__socket_map, timeout=timeout)
except Exception:
if not self.__closed:
raise # Unexpected exc
raise # Unexpected exc
logger.debug('acceptor %s loop stopped', self.addr)
__closed = False
def close(self):
if not self.__closed:
self.__closed = True
......
import json
import logging
import os
import random
import threading
import ZODB.POSException
logger = logging.getLogger(__name__)
from ..shortrepr import short_repr
from . import base
from .compat import asyncio, new_event_loop
from .marshal import server_decoder, encoder, reduce_exception
logger = logging.getLogger(__name__)
class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface
"""
......@@ -39,6 +40,7 @@ class ServerProtocol(base.Protocol):
)
closed = False
def close(self):
logger.debug("Closing server protocol")
if not self.closed:
......@@ -46,7 +48,8 @@ class ServerProtocol(base.Protocol):
if self.transport is not None:
self.transport.close()
connected = None # for tests
connected = None # for tests
def connection_made(self, transport):
self.connected = True
super(ServerProtocol, self).connection_made(transport)
......@@ -60,7 +63,7 @@ class ServerProtocol(base.Protocol):
self.stop()
def stop(self):
pass # Might be replaced when running a thread per client
pass # Might be replaced when running a thread per client
def finish_connect(self, protocol_version):
if protocol_version == b'ruok':
......@@ -95,7 +98,7 @@ class ServerProtocol(base.Protocol):
return
if message_id == -1:
return # keep-alive
return # keep-alive
if name not in self.methods:
logger.error('Invalid method, %r', name)
......@@ -109,7 +112,7 @@ class ServerProtocol(base.Protocol):
"%s`%r` raised exception:",
'async ' if async_ else '', name)
if async_:
return self.close() # No way to recover/cry for help
return self.close() # No way to recover/cry for help
else:
return self.send_error(message_id, exc)
......@@ -147,16 +150,19 @@ class ServerProtocol(base.Protocol):
def async_threadsafe(self, method, *args):
self.call_soon_threadsafe(self.call_async, method, args)
best_protocol_version = os.environ.get(
'ZEO_SERVER_PROTOCOL',
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols
def new_connection(loop, addr, socket, zeo_storage, msgpack):
protocol = ServerProtocol(loop, addr, zeo_storage, msgpack)
cr = loop.create_connection((lambda : protocol), sock=socket)
cr = loop.create_connection((lambda: protocol), sock=socket)
asyncio.ensure_future(cr, loop=loop)
class Delay(object):
"""Used to delay response to client for synchronous calls.
......@@ -192,6 +198,7 @@ class Delay(object):
def __reduce__(self):
raise TypeError("Can't pickle delays.")
class Result(Delay):
def __init__(self, *args):
......@@ -202,6 +209,7 @@ class Result(Delay):
protocol.send_reply(msgid, reply)
callback()
class MTDelay(Delay):
def __init__(self):
......@@ -266,6 +274,7 @@ class Acceptor(object):
self.event_loop.close()
closed = False
def close(self):
if not self.closed:
self.closed = True
......@@ -277,6 +286,7 @@ class Acceptor(object):
self.server.close()
f = asyncio.ensure_future(self.server.wait_closed(), loop=loop)
@f.add_done_callback
def server_closed(f):
# stop the loop when the server closes:
......
......@@ -11,7 +11,6 @@ except NameError:
class ConnectionRefusedError(OSError):
pass
import pprint
class Loop(object):
......@@ -19,7 +18,7 @@ class Loop(object):
def __init__(self, addrs=(), debug=True):
self.addrs = addrs
self.get_debug = lambda : debug
self.get_debug = lambda: debug
self.connecting = {}
self.later = []
self.exceptions = []
......@@ -31,7 +30,7 @@ class Loop(object):
func(*args)
def _connect(self, future, protocol_factory):
self.protocol = protocol = protocol_factory()
self.protocol = protocol = protocol_factory()
self.transport = transport = Transport(protocol)
protocol.connection_made(transport)
future.set_result((transport, protocol))
......@@ -45,10 +44,8 @@ class Loop(object):
if not future.cancelled():
future.set_exception(ConnectionRefusedError())
def create_connection(
self, protocol_factory, host=None, port=None, sock=None,
ssl=None, server_hostname=None
):
def create_connection(self, protocol_factory, host=None, port=None,
sock=None, ssl=None, server_hostname=None):
future = asyncio.Future(loop=self)
if sock is None:
addr = host, port
......@@ -83,13 +80,16 @@ class Loop(object):
self.exceptions.append(context)
closed = False
def close(self):
self.closed = True
stopped = False
def stop(self):
self.stopped = True
class Handle(object):
cancelled = False
......@@ -97,6 +97,7 @@ class Handle(object):
def cancel(self):
self.cancelled = True
class Transport(object):
capacity = 1 << 64
......@@ -136,12 +137,14 @@ class Transport(object):
self.protocol.resume_writing()
closed = False
def close(self):
self.closed = True
def get_extra_info(self, name):
return self.extra[name]
class AsyncRPC(object):
"""Adapt an asyncio API to an RPC to help hysterical tests
"""
......@@ -151,6 +154,7 @@ class AsyncRPC(object):
def __getattr__(self, name):
return lambda *a, **kw: self.api.call(name, *a, **kw)
class ClientRunner(object):
def __init__(self, addr, client, cache, storage, read_only, timeout,
......
......@@ -2,17 +2,17 @@ from .._compat import PY3
if PY3:
import asyncio
def to_byte(i):
return bytes([i])
else:
import trollius as asyncio
import trollius as asyncio # NOQA: F401 unused import
def to_byte(b):
return b
from zope.testing import setupstack
from concurrent.futures import Future
import mock
from ZODB.POSException import ReadOnlyError
from ZODB.utils import maxtid, RLock
import collections
......@@ -28,6 +28,7 @@ from .client import ClientRunner, Fallback
from .server import new_connection, best_protocol_version
from .marshal import encoder, decoder
class Base(object):
enc = b'Z'
......@@ -56,6 +57,7 @@ class Base(object):
return self.unsized(data, True)
target = None
def send(self, method, *args, **kw):
target = kw.pop('target', self.target)
called = kw.pop('called', True)
......@@ -77,6 +79,7 @@ class Base(object):
def pop(self, count=None, parse=True):
return self.unsized(self.loop.transport.pop(count), parse)
class ClientTests(Base, setupstack.TestCase, ClientRunner):
maxDiff = None
......@@ -204,7 +207,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
loaded = self.load_before(b'1'*8, maxtid)
# The data wasn't in the cache, so we made a server call:
self.assertEqual(self.pop(), ((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid)))
self.assertEqual(self.pop(),
((b'1'*8, maxtid),
False,
'loadBefore',
(b'1'*8, maxtid)))
# Note load_before uses the oid as the message id.
self.respond((b'1'*8, maxtid), (b'data', b'a'*8, None))
self.assertEqual(loaded.result(), (b'data', b'a'*8, None))
......@@ -224,7 +231,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# the requests will be collapsed:
loaded2 = self.load_before(b'1'*8, maxtid)
self.assertEqual(self.pop(), ((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid)))
self.assertEqual(self.pop(),
((b'1'*8, maxtid),
False,
'loadBefore',
(b'1'*8, maxtid)))
self.respond((b'1'*8, maxtid), (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None))
......@@ -238,7 +249,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(self.pop(), ((b'1'*8, b'_'*8), False, 'loadBefore', (b'1'*8, b'_'*8)))
self.assertEqual(self.pop(),
((b'1'*8, b'_'*8),
False,
'loadBefore',
(b'1'*8, b'_'*8)))
self.respond((b'1'*8, b'_'*8), (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
......@@ -247,6 +262,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# iteratable to tpc_finish_threadsafe.
tids = []
def finished_cb(tid):
tids.append(tid)
......@@ -349,7 +365,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done())
self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, )))
self.assertEqual(self.pop(),
(3, False, 'getInvalidations', (b'a'*8, )))
self.respond(3, (b'e'*8, [b'4'*8]))
self.assertEqual(self.pop(), (4, False, 'get_info', ()))
......@@ -361,7 +378,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# And the cache has been updated:
self.assertEqual(cache.load(b'2'*8),
('2 data', b'a'*8)) # unchanged
('2 data', b'a'*8)) # unchanged
self.assertEqual(cache.load(b'4'*8), None)
# Because we were able to update the cache, we didn't have to
......@@ -384,7 +401,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done())
self.assertEqual(self.pop(), (3, False, 'getInvalidations', (b'a'*8, )))
self.assertEqual(self.pop(),
(3, False, 'getInvalidations', (b'a'*8, )))
# We respond None, indicating that we're too far out of date:
self.respond(3, None)
......@@ -451,10 +469,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.respond(2, 'a'*8)
self.pop()
self.assertFalse(client.connected.done() or transport.data)
delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat
delay, func, args, _ = loop.later.pop(1) # first in later is heartbeat
self.assertTrue(8 < delay < 10)
self.assertEqual(len(loop.later), 1) # first in later is heartbeat
func(*args) # connect again
self.assertEqual(len(loop.later), 1) # first in later is heartbeat
func(*args) # connect again
self.assertFalse(protocol is loop.protocol)
self.assertFalse(transport is loop.transport)
protocol = loop.protocol
......@@ -512,7 +530,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We connect the second address:
loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), self.enc + b'3101')
self.assertEqual(self.unsized(loop.transport.pop(2)),
self.enc + b'3101')
self.assertEqual(self.parse(loop.transport.pop()),
(1, False, 'register', ('TEST', False)))
self.assertTrue(self.is_read_only())
......@@ -613,7 +632,6 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.data_received(sized(self.enc + b'200'))
self.assertTrue(isinstance(error.call_args[0][1], ProtocolError))
def test_get_peername(self):
wrapper, cache, loop, client, protocol, transport = self.start(
finish_start=True)
......@@ -641,7 +659,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# that caused it to fail badly if errors were raised while
# handling data.
wrapper, cache, loop, client, protocol, transport =self.start(
wrapper, cache, loop, client, protocol, transport = self.start(
finish_start=True)
wrapper.receiveBlobStart.side_effect = ValueError('test')
......@@ -694,10 +712,12 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.connection_lost(None)
self.assertTrue(handle.cancelled)
class MsgpackClientTests(ClientTests):
enc = b'M'
seq_type = tuple
class MemoryCache(object):
def __init__(self):
......@@ -709,6 +729,7 @@ class MemoryCache(object):
clear = __init__
closed = False
def close(self):
self.closed = True
......@@ -771,6 +792,7 @@ class ServerTests(Base, setupstack.TestCase):
message_id = 0
target = None
def call(self, meth, *args, **kw):
if kw:
expect = kw.pop('expect', self)
......@@ -835,10 +857,12 @@ class ServerTests(Base, setupstack.TestCase):
self.call('foo', target=None)
self.assertTrue(protocol.loop.transport.closed)
class MsgpackServerTests(ServerTests):
enc = b'M'
seq_type = tuple
def server_protocol(msgpack,
zeo_storage=None,
protocol_version=None,
......@@ -847,18 +871,17 @@ def server_protocol(msgpack,
if zeo_storage is None:
zeo_storage = mock.Mock()
loop = Loop()
sock = () # anything not None
sock = () # anything not None
new_connection(loop, addr, sock, zeo_storage, msgpack)
if protocol_version:
loop.protocol.data_received(sized(protocol_version))
return loop.protocol
def response(*data):
return sized(self.encode(*data))
def sized(message):
return struct.pack(">I", len(message)) + message
class Logging(object):
def __init__(self, level=logging.ERROR):
......@@ -885,9 +908,11 @@ class ProtocolTests(setupstack.TestCase):
loop = self.loop
protocol, transport = loop.protocol, loop.transport
transport.capacity = 1 # single message
def it(tag):
yield tag
yield tag
protocol._writeit(it(b"0"))
protocol._writeit(it(b"1"))
for b in b"0011":
......
......@@ -86,7 +86,7 @@ ZEC_HEADER_SIZE = 12
# need to write a free block that is almost twice as big. If we die
# in the middle of a store, then we need to split the large free records
# while opening.
max_block_size = (1<<31) - 1
max_block_size = (1 << 31) - 1
# After the header, the file contains a contiguous sequence of blocks. All
......@@ -132,12 +132,13 @@ allocated_record_overhead = 43
# Under PyPy, the available dict specializations perform significantly
# better (faster) than the pure-Python BTree implementation. They may
# use less memory too. And we don't require any of the special BTree features...
# use less memory too. And we don't require any of the special BTree features.
_current_index_type = ZODB.fsIndex.fsIndex if not PYPY else dict
_noncurrent_index_type = BTrees.LOBTree.LOBTree if not PYPY else dict
# ...except at this leaf level
_noncurrent_bucket_type = BTrees.LLBTree.LLBucket
class ClientCache(object):
"""A simple in-memory cache."""
......@@ -193,7 +194,7 @@ class ClientCache(object):
if path:
self._lock_file = zc.lockfile.LockFile(path + '.lock')
if not os.path.exists(path):
# Create a small empty file. We'll make it bigger in _initfile.
# Create a small empty file. We'll make it bigger in _initfile.
self.f = open(path, 'wb+')
self.f.write(magic+z64)
logger.info("created persistent cache file %r", path)
......@@ -209,10 +210,10 @@ class ClientCache(object):
try:
self._initfile(fsize)
except:
except: # NOQA: E722 bare except
self.f.close()
if not path:
raise # unrecoverable temp file error :(
raise # unrecoverable temp file error :(
badpath = path+'.bad'
if os.path.exists(badpath):
logger.critical(
......@@ -271,7 +272,7 @@ class ClientCache(object):
self.current = _current_index_type()
self.noncurrent = _noncurrent_index_type()
l = 0
length = 0
last = ofs = ZEC_HEADER_SIZE
first_free_offset = 0
current = self.current
......@@ -290,7 +291,7 @@ class ClientCache(object):
assert start_tid < end_tid, (ofs, f.tell())
self._set_noncurrent(oid, start_tid, ofs)
assert lver == 0, "Versions aren't supported"
l += 1
length += 1
else:
# free block
if first_free_offset == 0:
......@@ -331,7 +332,7 @@ class ClientCache(object):
break
if fsize < maxsize:
assert ofs==fsize
assert ofs == fsize
# Make sure the OS really saves enough bytes for the file.
seek(self.maxsize - 1)
write(b'x')
......@@ -349,7 +350,7 @@ class ClientCache(object):
assert last and (status in b' f1234')
first_free_offset = last
else:
assert ofs==maxsize
assert ofs == maxsize
if maxsize < fsize:
seek(maxsize)
f.truncate()
......@@ -357,7 +358,7 @@ class ClientCache(object):
# We use the first_free_offset because it is most likely the
# place where we last wrote.
self.currentofs = first_free_offset or ZEC_HEADER_SIZE
self._len = l
self._len = length
def _set_noncurrent(self, oid, tid, ofs):
noncurrent_for_oid = self.noncurrent.get(u64(oid))
......@@ -375,7 +376,6 @@ class ClientCache(object):
except KeyError:
logger.error("Couldn't find non-current %r", (oid, tid))
def clearStats(self):
self._n_adds = self._n_added_bytes = 0
self._n_evicts = self._n_evicted_bytes = 0
......@@ -384,8 +384,7 @@ class ClientCache(object):
def getStats(self):
return (self._n_adds, self._n_added_bytes,
self._n_evicts, self._n_evicted_bytes,
self._n_accesses
)
self._n_accesses)
##
# The number of objects currently in the cache.
......@@ -403,7 +402,7 @@ class ClientCache(object):
sync(f)
f.close()
if hasattr(self,'_lock_file'):
if hasattr(self, '_lock_file'):
self._lock_file.close()
##
......@@ -517,9 +516,9 @@ class ClientCache(object):
if ofsofs < 0:
ofsofs += self.maxsize
if (ofsofs > self.rearrange and
self.maxsize > 10*len(data) and
size > 4):
if ofsofs > self.rearrange and \
self.maxsize > 10*len(data) and \
size > 4:
# The record is far back and might get evicted, but it's
# valuable, so move it forward.
......@@ -619,8 +618,8 @@ class ClientCache(object):
raise ValueError("already have current data for oid")
else:
noncurrent_for_oid = self.noncurrent.get(u64(oid))
if noncurrent_for_oid and (
u64(start_tid) in noncurrent_for_oid):
if noncurrent_for_oid and \
u64(start_tid) in noncurrent_for_oid:
return
size = allocated_record_overhead + len(data)
......@@ -692,7 +691,6 @@ class ClientCache(object):
self.currentofs += size
##
# If `tid` is None,
# forget all knowledge of `oid`. (`tid` can be None only for
......@@ -765,8 +763,7 @@ class ClientCache(object):
for oid, tid in L:
print(oid_repr(oid), oid_repr(tid))
print("dll contents")
L = list(self)
L.sort(lambda x, y: cmp(x.key, y.key))
L = sorted(list(self), key=lambda x: x.key)
for x in L:
end_tid = x.end_tid or z64
print(oid_repr(x.key[0]), oid_repr(x.key[1]), oid_repr(end_tid))
......@@ -779,6 +776,7 @@ class ClientCache(object):
# tracing by setting self._trace to a dummy function, and set
# self._tracefile to None.
_tracefile = None
def _trace(self, *a, **kw):
pass
......@@ -797,6 +795,7 @@ class ClientCache(object):
return
now = time.time
def _trace(code, oid=b"", tid=z64, end_tid=z64, dlen=0):
# The code argument is two hex digits; bits 0 and 7 must be zero.
# The first hex digit shows the operation, the second the outcome.
......@@ -812,7 +811,7 @@ class ClientCache(object):
pack(">iiH8s8s",
int(now()), encoded, len(oid), tid, end_tid) + oid,
)
except:
except: # NOQA: E722 bare except
print(repr(tid), repr(end_tid))
raise
......@@ -826,10 +825,7 @@ class ClientCache(object):
self._tracefile.close()
del self._tracefile
def sync(f):
f.flush()
if hasattr(os, 'fsync'):
def sync(f):
f.flush()
os.fsync(f.fileno())
os.fsync(f.fileno())
......@@ -14,6 +14,7 @@
import zope.interface
class StaleCache(object):
"""A ZEO cache is stale and requires verification.
"""
......@@ -21,6 +22,7 @@ class StaleCache(object):
def __init__(self, storage):
self.storage = storage
class IClientCache(zope.interface.Interface):
"""Client cache interface.
......@@ -86,6 +88,7 @@ class IClientCache(zope.interface.Interface):
"""Clear/empty the cache
"""
class IServeable(zope.interface.Interface):
"""Interface provided by storages that can be served by ZEO
"""
......
......@@ -30,10 +30,7 @@ from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
import asyncore
import socket
import time
import logging
zeo_version = 'unknown'
try:
......@@ -47,6 +44,7 @@ else:
if zeo_dist is not None:
zeo_version = zeo_dist.version
class StorageStats(object):
"""Per-storage usage statistics."""
......
......@@ -33,6 +33,7 @@ diff_names = 'aborts commits conflicts conflicts_resolved loads stores'.split()
per_times = dict(seconds=1.0, minutes=60.0, hours=3600.0, days=86400.0)
def new_metric(metrics, storage_id, name, value):
if storage_id == '1':
label = name
......@@ -43,6 +44,7 @@ def new_metric(metrics, storage_id, name, value):
label = "%s:%s" % (storage_id, name)
metrics.append("%s=%s" % (label, value))
def result(messages, metrics=(), status=None):
if metrics:
messages[0] += '|' + metrics[0]
......@@ -51,12 +53,15 @@ def result(messages, metrics=(), status=None):
print('\n'.join(messages))
return status
def error(message):
return result((message, ), (), 2)
def warn(message):
return result((message, ), (), 1)
def check(addr, output_metrics, status, per):
m = re.match(r'\[(\S+)\]:(\d+)$', addr)
if m:
......@@ -75,7 +80,7 @@ def check(addr, output_metrics, status, per):
return error("Can't connect %s" % err)
s.sendall(b'\x00\x00\x00\x04ruok')
proto = s.recv(struct.unpack(">I", s.recv(4))[0])
proto = s.recv(struct.unpack(">I", s.recv(4))[0]) # NOQA: F841 unused
datas = s.recv(struct.unpack(">I", s.recv(4))[0])
s.close()
data = json.loads(datas.decode("ascii"))
......@@ -94,8 +99,8 @@ def check(addr, output_metrics, status, per):
now = time.time()
if os.path.exists(status):
dt = now - os.stat(status).st_mtime
if dt > 0: # sanity :)
with open(status) as f: # Read previous
if dt > 0: # sanity :)
with open(status) as f: # Read previous
old = json.loads(f.read())
dt /= per_times[per]
for storage_id, sdata in sorted(data.items()):
......@@ -105,7 +110,7 @@ def check(addr, output_metrics, status, per):
for name in diff_names:
v = (sdata[name] - sold[name]) / dt
new_metric(metrics, storage_id, name, v)
with open(status, 'w') as f: # save current
with open(status, 'w') as f: # save current
f.write(json.dumps(data))
for storage_id, sdata in sorted(data.items()):
......@@ -116,6 +121,7 @@ def check(addr, output_metrics, status, per):
messages.append('OK')
return result(messages, metrics, level or None)
def main(args=None):
if args is None:
args = sys.argv[1:]
......@@ -139,5 +145,6 @@ def main(args=None):
return check(
addr, options.output_metrics, options.status_path, options.time_units)
if __name__ == '__main__':
main()
......@@ -46,21 +46,25 @@ from zdaemon.zdoptions import ZDOptions
logger = logging.getLogger('ZEO.runzeo')
_pid = str(os.getpid())
def log(msg, level=logging.INFO, exc_info=False):
"""Internal: generic logging function."""
message = "(%s) %s" % (_pid, msg)
logger.log(level, message, exc_info=exc_info)
def parse_binding_address(arg):
# Caution: Not part of the official ZConfig API.
obj = ZConfig.datatypes.SocketBindingAddress(arg)
return obj.family, obj.address
def windows_shutdown_handler():
# Called by the signal mechanism on Windows to perform shutdown.
import asyncore
asyncore.close_all()
class ZEOOptionsMixin(object):
storages = None
......@@ -69,14 +73,17 @@ class ZEOOptionsMixin(object):
self.family, self.address = parse_binding_address(arg)
def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*!
from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig(object):
def __init__(self, name, path):
self._name = name
self.path = path
self.stop = None
def getSectionName(self):
return self._name
if not self.storages:
self.storages = []
name = str(1 + len(self.storages))
......@@ -84,6 +91,7 @@ class ZEOOptionsMixin(object):
self.storages.append(conf)
testing_exit_immediately = False
def handle_test(self, *args):
self.testing_exit_immediately = True
......@@ -108,6 +116,7 @@ class ZEOOptionsMixin(object):
None, 'pid-file=')
self.add("ssl", "zeo.ssl")
class ZEOOptions(ZDOptions, ZEOOptionsMixin):
__doc__ = __doc__
......@@ -164,15 +173,15 @@ class ZEOServer(object):
root = logging.getLogger()
root.setLevel(logging.INFO)
fmt = logging.Formatter(
"------\n%(asctime)s %(levelname)s %(name)s %(message)s",
"%Y-%m-%dT%H:%M:%S")
"------\n%(asctime)s %(levelname)s %(name)s %(message)s",
"%Y-%m-%dT%H:%M:%S")
handler = logging.StreamHandler()
handler.setFormatter(fmt)
root.addHandler(handler)
def check_socket(self):
if (isinstance(self.options.address, tuple) and
self.options.address[1] is None):
if isinstance(self.options.address, tuple) and \
self.options.address[1] is None:
self.options.address = self.options.address[0], 0
return
......@@ -217,7 +226,7 @@ class ZEOServer(object):
self.setup_win32_signals()
return
if hasattr(signal, 'SIGXFSZ'):
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
init_signames()
for sig, name in signames.items():
method = getattr(self, "handle_" + name.lower(), None)
......@@ -237,12 +246,12 @@ class ZEOServer(object):
"will *not* be installed.")
return
SignalHandler = Signals.Signals.SignalHandler
if SignalHandler is not None: # may be None if no pywin32.
if SignalHandler is not None: # may be None if no pywin32.
SignalHandler.registerHandler(signal.SIGTERM,
windows_shutdown_handler)
SignalHandler.registerHandler(signal.SIGINT,
windows_shutdown_handler)
SIGUSR2 = 12 # not in signal module on Windows.
SIGUSR2 = 12 # not in signal module on Windows.
SignalHandler.registerHandler(SIGUSR2, self.handle_sigusr2)
def create_server(self):
......@@ -278,7 +287,8 @@ class ZEOServer(object):
def handle_sigusr2(self):
# log rotation signal - do the same as Zope 2.7/2.8...
if self.options.config_logger is None or os.name not in ("posix", "nt"):
if self.options.config_logger is None or \
os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!",
level=logging.WARNING)
return
......@@ -286,13 +296,13 @@ class ZEOServer(object):
loggers = [self.options.config_logger]
if os.name == "posix":
for l in loggers:
l.reopen()
for logger in loggers:
logger.reopen()
log("Log files reopened successfully", level=logging.INFO)
else: # nt - same rotation code as in Zope's Signals/Signals.py
for l in loggers:
for f in l.handler_factories:
handler = f()
else: # nt - same rotation code as in Zope's Signals/Signals.py
for logger in loggers:
for factory in logger.handler_factories:
handler = factory()
if hasattr(handler, 'rotate') and callable(handler.rotate):
handler.rotate()
log("Log files rotation complete", level=logging.INFO)
......@@ -350,21 +360,21 @@ def create_server(storages, options):
return StorageServer(
options.address,
storages,
read_only = options.read_only,
read_only=options.read_only,
client_conflict_resolution=options.client_conflict_resolution,
msgpack=(options.msgpack if isinstance(options.msgpack, bool)
else os.environ.get('ZEO_MSGPACK')),
invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout,
ssl = options.ssl,
)
invalidation_queue_size=options.invalidation_queue_size,
invalidation_age=options.invalidation_age,
transaction_timeout=options.transaction_timeout,
ssl=options.ssl)
# Signal names
signames = None
def signame(sig):
"""Return a symbolic name for a signal.
......@@ -376,6 +386,7 @@ def signame(sig):
init_signames()
return signames.get(sig) or "signal %d" % sig
def init_signames():
global signames
signames = {}
......@@ -395,11 +406,13 @@ def main(args=None):
s = ZEOServer(options)
s.main()
def run(args):
options = ZEOOptions()
options.realize(args)
s = ZEOServer(options)
s.run()
if __name__ == "__main__":
main()
......@@ -27,6 +27,7 @@ from __future__ import print_function, absolute_import
import bisect
import struct
import random
import re
import sys
import ZEO.cache
......@@ -34,6 +35,7 @@ import argparse
from ZODB.utils import z64
from ..cache import ZEC_HEADER_SIZE
from .cache_stats import add_interval_argument
from .cache_stats import add_tracefile_argument
......@@ -46,7 +48,7 @@ def main(args=None):
if args is None:
args = sys.argv[1:]
# Parse options.
MB = 1<<20
MB = 1 << 20
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--size", "-s",
default=20*MB, dest="cachelimit",
......@@ -115,6 +117,7 @@ def main(args=None):
interval_sim.report()
sim.finish()
class Simulation(object):
"""Base class for simulations.
......@@ -270,7 +273,6 @@ class CircularCacheEntry(object):
self.end_tid = end_tid
self.offset = offset
from ZEO.cache import ZEC_HEADER_SIZE
class CircularCacheSimulation(Simulation):
"""Simulate the ZEO 3.0 cache."""
......@@ -285,8 +287,6 @@ class CircularCacheSimulation(Simulation):
evicts = 0
def __init__(self, cachelimit, rearrange):
from ZEO import cache
Simulation.__init__(self, cachelimit, rearrange)
self.total_evicts = 0 # number of cache evictions
......@@ -296,7 +296,7 @@ class CircularCacheSimulation(Simulation):
# Map offset in file to (size, CircularCacheEntry) pair, or to
# (size, None) if the offset starts a free block.
self.filemap = {ZEC_HEADER_SIZE: (self.cachelimit - ZEC_HEADER_SIZE,
None)}
None)}
# Map key to CircularCacheEntry. A key is an (oid, tid) pair.
self.key2entry = {}
......@@ -322,10 +322,11 @@ class CircularCacheSimulation(Simulation):
self.evicted_hit = self.evicted_miss = 0
evicted_hit = evicted_miss = 0
def load(self, oid, size, tid, code):
if (code == 0x20) or (code == 0x22):
# Trying to load current revision.
if oid in self.current: # else it's a cache miss
if oid in self.current: # else it's a cache miss
self.hits += 1
self.total_hits += 1
......@@ -512,7 +513,7 @@ class CircularCacheSimulation(Simulation):
self.inuse = round(100.0 * used / total, 1)
self.total_inuse = self.inuse
Simulation.report(self)
#print self.evicted_hit, self.evicted_miss
# print self.evicted_hit, self.evicted_miss
def check(self):
oidcount = 0
......@@ -536,16 +537,18 @@ class CircularCacheSimulation(Simulation):
def roundup(size):
k = MINSIZE
k = MINSIZE # NOQA: F821 undefined name
while k < size:
k += k
return k
def hitrate(loads, hits):
if loads < 1:
return 'n/a'
return "%5.1f%%" % (100.0 * hits / loads)
def duration(secs):
mm, ss = divmod(secs, 60)
hh, mm = divmod(mm, 60)
......@@ -555,7 +558,10 @@ def duration(secs):
return "%d:%02d" % (mm, ss)
return "%d" % ss
nre = re.compile('([=-]?)(\d+)([.]\d*)?').match
nre = re.compile(r'([=-]?)(\d+)([.]\d*)?').match
def addcommas(n):
sign, s, d = nre(str(n)).group(1, 2, 3)
if d == '.0':
......@@ -569,11 +575,11 @@ def addcommas(n):
return (sign or '') + result + (d or '')
import random
def maybe(f, p=0.5):
if random.random() < p:
f()
if __name__ == "__main__":
sys.exit(main())
......@@ -55,6 +55,7 @@ import gzip
from time import ctime
import six
def add_interval_argument(parser):
def _interval(a):
interval = int(60 * float(a))
......@@ -63,9 +64,11 @@ def add_interval_argument(parser):
elif interval > 3600:
interval = 3600
return interval
parser.add_argument("--interval", "-i",
default=15*60, type=_interval,
help="summarizing interval in minutes (default 15; max 60)")
parser.add_argument(
"--interval", "-i",
default=15*60, type=_interval,
help="summarizing interval in minutes (default 15; max 60)")
def add_tracefile_argument(parser):
......@@ -82,15 +85,17 @@ def add_tracefile_argument(parser):
parser.add_argument("tracefile", type=GzipFileType(),
help="The trace to read; may be gzipped")
def main(args=None):
if args is None:
args = sys.argv[1:]
# Parse options
parser = argparse.ArgumentParser(description="Trace file statistics analyzer",
# Our -h, short for --load-histogram
# conflicts with default for help, so we handle
# manually.
add_help=False)
parser = argparse.ArgumentParser(
description="Trace file statistics analyzer",
# Our -h, short for --load-histogram
# conflicts with default for help, so we handle
# manually.
add_help=False)
verbose_group = parser.add_mutually_exclusive_group()
verbose_group.add_argument('--verbose', '-v',
default=False, action='store_true',
......@@ -99,18 +104,22 @@ def main(args=None):
default=False, action='store_true',
help="Reduce output; don't print summaries")
parser.add_argument("--sizes", '-s',
default=False, action="store_true", dest="print_size_histogram",
default=False, action="store_true",
dest="print_size_histogram",
help="print histogram of object sizes")
parser.add_argument("--no-stats", '-S',
default=True, action="store_false", dest="dostats",
help="don't print statistics")
parser.add_argument("--load-histogram", "-h",
default=False, action="store_true", dest="print_histogram",
default=False, action="store_true",
dest="print_histogram",
help="print histogram of object load frequencies")
parser.add_argument("--check", "-X",
default=False, action="store_true", dest="heuristic",
help=" enable heuristic checking for misaligned records: oids > 2**32"
" will be rejected; this requires the tracefile to be seekable")
help=" enable heuristic checking for misaligned "
"records: oids > 2**32"
" will be rejected; this requires the tracefile "
"to be seekable")
add_interval_argument(parser)
add_tracefile_argument(parser)
......@@ -123,20 +132,20 @@ def main(args=None):
f = options.tracefile
rt0 = time.time()
bycode = {} # map code to count of occurrences
byinterval = {} # map code to count in current interval
records = 0 # number of trace records read
versions = 0 # number of trace records with versions
datarecords = 0 # number of records with dlen set
datasize = 0 # sum of dlen across records with dlen set
oids = {} # map oid to number of times it was loaded
bysize = {} # map data size to number of loads
bysizew = {} # map data size to number of writes
bycode = {} # map code to count of occurrences
byinterval = {} # map code to count in current interval
records = 0 # number of trace records read
versions = 0 # number of trace records with versions
datarecords = 0 # number of records with dlen set
datasize = 0 # sum of dlen across records with dlen set
oids = {} # map oid to number of times it was loaded
bysize = {} # map data size to number of loads
bysizew = {} # map data size to number of writes
total_loads = 0
t0 = None # first timestamp seen
te = None # most recent timestamp seen
h0 = None # timestamp at start of current interval
he = None # timestamp at end of current interval
t0 = None # first timestamp seen
te = None # most recent timestamp seen
h0 = None # timestamp at start of current interval
he = None # timestamp at end of current interval
thisinterval = None # generally te//interval
f_read = f.read
unpack = struct.unpack
......@@ -144,7 +153,8 @@ def main(args=None):
FMT_SIZE = struct.calcsize(FMT)
assert FMT_SIZE == 26
# Read file, gathering statistics, and printing each record if verbose.
print(' '*16, "%7s %7s %7s %7s" % ('loads', 'hits', 'inv(h)', 'writes'), end=' ')
print(' '*16, "%7s %7s %7s %7s" % (
'loads', 'hits', 'inv(h)', 'writes'), end=' ')
print('hitrate')
try:
while 1:
......@@ -187,10 +197,10 @@ def main(args=None):
bycode[code] = bycode.get(code, 0) + 1
byinterval[code] = byinterval.get(code, 0) + 1
if dlen:
if code & 0x70 == 0x20: # All loads
if code & 0x70 == 0x20: # All loads
bysize[dlen] = d = bysize.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1
elif code & 0x70 == 0x50: # All stores
elif code & 0x70 == 0x50: # All stores
bysizew[dlen] = d = bysizew.get(dlen) or {}
d[oid] = d.get(oid, 0) + 1
if options.verbose:
......@@ -205,7 +215,7 @@ def main(args=None):
if code & 0x70 == 0x20:
oids[oid] = oids.get(oid, 0) + 1
total_loads += 1
elif code == 0x00: # restart
elif code == 0x00: # restart
if not options.quiet:
dumpbyinterval(byinterval, h0, he)
byinterval = {}
......@@ -279,6 +289,7 @@ def main(args=None):
dumpbysize(bysizew, "written", "writes")
dumpbysize(bysize, "loaded", "loads")
def dumpbysize(bysize, how, how2):
print()
print("Unique sizes %s: %s" % (how, addcommas(len(bysize))))
......@@ -292,6 +303,7 @@ def dumpbysize(bysize, how, how2):
len(bysize.get(size, "")),
loads))
def dumpbyinterval(byinterval, h0, he):
loads = hits = invals = writes = 0
for code in byinterval:
......@@ -301,7 +313,7 @@ def dumpbyinterval(byinterval, h0, he):
if code in (0x22, 0x26):
hits += n
elif code & 0x40:
writes += byinterval[code]
writes += byinterval[code]
elif code & 0x10:
if code != 0x10:
invals += byinterval[code]
......@@ -315,6 +327,7 @@ def dumpbyinterval(byinterval, h0, he):
ctime(h0)[4:-8], ctime(he)[14:-8],
loads, hits, invals, writes, hr))
def hitrate(bycode):
loads = hits = 0
for code in bycode:
......@@ -328,6 +341,7 @@ def hitrate(bycode):
else:
return 0.0
def histogram(d):
bins = {}
for v in six.itervalues(d):
......@@ -335,15 +349,18 @@ def histogram(d):
L = sorted(bins.items())
return L
def U64(s):
return struct.unpack(">Q", s)[0]
def oid_repr(oid):
if isinstance(oid, six.binary_type) and len(oid) == 8:
return '%16x' % U64(oid)
else:
return repr(oid)
def addcommas(n):
sign, s = '', str(n)
if s[0] == '-':
......@@ -354,6 +371,7 @@ def addcommas(n):
i -= 3
return sign + s
explain = {
# The first hex digit shows the operation, the second the outcome.
# If the second digit is in "02468" then it is a 'miss'.
......
......@@ -3,7 +3,7 @@
"""Parse the BLATHER logging generated by ZEO2.
An example of the log format is:
2002-04-15T13:05:29 BLATHER(-100) ZEO Server storea(3235680, [714], 235339406490168806) ('10.0.26.30', 45514)
2002-04-15T13:05:29 BLATHER(-100) ZEO Server storea(3235680, [714], 235339406490168806) ('10.0.26.30', 45514) # NOQA: E501 line too long
"""
from __future__ import print_function
from __future__ import print_function
......@@ -14,7 +14,8 @@ from __future__ import print_function
import re
import time
rx_time = re.compile('(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
rx_time = re.compile(r'(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None."""
......@@ -26,11 +27,14 @@ def parse_time(line):
time_l = [int(elt) for elt in time_.split(':')]
return int(time.mktime(date_l + time_l + [0, 0, 0]))
rx_meth = re.compile("zrpc:\d+ calling (\w+)\((.*)")
rx_meth = re.compile(r"zrpc:\d+ calling (\w+)\((.*)")
def parse_method(line):
pass
def parse_line(line):
"""Parse a log entry and return time, method info, and client."""
t = parse_time(line)
......@@ -47,6 +51,7 @@ def parse_line(line):
m = meth_name, tuple(meth_args)
return t, m
class TStats(object):
counter = 1
......@@ -61,7 +66,6 @@ class TStats(object):
def report(self):
"""Print a report about the transaction"""
t = time.ctime(self.begin)
if hasattr(self, "vote"):
d_vote = self.vote - self.begin
else:
......@@ -69,10 +73,11 @@ class TStats(object):
if hasattr(self, "finish"):
d_finish = self.finish - self.begin
else:
d_finish = "*"
d_finish = "*"
print(self.fmt % (time.ctime(self.begin), d_vote, d_finish,
self.user, self.url))
class TransactionParser(object):
def __init__(self):
......@@ -122,6 +127,7 @@ class TransactionParser(object):
L.sort()
return [t for (id, t) in L]
if __name__ == "__main__":
import fileinput
......@@ -131,7 +137,7 @@ if __name__ == "__main__":
i += 1
try:
p.parse(line)
except:
except: # NOQA: E722 bare except
print("line", i)
raise
print("Transaction: %d" % len(p.txns))
......
......@@ -12,18 +12,21 @@
#
##############################################################################
from __future__ import print_function
import doctest, re, unittest
import doctest
import re
import unittest
from zope.testing import renormalizing
def test_suite():
return unittest.TestSuite((
doctest.DocFileSuite(
'zeopack.test',
checker=renormalizing.RENormalizing([
(re.compile('usage: Usage: '), 'Usage: '), # Py 2.4
(re.compile('options:'), 'Options:'), # Py 2.4
(re.compile('usage: Usage: '), 'Usage: '), # Py 2.4
(re.compile('options:'), 'Options:'), # Py 2.4
]),
globs={'print_function': print_function},
),
))
......@@ -25,6 +25,7 @@ from ZEO.ClientStorage import ClientStorage
ZERO = '\0'*8
def main():
if len(sys.argv) not in (3, 4):
sys.stderr.write("Usage: timeout.py address delay [storage-name]\n" %
......@@ -68,5 +69,6 @@ def main():
time.sleep(delay)
print("Done.")
if __name__ == "__main__":
main()
......@@ -8,7 +8,6 @@ import time
import traceback
import ZEO.ClientStorage
from six.moves import map
from six.moves import zip
usage = """Usage: %prog [options] [servers]
......@@ -21,7 +20,8 @@ each is of the form:
"""
WAIT = 10 # wait no more than 10 seconds for client to connect
WAIT = 10 # wait no more than 10 seconds for client to connect
def _main(args=None, prog=None):
if args is None:
......@@ -160,10 +160,11 @@ def _main(args=None, prog=None):
continue
cs.pack(packt, wait=True)
cs.close()
except:
except: # NOQA: E722 bare except
traceback.print_exception(*(sys.exc_info()+(99, sys.stderr)))
error("Error packing storage %s in %r" % (name, addr))
def main(*args):
root_logger = logging.getLogger()
old_level = root_logger.getEffectiveLevel()
......@@ -178,6 +179,6 @@ def main(*args):
logging.getLogger().setLevel(old_level)
logging.getLogger().removeHandler(handler)
if __name__ == "__main__":
main()
......@@ -37,7 +37,6 @@ STATEFILE = 'zeoqueue.pck'
PROGRAM = sys.argv[0]
tcre = re.compile(r"""
(?P<ymd>
\d{4}- # year
......@@ -67,7 +66,6 @@ ccre = re.compile(r"""
wcre = re.compile(r'Clients waiting: (?P<num>\d+)')
def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None."""
mo = tcre.match(line)
......@@ -97,7 +95,6 @@ class Txn(object):
return False
class Status(object):
"""Track status of ZEO server by replaying log records.
......@@ -303,7 +300,6 @@ class Status(object):
break
def usage(code, msg=''):
print(__doc__ % globals(), file=sys.stderr)
if msg:
......
......@@ -41,25 +41,25 @@ import time
import getopt
import operator
# ZEO logs measure wall-clock time so for consistency we need to do the same
#from time import clock as now
# from time import clock as now
from time import time as now
from ZODB.FileStorage import FileStorage
#from BDBStorage.BDBFullStorage import BDBFullStorage
#from Standby.primary import PrimaryStorage
#from Standby.config import RS_PORT
# from BDBStorage.BDBFullStorage import BDBFullStorage
# from Standby.primary import PrimaryStorage
# from Standby.config import RS_PORT
from ZODB.Connection import TransactionMetaData
from ZODB.utils import p64
from functools import reduce
datecre = re.compile('(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
methcre = re.compile("ZEO Server (\w+)\((.*)\) \('(.*)', (\d+)")
datecre = re.compile(r'(\d\d\d\d-\d\d-\d\d)T(\d\d:\d\d:\d\d)')
methcre = re.compile(r"ZEO Server (\w+)\((.*)\) \('(.*)', (\d+)")
class StopParsing(Exception):
pass
def usage(code, msg=''):
print(__doc__)
if msg:
......@@ -67,7 +67,6 @@ def usage(code, msg=''):
sys.exit(code)
def parse_time(line):
"""Return the time portion of a zLOG line in seconds or None."""
mo = datecre.match(line)
......@@ -95,7 +94,6 @@ def parse_line(line):
return t, m, c
class StoreStat(object):
def __init__(self, when, oid, size):
self.when = when
......@@ -104,8 +102,10 @@ class StoreStat(object):
# Crufty
def __getitem__(self, i):
if i == 0: return self.oid
if i == 1: return self.size
if i == 0:
return self.oid
if i == 1:
return self.size
raise IndexError
......@@ -136,10 +136,10 @@ class TxnStat(object):
self._finishtime = when
# Mapping oid -> revid
_revids = {}
class ReplayTxn(TxnStat):
def __init__(self, storage):
self._storage = storage
......@@ -157,7 +157,7 @@ class ReplayTxn(TxnStat):
# BAW: simulate a pickle of the given size
data = 'x' * obj.size
# BAW: ignore versions for now
newrevid = self._storage.store(p64(oid), revid, data, '', t)
newrevid = self._storage.store(p64(oid), revid, data, '', t)
_revids[oid] = newrevid
if self._aborttime:
self._storage.tpc_abort(t)
......@@ -172,7 +172,6 @@ class ReplayTxn(TxnStat):
self._replaydelta = t1 - t0 - origdelta
class ZEOParser(object):
def __init__(self, maxtxns=-1, report=1, storage=None):
self.__txns = []
......@@ -261,7 +260,6 @@ class ZEOParser(object):
print('average faster txn was:', float(sum) / len(faster))
def main():
try:
opts, args = getopt.getopt(
......@@ -294,8 +292,8 @@ def main():
if replay:
storage = FileStorage(storagefile)
#storage = BDBFullStorage(storagefile)
#storage = PrimaryStorage('yyz', storage, RS_PORT)
# storage = BDBFullStorage(storagefile)
# storage = PrimaryStorage('yyz', storage, RS_PORT)
t0 = now()
p = ZEOParser(maxtxns, report, storage)
i = 0
......@@ -308,7 +306,7 @@ def main():
p.parse(line)
except StopParsing:
break
except:
except: # NOQA: E722 bare except
print('input file line:', i)
raise
t1 = now()
......@@ -321,6 +319,5 @@ def main():
print('total time:', t3-t0)
if __name__ == '__main__':
main()
......@@ -169,9 +169,11 @@ from __future__ import print_function
from __future__ import print_function
from __future__ import print_function
import datetime, sys, re, os
import datetime
import os
import re
import sys
from six.moves import map
from six.moves import zip
def time(line):
......@@ -187,9 +189,10 @@ def sub(t1, t2):
return delta.days*86400.0+delta.seconds+delta.microseconds/1000000.0
waitre = re.compile(r'Clients waiting: (\d+)')
idre = re.compile(r' ZSS:\d+/(\d+.\d+.\d+.\d+:\d+) ')
def blocked_times(args):
f, thresh = args
......@@ -217,7 +220,6 @@ def blocked_times(args):
t2 = t1
if not blocking and last_blocking:
last_wait = 0
t2 = time(line)
cid = idre.search(line).group(1)
......@@ -225,11 +227,14 @@ def blocked_times(args):
d = sub(t1, time(line))
if d >= thresh:
print(t1, sub(t1, t2), cid, d)
t1 = t2 = cid = blocking = waiting = last_wait = max_wait = 0
t1 = t2 = cid = blocking = waiting = 0
last_blocking = blocking
connidre = re.compile(r' zrpc-conn:(\d+.\d+.\d+.\d+:\d+) ')
def time_calls(f):
f, thresh = f
if f == '-':
......@@ -255,6 +260,7 @@ def time_calls(f):
print(maxd)
def xopen(f):
if f == '-':
return sys.stdin
......@@ -262,6 +268,7 @@ def xopen(f):
return os.popen(f, 'r')
return open(f)
def time_tpc(f):
f, thresh = f
if f == '-':
......@@ -307,11 +314,14 @@ def time_tpc(f):
t = time(line)
d = sub(t1, t)
if d >= thresh:
print('c', t1, cid, sub(t1, t2), vs, sub(t2, t3), sub(t3, t))
print('c', t1, cid, sub(t1, t2),
vs, sub(t2, t3), sub(t3, t))
del transactions[cid]
newobre = re.compile(r"storea\(.*, '\\x00\\x00\\x00\\x00\\x00")
def time_trans(f):
f, thresh = f
if f == '-':
......@@ -363,8 +373,8 @@ def time_trans(f):
t = time(line)
d = sub(t1, t)
if d >= thresh:
print(t1, cid, "%s/%s" % (stores, old), \
sub(t0, t1), sub(t1, t2), vs, \
print(t1, cid, "%s/%s" % (stores, old),
sub(t0, t1), sub(t1, t2), vs,
sub(t2, t), 'abort')
del transactions[cid]
elif ' calling tpc_finish(' in line:
......@@ -377,11 +387,12 @@ def time_trans(f):
t = time(line)
d = sub(t1, t)
if d >= thresh:
print(t1, cid, "%s/%s" % (stores, old), \
sub(t0, t1), sub(t1, t2), vs, \
print(t1, cid, "%s/%s" % (stores, old),
sub(t0, t1), sub(t1, t2), vs,
sub(t2, t3), sub(t3, t))
del transactions[cid]
def minute(f, slice=16, detail=1, summary=1):
f, = f
......@@ -405,10 +416,9 @@ def minute(f, slice=16, detail=1, summary=1):
for line in f:
line = line.strip()
if (line.find('returns') > 0
or line.find('storea') > 0
or line.find('tpc_abort') > 0
):
if line.find('returns') > 0 or \
line.find('storea') > 0 or \
line.find('tpc_abort') > 0:
client = connidre.search(line).group(1)
m = line[:slice]
if m != mlast:
......@@ -452,12 +462,13 @@ def minute(f, slice=16, detail=1, summary=1):
print('Summary: \t', '\t'.join(('min', '10%', '25%', 'med',
'75%', '90%', 'max', 'mean')))
print("n=%6d\t" % len(cls), '-'*62)
print('Clients: \t', '\t'.join(map(str,stats(cls))))
print('Reads: \t', '\t'.join(map(str,stats(rs))))
print('Stores: \t', '\t'.join(map(str,stats(ss))))
print('Commits: \t', '\t'.join(map(str,stats(cs))))
print('Aborts: \t', '\t'.join(map(str,stats(aborts))))
print('Trans: \t', '\t'.join(map(str,stats(ts))))
print('Clients: \t', '\t'.join(map(str, stats(cls))))
print('Reads: \t', '\t'.join(map(str, stats(rs))))
print('Stores: \t', '\t'.join(map(str, stats(ss))))
print('Commits: \t', '\t'.join(map(str, stats(cs))))
print('Aborts: \t', '\t'.join(map(str, stats(aborts))))
print('Trans: \t', '\t'.join(map(str, stats(ts))))
def stats(s):
s.sort()
......@@ -468,13 +479,14 @@ def stats(s):
ni = n + 1
for p in .1, .25, .5, .75, .90:
lp = ni*p
l = int(lp)
lp_int = int(lp)
if lp < 1 or lp > n:
out.append('-')
elif abs(lp-l) < .00001:
out.append(s[l-1])
elif abs(lp-lp_int) < .00001:
out.append(s[lp_int-1])
else:
out.append(int(s[l-1] + (lp - l) * (s[l] - s[l-1])))
out.append(
int(s[lp_int-1] + (lp - lp_int) * (s[lp_int] - s[lp_int-1])))
mean = 0.0
for v in s:
......@@ -484,24 +496,31 @@ def stats(s):
return out
def minutes(f):
minute(f, 16, detail=0)
def hour(f):
minute(f, 13)
def day(f):
minute(f, 10)
def hours(f):
minute(f, 13, detail=0)
def days(f):
minute(f, 10, detail=0)
new_connection_idre = re.compile(
r"new connection \('(\d+.\d+.\d+.\d+)', (\d+)\):")
def verify(f):
f, = f
......@@ -527,6 +546,7 @@ def verify(f):
d = sub(t1, time(line))
print(cid, t1, n, d, n and (d*1000.0/n) or '-')
def recovery(f):
f, = f
......@@ -542,16 +562,16 @@ def recovery(f):
n += 1
if line.find('RecoveryServer') < 0:
continue
l = line.find('sending transaction ')
if l > 0 and last.find('sending transaction ') > 0:
trans.append(line[l+20:].strip())
pos = line.find('sending transaction ')
if pos > 0 and last.find('sending transaction ') > 0:
trans.append(line[pos+20:].strip())
else:
if trans:
if len(trans) > 1:
print(" ... %s similar records skipped ..." % (
len(trans) - 1))
print(n, last.strip())
trans=[]
trans = []
print(n, line.strip())
last = line
......@@ -561,6 +581,5 @@ def recovery(f):
print(n, last.strip())
if __name__ == '__main__':
globals()[sys.argv[1]](sys.argv[2:])
......@@ -47,6 +47,7 @@ from ZEO.ClientStorage import ClientStorage
ZEO_VERSION = 2
def setup_logging():
# Set up logging to stderr which will show messages originating
# at severity ERROR or higher.
......@@ -59,6 +60,7 @@ def setup_logging():
handler.setFormatter(fmt)
root.addHandler(handler)
def check_server(addr, storage, write):
t0 = time.time()
if ZEO_VERSION == 2:
......@@ -97,11 +99,13 @@ def check_server(addr, storage, write):
t1 = time.time()
print("Elapsed time: %.2f" % (t1 - t0))
def usage(exit=1):
print(__doc__)
print(" ".join(sys.argv))
sys.exit(exit)
def main():
host = None
port = None
......@@ -123,7 +127,7 @@ def main():
elif o == '--nowrite':
write = 0
elif o == '-1':
ZEO_VERSION = 1
ZEO_VERSION = 1 # NOQA: F841 unused variable
except Exception as err:
s = str(err)
if s:
......@@ -143,6 +147,7 @@ def main():
setup_logging()
check_server(addr, storage, write)
if __name__ == "__main__":
try:
main()
......
......@@ -14,8 +14,9 @@
REPR_LIMIT = 60
def short_repr(obj):
"Return an object repr limited to REPR_LIMIT bytes."
"""Return an object repr limited to REPR_LIMIT bytes."""
# Some of the objects being repr'd are large strings. A lot of memory
# would be wasted to repr them and then truncate, so they are treated
......
......@@ -17,6 +17,7 @@ from ZODB.Connection import TransactionMetaData
from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_unpickle
class TransUndoStorageWithCache(object):
def checkUndoInvalidation(self):
......
......@@ -20,12 +20,12 @@ from persistent.TimeStamp import TimeStamp
from ZODB.Connection import TransactionMetaData
from ZODB.tests.StorageTestBase import zodb_pickle, MinPO
import ZEO.ClientStorage
from ZEO.Exceptions import ClientDisconnected
from ZEO.tests.TestThread import TestThread
ZERO = b'\0'*8
class WorkerThread(TestThread):
# run the entire test in a thread so that the blocking call for
......@@ -62,6 +62,7 @@ class WorkerThread(TestThread):
self.ready.set()
future.result(9)
class CommitLockTests(object):
NUM_CLIENTS = 5
......@@ -99,7 +100,7 @@ class CommitLockTests(object):
for i in range(self.NUM_CLIENTS):
storage = self._new_storage_client()
txn = TransactionMetaData()
tid = self._get_timestamp()
self._get_timestamp()
t = WorkerThread(self, storage, txn)
self._threads.append(t)
......@@ -118,9 +119,10 @@ class CommitLockTests(object):
def _get_timestamp(self):
t = time.time()
t = TimeStamp(*time.gmtime(t)[:5]+(t%60,))
t = TimeStamp(*time.gmtime(t)[:5]+(t % 60,))
return repr(t)
class CommitLockVoteTests(CommitLockTests):
def checkCommitLockVoteFinish(self):
......
......@@ -26,11 +26,10 @@ from ZEO.tests import forker
from ZODB.Connection import TransactionMetaData
from ZODB.DB import DB
from ZODB.POSException import ReadOnlyError, ConflictError
from ZODB.POSException import ReadOnlyError
from ZODB.tests.StorageTestBase import StorageTestBase
from ZODB.tests.MinPO import MinPO
from ZODB.tests.StorageTestBase import zodb_pickle, zodb_unpickle
import ZODB.tests.util
import transaction
......@@ -40,6 +39,7 @@ logger = logging.getLogger('ZEO.tests.ConnectionTests')
ZERO = '\0'*8
class TestClientStorage(ClientStorage):
test_connection = False
......@@ -51,6 +51,7 @@ class TestClientStorage(ClientStorage):
self.connection_count_for_tests += 1
self.verify_result = conn.verify_result
class DummyDB(object):
def invalidate(self, *args, **kwargs):
pass
......@@ -93,7 +94,7 @@ class CommonSetupTearDown(StorageTestBase):
self._storage.close()
if hasattr(self._storage, 'cleanup'):
logging.debug("cleanup storage %s" %
self._storage.__name__)
self._storage.__name__)
self._storage.cleanup()
for stop in self._servers:
stop()
......@@ -113,7 +114,7 @@ class CommonSetupTearDown(StorageTestBase):
for dummy in range(5):
try:
os.unlink(path)
except:
except: # NOQA: E722 bare except
time.sleep(0.5)
else:
need_to_delete = False
......@@ -188,7 +189,7 @@ class CommonSetupTearDown(StorageTestBase):
stop = self._servers[index]
if stop is not None:
stop()
self._servers[index] = lambda : None
self._servers[index] = lambda: None
def pollUp(self, timeout=30.0, storage=None):
if storage is None:
......@@ -271,7 +272,6 @@ class ConnectionTests(CommonSetupTearDown):
self.assertRaises(ReadOnlyError, self._dostore)
self._storage.close()
def checkDisconnectionError(self):
# Make sure we get a ClientDisconnected when we try to read an
# object when we're not connected to a storage server and the
......@@ -374,7 +374,7 @@ class ConnectionTests(CommonSetupTearDown):
pickle, rev = self._storage.load(oid, '')
newobj = zodb_unpickle(pickle)
self.assertEqual(newobj, obj)
newobj.value = 42 # .value *should* be 42 forever after now, not 13
newobj.value = 42 # .value *should* be 42 forever after now, not 13
self._dostore(oid, data=newobj, revid=rev)
self._storage.close()
......@@ -416,6 +416,7 @@ class ConnectionTests(CommonSetupTearDown):
def checkBadMessage2(self):
# just like a real message, but with an unpicklable argument
global Hack
class Hack(object):
pass
......@@ -505,7 +506,7 @@ class ConnectionTests(CommonSetupTearDown):
r1["a"] = MinPO("a")
transaction.commit()
self.assertEqual(r1._p_state, 0) # up-to-date
self.assertEqual(r1._p_state, 0) # up-to-date
db2 = DB(self.openClientStorage())
r2 = db2.open().root()
......@@ -524,9 +525,9 @@ class ConnectionTests(CommonSetupTearDown):
if r1._p_state == -1:
break
time.sleep(i / 10.0)
self.assertEqual(r1._p_state, -1) # ghost
self.assertEqual(r1._p_state, -1) # ghost
r1.keys() # unghostify
r1.keys() # unghostify
self.assertEqual(r1._p_serial, r2._p_serial)
self.assertEqual(r1["b"].value, "b")
......@@ -551,6 +552,7 @@ class ConnectionTests(CommonSetupTearDown):
self.assertRaises(ClientDisconnected,
self._storage.load, b'\0'*8, '')
class SSLConnectionTests(ConnectionTests):
def getServerConfig(self, addr, ro_svr):
......@@ -585,13 +587,13 @@ class InvqTests(CommonSetupTearDown):
revid2 = self._dostore(oid2, revid2)
forker.wait_until(
lambda :
lambda:
perstorage.lastTransaction() == self._storage.lastTransaction())
perstorage.load(oid, '')
perstorage.close()
forker.wait_until(lambda : os.path.exists('test-1.zec'))
forker.wait_until(lambda: os.path.exists('test-1.zec'))
revid = self._dostore(oid, revid)
......@@ -617,7 +619,7 @@ class InvqTests(CommonSetupTearDown):
revid = self._dostore(oid, revid)
forker.wait_until(
"Client has seen all of the transactions from the server",
lambda :
lambda:
perstorage.lastTransaction() == self._storage.lastTransaction()
)
perstorage.load(oid, '')
......@@ -635,6 +637,7 @@ class InvqTests(CommonSetupTearDown):
perstorage.close()
class ReconnectionTests(CommonSetupTearDown):
# The setUp() starts a server automatically. In order for its
# state to persist, we set the class variable keep to 1. In
......@@ -798,7 +801,7 @@ class ReconnectionTests(CommonSetupTearDown):
# Start a read-write server
self.startServer(index=1, read_only=0, keep=0)
# After a while, stores should work
for i in range(300): # Try for 30 seconds
for i in range(300): # Try for 30 seconds
try:
self._dostore()
break
......@@ -840,7 +843,7 @@ class ReconnectionTests(CommonSetupTearDown):
revid = self._dostore(oid, revid)
forker.wait_until(
"Client has seen all of the transactions from the server",
lambda :
lambda:
perstorage.lastTransaction() == self._storage.lastTransaction()
)
perstorage.load(oid, '')
......@@ -894,7 +897,6 @@ class ReconnectionTests(CommonSetupTearDown):
# Module ZEO.ClientStorage, line 709, in _update_cache
# KeyError: ...
def checkReconnection(self):
# Check that the client reconnects when a server restarts.
......@@ -952,6 +954,7 @@ class ReconnectionTests(CommonSetupTearDown):
self.assertTrue(did_a_store)
self._storage.close()
class TimeoutTests(CommonSetupTearDown):
timeout = 1
......@@ -967,9 +970,8 @@ class TimeoutTests(CommonSetupTearDown):
# Make sure it's logged as CRITICAL
with open("server.log") as f:
for line in f:
if (('Transaction timeout after' in line) and
('CRITICAL ZEO.StorageServer' in line)
):
if ('Transaction timeout after' in line) and \
('CRITICAL ZEO.StorageServer' in line):
break
else:
self.fail('bad logging')
......@@ -1002,7 +1004,7 @@ class TimeoutTests(CommonSetupTearDown):
t = TransactionMetaData()
old_connection_count = storage.connection_count_for_tests
storage.tpc_begin(t)
revid1 = storage.store(oid, ZERO, zodb_pickle(obj), '', t)
storage.store(oid, ZERO, zodb_pickle(obj), '', t)
storage.tpc_vote(t)
# Now sleep long enough for the storage to time out
time.sleep(3)
......@@ -1021,6 +1023,7 @@ class TimeoutTests(CommonSetupTearDown):
# or the server.
self.assertRaises(KeyError, storage.load, oid, '')
class MSTThread(threading.Thread):
__super_init = threading.Thread.__init__
......@@ -1054,7 +1057,7 @@ class MSTThread(threading.Thread):
# Begin a transaction
t = TransactionMetaData()
for c in clients:
#print("%s.%s.%s begin" % (tname, c.__name, i))
# print("%s.%s.%s begin" % (tname, c.__name, i))
c.tpc_begin(t)
for j in range(testcase.nobj):
......@@ -1063,18 +1066,18 @@ class MSTThread(threading.Thread):
oid = c.new_oid()
c.__oids.append(oid)
data = MinPO("%s.%s.t%d.o%d" % (tname, c.__name, i, j))
#print(data.value)
# print(data.value)
data = zodb_pickle(data)
c.store(oid, ZERO, data, '', t)
# Vote on all servers and handle serials
for c in clients:
#print("%s.%s.%s vote" % (tname, c.__name, i))
# print("%s.%s.%s vote" % (tname, c.__name, i))
c.tpc_vote(t)
# Finish on all servers
for c in clients:
#print("%s.%s.%s finish\n" % (tname, c.__name, i))
# print("%s.%s.%s finish\n" % (tname, c.__name, i))
c.tpc_finish(t)
for c in clients:
......@@ -1090,7 +1093,7 @@ class MSTThread(threading.Thread):
for c in self.clients:
try:
c.close()
except:
except: # NOQA: E722 bare except
pass
......@@ -1101,6 +1104,7 @@ def short_timeout(self):
yield
self._storage._server.timeout = old
# Run IPv6 tests if V6 sockets are supported
try:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
......
......@@ -41,6 +41,7 @@ from ZODB.POSException import ReadConflictError, ConflictError
# thought they added (i.e., the keys for which transaction.commit()
# did not raise any exception).
class FailableThread(TestThread):
# mixin class
......@@ -52,7 +53,7 @@ class FailableThread(TestThread):
def testrun(self):
try:
self._testrun()
except:
except: # NOQA: E722 bare except
# Report the failure here to all the other threads, so
# that they stop quickly.
self.stop.set()
......@@ -81,12 +82,11 @@ class StressTask(object):
tree[key] = self.threadnum
def commit(self):
cn = self.cn
key = self.startnum
self.tm.get().note(u"add key %s" % key)
try:
self.tm.get().commit()
except ConflictError as msg:
except ConflictError:
self.tm.abort()
else:
if self.sleep:
......@@ -98,15 +98,18 @@ class StressTask(object):
self.tm.get().abort()
self.cn.close()
def _runTasks(rounds, *tasks):
'''run *task* interleaved for *rounds* rounds.'''
def commit(run, actions):
actions.append(':')
for t in run:
t.commit()
del run[:]
r = Random()
r.seed(1064589285) # make it deterministic
r.seed(1064589285) # make it deterministic
run = []
actions = []
try:
......@@ -117,7 +120,7 @@ def _runTasks(rounds, *tasks):
run.append(t)
t.doStep()
actions.append(repr(t.startnum))
commit(run,actions)
commit(run, actions)
# stderr.write(' '.join(actions)+'\n')
finally:
for t in tasks:
......@@ -160,13 +163,14 @@ class StressThread(FailableThread):
self.commitdict[self] = 1
if self.sleep:
time.sleep(self.sleep)
except (ReadConflictError, ConflictError) as msg:
except (ReadConflictError, ConflictError):
tm.abort()
else:
self.added_keys.append(key)
key += self.step
cn.close()
class LargeUpdatesThread(FailableThread):
# A thread that performs a lot of updates. It attempts to modify
......@@ -195,7 +199,7 @@ class LargeUpdatesThread(FailableThread):
# print("%d getting tree abort" % self.threadnum)
transaction.abort()
keys_added = {} # set of keys we commit
keys_added = {} # set of keys we commit
tkeys = []
while not self.stop.isSet():
......@@ -212,7 +216,7 @@ class LargeUpdatesThread(FailableThread):
for key in keys:
try:
tree[key] = self.threadnum
except (ReadConflictError, ConflictError) as msg:
except (ReadConflictError, ConflictError): # as msg:
# print("%d setting key %s" % (self.threadnum, msg))
transaction.abort()
break
......@@ -224,7 +228,7 @@ class LargeUpdatesThread(FailableThread):
self.commitdict[self] = 1
if self.sleep:
time.sleep(self.sleep)
except ConflictError as msg:
except ConflictError: # as msg
# print("%d commit %s" % (self.threadnum, msg))
transaction.abort()
continue
......@@ -234,6 +238,7 @@ class LargeUpdatesThread(FailableThread):
self.added_keys = keys_added.keys()
cn.close()
class InvalidationTests(object):
# Minimum # of seconds the main thread lets the workers run. The
......@@ -261,7 +266,7 @@ class InvalidationTests(object):
transaction.abort()
else:
raise
except:
except: # NOQA: E722 bare except
display(tree)
raise
......
......@@ -21,6 +21,7 @@ from ZODB.Connection import TransactionMetaData
from ..asyncio.testing import AsyncRPC
class IterationTests(object):
def _assertIteratorIdsEmpty(self):
......@@ -44,7 +45,7 @@ class IterationTests(object):
# everything goes away as expected.
gc.enable()
gc.collect()
gc.collect() # sometimes PyPy needs it twice to clear weak refs
gc.collect() # sometimes PyPy needs it twice to clear weak refs
self._storage._iterator_gc()
......@@ -147,7 +148,6 @@ class IterationTests(object):
self._dostore()
six.advance_iterator(self._storage.iterator())
iid = list(self._storage._iterator_ids)[0]
t = TransactionMetaData()
self._storage.tpc_begin(t)
# Show that after disconnecting, the client side GCs the iterators
......@@ -176,12 +176,12 @@ def iterator_sane_after_reconnect():
Start a server:
>>> addr, adminaddr = start_server(
>>> addr, adminaddr = start_server( # NOQA: F821 undefined
... '<filestorage>\npath fs\n</filestorage>', keep=1)
Open a client storage to it and commit a some transactions:
>>> import ZEO, ZODB, transaction
>>> import ZEO, ZODB
>>> client = ZEO.client(addr)
>>> db = ZODB.DB(client)
>>> conn = db.open()
......@@ -196,10 +196,11 @@ Create an iterator:
Restart the storage:
>>> stop_server(adminaddr)
>>> wait_disconnected(client)
>>> _ = start_server('<filestorage>\npath fs\n</filestorage>', addr=addr)
>>> wait_connected(client)
>>> stop_server(adminaddr) # NOQA: F821 undefined
>>> wait_disconnected(client) # NOQA: F821 undefined
>>> _ = start_server( # NOQA: F821 undefined
... '<filestorage>\npath fs\n</filestorage>', addr=addr)
>>> wait_connected(client) # NOQA: F821 undefined
Now, we'll create a second iterator:
......
......@@ -16,6 +16,7 @@ import threading
import sys
import six
class TestThread(threading.Thread):
"""Base class for defining threads that run from unittest.
......@@ -46,12 +47,14 @@ class TestThread(threading.Thread):
def run(self):
try:
self.testrun()
except:
except: # NOQA: E722 blank except
self._exc_info = sys.exc_info()
def cleanup(self, timeout=15):
self.join(timeout)
if self._exc_info:
six.reraise(self._exc_info[0], self._exc_info[1], self._exc_info[2])
six.reraise(self._exc_info[0],
self._exc_info[1],
self._exc_info[2])
if self.is_alive():
self._testcase.fail("Thread did not finish: %s" % self)
......@@ -21,6 +21,7 @@ import ZEO.Exceptions
ZERO = '\0'*8
class BasicThread(threading.Thread):
def __init__(self, storage, doNextEvent, threadStartedEvent):
self.storage = storage
......@@ -123,7 +124,6 @@ class ThreadTests(object):
# Helper for checkMTStores
def mtstorehelper(self):
name = threading.currentThread().getName()
objs = []
for i in range(10):
objs.append(MinPO("X" * 200000))
......
This diff is collapsed.
......@@ -14,6 +14,7 @@
_auth_modules = {}
def get_module(name):
if name == 'sha':
from auth_sha import StorageClass, SHAClient, Database
......@@ -24,6 +25,7 @@ def get_module(name):
else:
return _auth_modules.get(name)
def register_module(name, storage_class, client, db):
if name in _auth_modules:
raise TypeError("%s is already registred" % name)
......
......@@ -45,6 +45,7 @@ from ..StorageServer import ZEOStorage
from ZEO.Exceptions import AuthError
from ..hash import sha1
def get_random_bytes(n=8):
try:
b = os.urandom(n)
......@@ -53,9 +54,11 @@ def get_random_bytes(n=8):
b = b"".join(L)
return b
def hexdigest(s):
return sha1(s.encode()).hexdigest()
class DigestDatabase(Database):
def __init__(self, filename, realm=None):
Database.__init__(self, filename, realm)
......@@ -69,6 +72,7 @@ class DigestDatabase(Database):
dig = hexdigest("%s:%s:%s" % (username, self.realm, password))
self._users[username] = dig
def session_key(h_up, nonce):
# The hash itself is a bit too short to be a session key.
# HMAC wants a 64-byte key. We don't want to use h_up
......@@ -77,6 +81,7 @@ def session_key(h_up, nonce):
return (sha1(("%s:%s" % (h_up, nonce)).encode('latin-1')).digest() +
h_up.encode('utf-8')[:44])
class StorageClass(ZEOStorage):
def set_database(self, database):
assert isinstance(database, DigestDatabase)
......@@ -124,6 +129,7 @@ class StorageClass(ZEOStorage):
extensions = [auth_get_challenge, auth_response]
class DigestClient(Client):
extensions = ["auth_get_challenge", "auth_response"]
......
......@@ -22,6 +22,7 @@ from __future__ import print_function
import os
from ..hash import sha1
class Client(object):
# Subclass should override to list the names of methods that
# will be called on the server.
......@@ -32,11 +33,13 @@ class Client(object):
for m in self.extensions:
setattr(self.stub, m, self.stub.extensionMethod(m))
def sort(L):
"""Sort a list in-place and return it."""
L.sort()
return L
class Database(object):
"""Abstracts a password database.
......@@ -49,6 +52,7 @@ class Database(object):
produced from the password string.
"""
realm = None
def __init__(self, filename, realm=None):
"""Creates a new Database
......
......@@ -3,24 +3,26 @@
Implements the HMAC algorithm as described by RFC 2104.
"""
from six.moves import map
from six.moves import zip
def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length).
"""
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying
# hashing module used.
digest_size = None
class HMAC(object):
"""RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247).
"""
def __init__(self, key, msg = None, digestmod = None):
def __init__(self, key, msg=None, digestmod=None):
"""Create a new HMAC object.
key: key for the keyed hash object.
......@@ -49,8 +51,8 @@ class HMAC(object):
if msg is not None:
self.update(msg)
## def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.")
# def clear(self):
# raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg):
"""Update this hashing object with the string msg.
......@@ -85,7 +87,8 @@ class HMAC(object):
return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
def new(key, msg=None, digestmod=None):
"""Create a new hashing object and return it.
key: The starting key for the hash.
......
......@@ -47,6 +47,7 @@ else:
if zeo_dist is not None:
zeo_version = zeo_dist.version
class StorageStats(object):
"""Per-storage usage statistics."""
......@@ -113,6 +114,7 @@ class StorageStats(object):
print("Conflicts:", self.conflicts, file=f)
print("Conflicts resolved:", self.conflicts_resolved, file=f)
class StatsClient(asyncore.dispatcher):
def __init__(self, sock, addr):
......@@ -144,6 +146,7 @@ class StatsClient(asyncore.dispatcher):
if self.closed and not self.buf:
asyncore.dispatcher.close(self)
class StatsServer(asyncore.dispatcher):
StatsConnectionClass = StatsClient
......
......@@ -49,21 +49,24 @@ from zdaemon.zdoptions import ZDOptions
logger = logging.getLogger('ZEO.runzeo')
_pid = str(os.getpid())
def log(msg, level=logging.INFO, exc_info=False):
"""Internal: generic logging function."""
message = "(%s) %s" % (_pid, msg)
logger.log(level, message, exc_info=exc_info)
def parse_binding_address(arg):
# Caution: Not part of the official ZConfig API.
obj = ZConfig.datatypes.SocketBindingAddress(arg)
return obj.family, obj.address
def windows_shutdown_handler():
# Called by the signal mechanism on Windows to perform shutdown.
import asyncore
asyncore.close_all()
class ZEOOptionsMixin(object):
storages = None
......@@ -75,14 +78,18 @@ class ZEOOptionsMixin(object):
self.monitor_family, self.monitor_address = parse_binding_address(arg)
def handle_filename(self, arg):
from ZODB.config import FileStorage # That's a FileStorage *opener*!
from ZODB.config import FileStorage # That's a FileStorage *opener*!
class FSConfig(object):
def __init__(self, name, path):
self._name = name
self.path = path
self.stop = None
def getSectionName(self):
return self._name
if not self.storages:
self.storages = []
name = str(1 + len(self.storages))
......@@ -90,6 +97,7 @@ class ZEOOptionsMixin(object):
self.storages.append(conf)
testing_exit_immediately = False
def handle_test(self, *args):
self.testing_exit_immediately = True
......@@ -117,6 +125,7 @@ class ZEOOptionsMixin(object):
self.add('pid_file', 'zeo.pid_filename',
None, 'pid-file=')
class ZEOOptions(ZDOptions, ZEOOptionsMixin):
__doc__ = __doc__
......@@ -179,8 +188,8 @@ class ZEOServer(object):
root.addHandler(handler)
def check_socket(self):
if (isinstance(self.options.address, tuple) and
self.options.address[1] is None):
if isinstance(self.options.address, tuple) and \
self.options.address[1] is None:
self.options.address = self.options.address[0], 0
return
if self.can_connect(self.options.family, self.options.address):
......@@ -224,7 +233,7 @@ class ZEOServer(object):
self.setup_win32_signals()
return
if hasattr(signal, 'SIGXFSZ'):
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
signal.signal(signal.SIGXFSZ, signal.SIG_IGN) # Special case
init_signames()
for sig, name in signames.items():
method = getattr(self, "handle_" + name.lower(), None)
......@@ -244,12 +253,12 @@ class ZEOServer(object):
"will *not* be installed.")
return
SignalHandler = Signals.Signals.SignalHandler
if SignalHandler is not None: # may be None if no pywin32.
if SignalHandler is not None: # may be None if no pywin32.
SignalHandler.registerHandler(signal.SIGTERM,
windows_shutdown_handler)
SignalHandler.registerHandler(signal.SIGINT,
windows_shutdown_handler)
SIGUSR2 = 12 # not in signal module on Windows.
SIGUSR2 = 12 # not in signal module on Windows.
SignalHandler.registerHandler(SIGUSR2, self.handle_sigusr2)
def create_server(self):
......@@ -275,20 +284,21 @@ class ZEOServer(object):
def handle_sigusr2(self):
# log rotation signal - do the same as Zope 2.7/2.8...
if self.options.config_logger is None or os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!",
if self.options.config_logger is None or \
os.name not in ("posix", "nt"):
log("received SIGUSR2, but it was not handled!",
level=logging.WARNING)
return
loggers = [self.options.config_logger]
if os.name == "posix":
for l in loggers:
l.reopen()
for logger in loggers:
logger.reopen()
log("Log files reopened successfully", level=logging.INFO)
else: # nt - same rotation code as in Zope's Signals/Signals.py
for l in loggers:
for f in l.handler_factories:
else: # nt - same rotation code as in Zope's Signals/Signals.py
for logger in loggers:
for f in logger.handler_factories:
handler = f()
if hasattr(handler, 'rotate') and callable(handler.rotate):
handler.rotate()
......@@ -347,14 +357,14 @@ def create_server(storages, options):
return StorageServer(
options.address,
storages,
read_only = options.read_only,
invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout,
monitor_address = options.monitor_address,
auth_protocol = options.auth_protocol,
auth_database = options.auth_database,
auth_realm = options.auth_realm,
read_only=options.read_only,
invalidation_queue_size=options.invalidation_queue_size,
invalidation_age=options.invalidation_age,
transaction_timeout=options.transaction_timeout,
monitor_address=options.monitor_address,
auth_protocol=options.auth_protocol,
auth_database=options.auth_database,
auth_realm=options.auth_realm,
)
......@@ -362,6 +372,7 @@ def create_server(storages, options):
signames = None
def signame(sig):
"""Return a symbolic name for a signal.
......@@ -373,6 +384,7 @@ def signame(sig):
init_signames()
return signames.get(sig) or "signal %d" % sig
def init_signames():
global signames
signames = {}
......@@ -392,5 +404,6 @@ def main(args=None):
s = ZEOServer(options)
s.main()
if __name__ == "__main__":
main()
......@@ -6,24 +6,26 @@
Implements the HMAC algorithm as described by RFC 2104.
"""
from six.moves import map
from six.moves import zip
def _strxor(s1, s2):
"""Utility method. XOR the two strings s1 and s2 (must have same length).
"""
return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), s1, s2))
# The size of the digests returned by HMAC depends on the underlying
# hashing module used.
digest_size = None
class HMAC(object):
"""RFC2104 HMAC class.
This supports the API for Cryptographic Hash Functions (PEP 247).
"""
def __init__(self, key, msg = None, digestmod = None):
def __init__(self, key, msg=None, digestmod=None):
"""Create a new HMAC object.
key: key for the keyed hash object.
......@@ -56,8 +58,8 @@ class HMAC(object):
if msg is not None:
self.update(msg)
## def clear(self):
## raise NotImplementedError("clear() method not available in HMAC.")
# def clear(self):
# raise NotImplementedError("clear() method not available in HMAC.")
def update(self, msg):
"""Update this hashing object with the string msg.
......@@ -92,7 +94,8 @@ class HMAC(object):
return "".join([hex(ord(x))[2:].zfill(2)
for x in tuple(self.digest())])
def new(key, msg = None, digestmod = None):
def new(key, msg=None, digestmod=None):
"""Create a new hashing object and return it.
key: The starting key for the hash.
......
......@@ -34,6 +34,7 @@ from six.moves import map
def client_timeout():
return 30.0
def client_loop(map):
read = asyncore.read
write = asyncore.write
......@@ -52,7 +53,7 @@ def client_loop(map):
r, w, e = select.select(r, w, e, client_timeout())
except (select.error, RuntimeError) as err:
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
# which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno != errno.EINTR:
if err_errno == errno.EBADF:
......@@ -114,14 +115,13 @@ def client_loop(map):
continue
_exception(obj)
except:
except: # NOQA: E722 bare except
if map:
try:
logging.getLogger(__name__+'.client_loop').critical(
'A ZEO client loop failed.',
exc_info=sys.exc_info())
except:
except: # NOQA: E722 bare except
pass
for fd, obj in map.items():
......@@ -129,14 +129,14 @@ def client_loop(map):
continue
try:
obj.mgr.client.close()
except:
except: # NOQA: E722 bare except
map.pop(fd, None)
try:
logging.getLogger(__name__+'.client_loop'
).critical(
"Couldn't close a dispatcher.",
exc_info=sys.exc_info())
except:
except: # NOQA: E722 bare except
pass
......@@ -152,11 +152,11 @@ class ConnectionManager(object):
self.tmin = min(tmin, tmax)
self.tmax = tmax
self.cond = threading.Condition(threading.Lock())
self.connection = None # Protected by self.cond
self.connection = None # Protected by self.cond
self.closed = 0
# If thread is not None, then there is a helper thread
# attempting to connect.
self.thread = None # Protected by self.cond
self.thread = None # Protected by self.cond
def new_addrs(self, addrs):
self.addrlist = self._parse_addrs(addrs)
......@@ -189,7 +189,8 @@ class ConnectionManager(object):
for addr in addrs:
addr_type = self._guess_type(addr)
if addr_type is None:
raise ValueError("unknown address in list: %s" % repr(addr))
raise ValueError(
"unknown address in list: %s" % repr(addr))
addrlist.append((addr_type, addr))
return addrlist
......@@ -197,10 +198,10 @@ class ConnectionManager(object):
if isinstance(addr, str):
return socket.AF_UNIX
if (len(addr) == 2
and isinstance(addr[0], str)
and isinstance(addr[1], int)):
return socket.AF_INET # also denotes IPv6
if len(addr) == 2 and \
isinstance(addr[0], str) and \
isinstance(addr[1], int):
return socket.AF_INET # also denotes IPv6
# not anything I know about
return None
......@@ -226,7 +227,7 @@ class ConnectionManager(object):
if obj is not self.trigger:
try:
obj.close()
except:
except: # NOQA: E722 bare except
logging.getLogger(__name__+'.'+self.__class__.__name__
).critical(
"Couldn't close a dispatcher.",
......@@ -237,7 +238,7 @@ class ConnectionManager(object):
try:
self.loop_thread.join(9)
except RuntimeError:
pass # we are the thread :)
pass # we are the thread :)
self.trigger.close()
def attempt_connect(self):
......@@ -304,7 +305,7 @@ class ConnectionManager(object):
self.connection = conn
if preferred:
self.thread = None
self.cond.notifyAll() # Wake up connect(sync=1)
self.cond.notifyAll() # Wake up connect(sync=1)
finally:
self.cond.release()
......@@ -331,6 +332,7 @@ class ConnectionManager(object):
finally:
self.cond.release()
# When trying to do a connect on a non-blocking socket, some outcomes
# are expected. Set _CONNECT_IN_PROGRESS to the errno value(s) expected
# when an initial connect can't complete immediately. Set _CONNECT_OK
......@@ -342,10 +344,11 @@ if hasattr(errno, "WSAEWOULDBLOCK"): # Windows
# seen this.
_CONNECT_IN_PROGRESS = (errno.WSAEWOULDBLOCK,)
# Win98: WSAEISCONN; Win2K: WSAEINVAL
_CONNECT_OK = (0, errno.WSAEISCONN, errno.WSAEINVAL)
_CONNECT_OK = (0, errno.WSAEISCONN, errno.WSAEINVAL)
else: # Unix
_CONNECT_IN_PROGRESS = (errno.EINPROGRESS,)
_CONNECT_OK = (0, errno.EISCONN)
_CONNECT_OK = (0, errno.EISCONN)
class ConnectThread(threading.Thread):
"""Thread that tries to connect to server given one or more addresses.
......@@ -455,7 +458,7 @@ class ConnectThread(threading.Thread):
) in socket.getaddrinfo(host or 'localhost', port,
socket.AF_INET,
socket.SOCK_STREAM
): # prune non-TCP results
): # prune non-TCP results
# for IPv6, drop flowinfo, and restrict addresses
# to [host]:port
yield family, sockaddr[:2]
......@@ -495,7 +498,7 @@ class ConnectThread(threading.Thread):
break
try:
r, w, x = select.select([], connecting, connecting, 1.0)
log("CT: select() %d, %d, %d" % tuple(map(len, (r,w,x))))
log("CT: select() %d, %d, %d" % tuple(map(len, (r, w, x))))
except select.error as msg:
log("CT: select failed; msg=%s" % str(msg),
level=logging.WARNING)
......@@ -509,7 +512,7 @@ class ConnectThread(threading.Thread):
for wrap in w:
wrap.connect_procedure()
if wrap.state == "notified":
del wrappers[wrap] # Don't close this one
del wrappers[wrap] # Don't close this one
for wrap in wrappers.keys():
wrap.close()
return 1
......@@ -526,7 +529,7 @@ class ConnectThread(threading.Thread):
else:
wrap.notify_client()
if wrap.state == "notified":
del wrappers[wrap] # Don't close this one
del wrappers[wrap] # Don't close this one
for wrap in wrappers.keys():
wrap.close()
return -1
......@@ -602,7 +605,7 @@ class ConnectWrapper(object):
to do app-level check of the connection.
"""
self.conn = ManagedClientConnection(self.sock, self.addr, self.mgr)
self.sock = None # The socket is now owned by the connection
self.sock = None # The socket is now owned by the connection
try:
self.preferred = self.client.testConnection(self.conn)
self.state = "tested"
......@@ -610,7 +613,7 @@ class ConnectWrapper(object):
log("CW: ReadOnlyError in testConnection (%s)" % repr(self.addr))
self.close()
return
except:
except: # NOQA: E722 bare except
log("CW: error in testConnection (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True)
self.close()
......@@ -629,7 +632,7 @@ class ConnectWrapper(object):
"""
try:
self.client.notifyConnected(self.conn)
except:
except: # NOQA: E722 bare except
log("CW: error in notifyConnected (%s)" % repr(self.addr),
level=logging.ERROR, exc_info=True)
self.close()
......
......@@ -26,12 +26,13 @@ from .log import short_repr, log
from ZODB.loglevels import BLATHER, TRACE
import ZODB.POSException
REPLY = ".reply" # message name used for replies
REPLY = ".reply" # message name used for replies
exception_type_type = type(Exception)
debug_zrpc = False
class Delay(object):
"""Used to delay response to client for synchronous calls.
......@@ -57,7 +58,9 @@ class Delay(object):
def __repr__(self):
return "%s[%s, %r, %r, %r]" % (
self.__class__.__name__, id(self), self.msgid, self.conn, self.sent)
self.__class__.__name__, id(self), self.msgid,
self.conn, self.sent)
class Result(Delay):
......@@ -69,6 +72,7 @@ class Result(Delay):
conn.send_reply(msgid, reply, False)
callback()
class MTDelay(Delay):
def __init__(self):
......@@ -147,6 +151,7 @@ class MTDelay(Delay):
# supply a handshake() method appropriate for their role in protocol
# negotiation.
class Connection(smac.SizedMessageAsyncConnection, object):
"""Dispatcher for RPC on object on both sides of socket.
......@@ -294,7 +299,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
self.fast_encode = marshal.fast_encode
self.closed = False
self.peer_protocol_version = None # set in recv_handshake()
self.peer_protocol_version = None # set in recv_handshake()
assert tag in b"CS"
self.tag = tag
......@@ -359,7 +364,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.addr)
__str__ = __repr__ # Defeat asyncore's dreaded __getattr__
__str__ = __repr__ # Defeat asyncore's dreaded __getattr__
def log(self, message, level=BLATHER, exc_info=False):
self.logger.log(level, self.log_label + message, exc_info=exc_info)
......@@ -441,7 +446,7 @@ class Connection(smac.SizedMessageAsyncConnection, object):
try:
self.message_output(self.fast_encode(msgid, 0, REPLY, ret))
self.poll()
except:
except: # NOQA: E722 bare except
# Fall back to normal version for better error handling
self.send_reply(msgid, ret)
......@@ -520,10 +525,10 @@ class Connection(smac.SizedMessageAsyncConnection, object):
# cPickle may raise.
try:
msg = self.encode(msgid, 0, REPLY, (err_type, err_value))
except: # see above
except: # NOQA: E722 bare except; see above
try:
r = short_repr(err_value)
except:
except: # NOQA: E722 bare except
r = "<unreprable>"
err = ZRPCError("Couldn't pickle error %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
......@@ -656,10 +661,10 @@ class ManagedServerConnection(Connection):
# cPickle may raise.
try:
msg = self.encode(msgid, 0, REPLY, ret)
except: # see above
except: # NOQA: E722 bare except; see above
try:
r = short_repr(ret)
except:
except: # NOQA: E722 bare except
r = "<unreprable>"
err = ZRPCError("Couldn't pickle return %.100s" % r)
msg = self.encode(msgid, 0, REPLY, (ZRPCError, err))
......@@ -669,6 +674,7 @@ class ManagedServerConnection(Connection):
poll = smac.SizedMessageAsyncConnection.handle_write
def server_loop(map):
while len(map) > 1:
try:
......@@ -680,6 +686,7 @@ def server_loop(map):
for o in tuple(map.values()):
o.close()
class ManagedClientConnection(Connection):
"""Client-side Connection subclass."""
__super_init = Connection.__init__
......@@ -740,7 +747,7 @@ class ManagedClientConnection(Connection):
# are queued for the duration. The client will send its own
# handshake after the server's handshake is seen, in recv_handshake()
# below. It will then send any messages queued while waiting.
assert self.queue_output # the constructor already set this
assert self.queue_output # the constructor already set this
def recv_handshake(self, proto):
# The protocol to use is the older of our and the server's preferred
......@@ -778,11 +785,11 @@ class ManagedClientConnection(Connection):
raise DisconnectedError()
msgid = self.send_call(method, args)
r_args = self.wait(msgid)
if (isinstance(r_args, tuple) and len(r_args) > 1
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
if isinstance(r_args, tuple) and len(r_args) > 1 and \
type(r_args[0]) == exception_type_type and \
issubclass(r_args[0], Exception):
inst = r_args[1]
raise inst # error raised by server
raise inst # error raised by server
else:
return r_args
......@@ -821,11 +828,11 @@ class ManagedClientConnection(Connection):
def _deferred_wait(self, msgid):
r_args = self.wait(msgid)
if (isinstance(r_args, tuple)
and type(r_args[0]) == exception_type_type
and issubclass(r_args[0], Exception)):
if isinstance(r_args, tuple) and \
type(r_args[0]) == exception_type_type and \
issubclass(r_args[0], Exception):
inst = r_args[1]
raise inst # error raised by server
raise inst # error raised by server
else:
return r_args
......
......@@ -14,9 +14,11 @@
from ZODB import POSException
from ZEO.Exceptions import ClientDisconnected
class ZRPCError(POSException.StorageError):
pass
class DisconnectedError(ZRPCError, ClientDisconnected):
"""The database storage is disconnected from the storage server.
......
......@@ -17,24 +17,29 @@ import logging
from ZODB.loglevels import BLATHER
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
LOG_THREAD_ID = 0 # Set this to 1 during heavy debugging
logger = logging.getLogger('ZEO.zrpc')
_label = "%s" % os.getpid()
def new_label():
global _label
_label = str(os.getpid())
def log(message, level=BLATHER, label=None, exc_info=False):
label = label or _label
if LOG_THREAD_ID:
label = label + ':' + threading.currentThread().getName()
logger.log(level, '(%s) %s' % (label, message), exc_info=exc_info)
REPR_LIMIT = 60
def short_repr(obj):
"Return an object repr limited to REPR_LIMIT bytes."
......
......@@ -19,7 +19,8 @@ from .log import log, short_repr
PY2 = not PY3
def encode(*args): # args: (msgid, flags, name, args)
def encode(*args): # args: (msgid, flags, name, args)
# (We used to have a global pickler, but that's not thread-safe. :-( )
# It's not thread safe if, in the couse of pickling, we call the
......@@ -41,7 +42,6 @@ def encode(*args): # args: (msgid, flags, name, args)
return res
if PY3:
# XXX: Py3: Needs optimization.
fast_encode = encode
......@@ -50,48 +50,57 @@ elif PYPY:
# every time, getvalue() only works once
fast_encode = encode
else:
def fast_encode():
# Only use in cases where you *know* the data contains only basic
# Python objects
pickler = Pickler(1)
pickler.fast = 1
dump = pickler.dump
def fast_encode(*args):
return dump(args, 1)
return fast_encode
fast_encode = fast_encode()
def decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global
try:
unpickler.find_class = find_global # PyPy, zodbpickle, the non-c-accelerated version
# PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = find_global
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
return unpickler.load() # msgid, flags, name, args
except: # NOQA: E722 bare except
log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR)
raise
def server_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global
try:
unpickler.find_class = server_find_global # PyPy, zodbpickle, the non-c-accelerated version
# PyPy, zodbpickle, the non-c-accelerated version
unpickler.find_class = server_find_global
except AttributeError:
pass
try:
return unpickler.load() # msgid, flags, name, args
except:
return unpickler.load() # msgid, flags, name, args
except: # NOQA: E722 bare except
log("can't decode message: %s" % short_repr(msg),
level=logging.ERROR)
raise
_globals = globals()
_silly = ('__doc__',)
......@@ -102,6 +111,7 @@ _SAFE_MODULE_NAMES = (
'builtins', 'copy_reg', '__builtin__',
)
def find_global(module, name):
"""Helper for message unpickler"""
try:
......@@ -114,7 +124,8 @@ def find_global(module, name):
except AttributeError:
raise ZRPCError("module %s has no global %s" % (module, name))
safe = getattr(r, '__no_side_effects__', 0) or (PY2 and module in _SAFE_MODULE_NAMES)
safe = (getattr(r, '__no_side_effects__', 0) or
(PY2 and module in _SAFE_MODULE_NAMES))
if safe:
return r
......@@ -124,6 +135,7 @@ def find_global(module, name):
raise ZRPCError("Unsafe global: %s.%s" % (module, name))
def server_find_global(module, name):
"""Helper for message unpickler"""
if module not in _SAFE_MODULE_NAMES:
......
......@@ -13,6 +13,7 @@
##############################################################################
import asyncore
import socket
import time
# _has_dualstack: True if the dual-stack sockets are supported
try:
......@@ -39,6 +40,7 @@ import logging
# Export the main asyncore loop
loop = asyncore.loop
class Dispatcher(asyncore.dispatcher):
"""A server that accepts incoming RPC connections"""
__super_init = asyncore.dispatcher.__init__
......@@ -74,7 +76,7 @@ class Dispatcher(asyncore.dispatcher):
for i in range(25):
try:
self.bind(self.addr)
except Exception as exc:
except Exception:
log("bind failed %s waiting", i)
if i == 24:
raise
......@@ -98,7 +100,6 @@ class Dispatcher(asyncore.dispatcher):
log("accepted failed: %s" % msg)
return
# We could short-circuit the attempt below in some edge cases
# and avoid a log message by checking for addr being None.
# Unfortunately, our test for the code below,
......@@ -111,12 +112,12 @@ class Dispatcher(asyncore.dispatcher):
# closed, but I don't see a way to do that. :(
# Drop flow-info from IPv6 addresses
if addr: # Sometimes None on Mac. See above.
if addr: # Sometimes None on Mac. See above.
addr = addr[:2]
try:
c = self.factory(sock, addr)
except:
except: # NOQA: E722 bare except
if sock.fileno() in asyncore.socket_map:
del asyncore.socket_map[sock.fileno()]
logger.exception("Error in handle_accept")
......
......@@ -67,19 +67,20 @@ MAC_BIT = 0x80000000
_close_marker = object()
class SizedMessageAsyncConnection(asyncore.dispatcher):
__super_init = asyncore.dispatcher.__init__
__super_close = asyncore.dispatcher.close
__closed = True # Marker indicating that we're closed
__closed = True # Marker indicating that we're closed
socket = None # to outwit Sam's getattr
socket = None # to outwit Sam's getattr
def __init__(self, sock, addr, map=None):
self.addr = addr
# __input_lock protects __inp, __input_len, __state, __msg_size
self.__input_lock = threading.Lock()
self.__inp = None # None, a single String, or a list
self.__inp = None # None, a single String, or a list
self.__input_len = 0
# Instance variables __state, __msg_size and __has_mac work together:
# when __state == 0:
......@@ -168,7 +169,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
d = self.recv(8192)
except socket.error as err:
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
# which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_read_errors:
return
......@@ -190,7 +191,7 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
else:
self.__inp.append(d)
self.__input_len = input_len
return # keep waiting for more input
return # keep waiting for more input
# load all previous input and d into single string inp
if isinstance(inp, six.binary_type):
......@@ -298,15 +299,15 @@ class SizedMessageAsyncConnection(asyncore.dispatcher):
# ensure the above mentioned "output" invariant
output.insert(0, v)
# Python >= 3.3 makes select.error an alias of OSError,
# which is not subscriptable but does have the 'errno' attribute
# which is not subscriptable but does have a 'errno' attribute
err_errno = getattr(err, 'errno', None) or err[0]
if err_errno in expected_socket_write_errors:
break # we couldn't write anything
break # we couldn't write anything
raise
if n < len(v):
output.append(v[n:])
break # we can't write any more
break # we can't write any more
def handle_close(self):
self.close()
......
......@@ -21,7 +21,7 @@ import socket
import errno
from ZODB.utils import positive_id
from ZEO._compat import thread, get_ident
from ZEO._compat import thread
# Original comments follow; they're hard to follow in the context of
# ZEO's use of triggers. TODO: rewrite from a ZEO perspective.
......@@ -56,6 +56,7 @@ from ZEO._compat import thread, get_ident
# new data onto a channel's outgoing data queue at the same time that
# the main thread is trying to remove some]
class _triggerbase(object):
"""OS-independent base class for OS-dependent trigger class."""
......@@ -127,7 +128,7 @@ class _triggerbase(object):
return
try:
thunk[0](*thunk[1:])
except:
except: # NOQA: E722 bare except
nil, t, v, tbinfo = asyncore.compact_traceback()
print(('exception in trigger thunk:'
' (%s:%s %s)' % (t, v, tbinfo)))
......@@ -135,6 +136,7 @@ class _triggerbase(object):
def __repr__(self):
return '<select-trigger (%s) at %x>' % (self.kind, positive_id(self))
if os.name == 'posix':
class trigger(_triggerbase, asyncore.file_dispatcher):
......@@ -187,39 +189,39 @@ else:
count = 0
while 1:
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
connect_address = a.getsockname() # assigned (host, port) pair
a.listen(1)
try:
w.connect(connect_address)
break # success
except socket.error as detail:
if detail[0] != errno.WSAEADDRINUSE:
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
w.close()
raise BindError("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
count += 1
# Bind to a local port; for efficiency, let the OS pick
# a free port for us.
# Unfortunately, stress tests showed that we may not
# be able to connect to that port ("Address already in
# use") despite that the OS picked it. This appears
# to be a race bug in the Windows socket implementation.
# So we loop until a connect() succeeds (almost always
# on the first try). See the long thread at
# http://mail.zope.org/pipermail/zope/2005-July/160433.html
# for hideous details.
a = socket.socket()
a.bind(("127.0.0.1", 0))
connect_address = a.getsockname() # assigned (host, port) pair
a.listen(1)
try:
w.connect(connect_address)
break # success
except socket.error as detail:
if detail[0] != errno.WSAEADDRINUSE:
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
raise
# (10048, 'Address already in use')
# assert count <= 2 # never triggered in Tim's tests
if count >= 10: # I've never seen it go above 2
a.close()
w.close()
raise BindError("Cannot bind trigger!")
# Close `a` and try again. Note: I originally put a short
# sleep() here, but it didn't appear to help or hurt.
a.close()
r, addr = a.accept() # r becomes asyncore's (self.)socket
a.close()
......
......@@ -16,7 +16,6 @@ from __future__ import print_function
import random
import sys
import time
......@@ -56,8 +55,8 @@ def encode_format(fmt):
fmt = fmt.replace(*xform)
return fmt
runner = _forker.runner
runner = _forker.runner
stop_runner = _forker.stop_runner
start_zeo_server = _forker.start_zeo_server
......@@ -70,6 +69,7 @@ else:
shutdown_zeo_server = _forker.shutdown_zeo_server
def get_port(ignored=None):
"""Return a port that is not in use.
......@@ -107,6 +107,7 @@ def get_port(ignored=None):
s1.close()
raise RuntimeError("Can't find port")
def can_connect(port):
c = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
......@@ -119,6 +120,7 @@ def can_connect(port):
finally:
c.close()
def setUp(test):
ZODB.tests.util.setUp(test)
......@@ -194,9 +196,11 @@ def wait_until(label=None, func=None, timeout=30, onfail=None):
return onfail()
time.sleep(0.01)
def wait_connected(storage):
wait_until("storage is connected", storage.is_connected)
def wait_disconnected(storage):
wait_until("storage is disconnected",
lambda: not storage.is_connected())
......
......@@ -34,6 +34,7 @@ import ZEO.asyncio.tests
import ZEO.StorageServer
import ZODB.MappingStorage
class StorageServer(ZEO.StorageServer.StorageServer):
def __init__(self, addr='test_addr', storages=None, **kw):
......@@ -41,6 +42,7 @@ class StorageServer(ZEO.StorageServer.StorageServer):
storages = {'1': ZODB.MappingStorage.MappingStorage()}
ZEO.StorageServer.StorageServer.__init__(self, addr, storages, **kw)
def client(server, name='client'):
zs = ZEO.StorageServer.ZEOStorage(server)
protocol = ZEO.asyncio.tests.server_protocol(
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -16,15 +16,18 @@ import unittest
from ZEO.TransactionBuffer import TransactionBuffer
def random_string(size):
"""Return a random string of size size."""
l = [chr(random.randrange(256)) for i in range(size)]
return "".join(l)
lst = [chr(random.randrange(256)) for i in range(size)]
return "".join(lst)
def new_store_data():
"""Return arbitrary data to use as argument to store() method."""
return random_string(8), random_string(random.randrange(1000))
def store(tbuf, resolved=False):
data = new_store_data()
tbuf.store(*data)
......@@ -32,6 +35,7 @@ def store(tbuf, resolved=False):
tbuf.server_resolve(data[0])
return data
class TransBufTests(unittest.TestCase):
def checkTypicalUsage(self):
......@@ -54,5 +58,6 @@ class TransBufTests(unittest.TestCase):
self.assertEqual(resolved, data[i][1])
tbuf.close()
def test_suite():
return unittest.makeSuite(TransBufTests, 'check')
This diff is collapsed.
This diff is collapsed.
......@@ -27,6 +27,7 @@ from zdaemon.tests.testzdoptions import TestZDOptions
# supplies the empty string.
DEFAULT_BINDING_HOST = ""
class TestZEOOptions(TestZDOptions):
OptionsClass = ZEOOptions
......@@ -59,7 +60,7 @@ class TestZEOOptions(TestZDOptions):
# Hide the base class test_configure
pass
def test_default_help(self): pass # disable silly test w spurious failures
def test_default_help(self): pass # disable silly test w spurious failures
def test_defaults_with_schema(self):
options = self.OptionsClass()
......@@ -106,5 +107,6 @@ def test_suite():
suite.addTest(unittest.makeSuite(cls))
return suite
if __name__ == "__main__":
unittest.main(defaultTest='test_suite')
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -9,4 +9,3 @@ import ZODB.tests.util
threaded_server_tests = ZODB.tests.util.MininalTestLayer(
'threaded_server_tests')
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment