Commit b31fed14 authored by Jim Fulton's avatar Jim Fulton

Implemented msgpack as an optional ZEO message encoding with basic tests.

parent c3183420
...@@ -289,6 +289,13 @@ client-conflict-resolution ...@@ -289,6 +289,13 @@ client-conflict-resolution
Flag indicating that clients should perform conflict Flag indicating that clients should perform conflict
resolution. This option defaults to false. resolution. This option defaults to false.
msgpack
Use msgpack to serialize and de-serialize ZEO protocol messages.
An advantage of using msgpack for ZEO communication is that
it's a little bit faster and a ZEO server can support Python 2
or Python 3 clients (but not both).
Server SSL configuration Server SSL configuration
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -36,7 +36,7 @@ install_requires = [ ...@@ -36,7 +36,7 @@ install_requires = [
'zope.interface', 'zope.interface',
] ]
tests_require = ['zope.testing', 'manuel', 'random2', 'mock'] tests_require = ['zope.testing', 'manuel', 'random2', 'mock', 'msgpack-python']
if sys.version_info[:2] < (3, ): if sys.version_info[:2] < (3, ):
install_requires.extend(('futures', 'trollius')) install_requires.extend(('futures', 'trollius'))
...@@ -128,7 +128,11 @@ setup(name="ZEO", ...@@ -128,7 +128,11 @@ setup(name="ZEO",
classifiers = classifiers, classifiers = classifiers,
test_suite="__main__.alltests", # to support "setup.py test" test_suite="__main__.alltests", # to support "setup.py test"
tests_require = tests_require, tests_require = tests_require,
extras_require = dict(test=tests_require, uvloop=['uvloop >=0.5.1']), extras_require = dict(
test=tests_require,
uvloop=['uvloop >=0.5.1'],
msgpack=['msgpack-python'],
),
install_requires = install_requires, install_requires = install_requires,
zip_safe = False, zip_safe = False,
entry_points = """ entry_points = """
......
...@@ -663,6 +663,7 @@ class StorageServer: ...@@ -663,6 +663,7 @@ class StorageServer:
ssl=None, ssl=None,
client_conflict_resolution=False, client_conflict_resolution=False,
Acceptor=Acceptor, Acceptor=Acceptor,
msgpack=False,
): ):
"""StorageServer constructor. """StorageServer constructor.
...@@ -757,7 +758,7 @@ class StorageServer: ...@@ -757,7 +758,7 @@ class StorageServer:
self.client_conflict_resolution = client_conflict_resolution self.client_conflict_resolution = client_conflict_resolution
if addr is not None: if addr is not None:
self.acceptor = Acceptor(self, addr, ssl) self.acceptor = Acceptor(self, addr, ssl, msgpack)
if isinstance(addr, tuple) and addr[0]: if isinstance(addr, tuple) and addr[0]:
self.addr = self.acceptor.addr self.addr = self.acceptor.addr
else: else:
......
...@@ -10,8 +10,6 @@ import socket ...@@ -10,8 +10,6 @@ import socket
from struct import unpack from struct import unpack
import sys import sys
from .marshal import encoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
INET_FAMILIES = socket.AF_INET, socket.AF_INET6 INET_FAMILIES = socket.AF_INET, socket.AF_INET6
...@@ -129,13 +127,13 @@ class Protocol(asyncio.Protocol): ...@@ -129,13 +127,13 @@ class Protocol(asyncio.Protocol):
self.getting_size = True self.getting_size = True
self.message_received(collected) self.message_received(collected)
except Exception: except Exception:
#import traceback; traceback.print_exc()
logger.exception("data_received %s %s %s", logger.exception("data_received %s %s %s",
self.want, self.got, self.getting_size) self.want, self.got, self.getting_size)
def first_message_received(self, protocol_version): def first_message_received(self, protocol_version):
# Handler for first/handshake message, set up in __init__ # Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on del self.message_received # use default handler from here on
self.encode = encoder()
self.finish_connect(protocol_version) self.finish_connect(protocol_version)
def call_async(self, method, args): def call_async(self, method, args):
......
...@@ -13,7 +13,7 @@ import ZEO.interfaces ...@@ -13,7 +13,7 @@ import ZEO.interfaces
from . import base from . import base
from .compat import asyncio, new_event_loop from .compat import asyncio, new_event_loop
from .marshal import decode from .marshal import encoder, decoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -63,7 +63,7 @@ class Protocol(base.Protocol): ...@@ -63,7 +63,7 @@ class Protocol(base.Protocol):
# One place where special care was required was in cache setup on # One place where special care was required was in cache setup on
# connect. See finish connect below. # connect. See finish connect below.
protocols = b'Z309', b'Z310', b'Z3101', b'Z4', b'Z5' protocols = b'309', b'310', b'3101', b'4', b'5'
def __init__(self, loop, def __init__(self, loop,
addr, client, storage_key, read_only, connect_poll=1, addr, client, storage_key, read_only, connect_poll=1,
...@@ -150,6 +150,8 @@ class Protocol(base.Protocol): ...@@ -150,6 +150,8 @@ class Protocol(base.Protocol):
# We have to be careful processing the futures, because # We have to be careful processing the futures, because
# exception callbacks might modufy them. # exception callbacks might modufy them.
for f in self.pop_futures(): for f in self.pop_futures():
if isinstance(f, tuple):
continue
f.set_exception(ClientDisconnected(exc or 'connection lost')) f.set_exception(ClientDisconnected(exc or 'connection lost'))
self.closed = True self.closed = True
self.client.disconnected(self) self.client.disconnected(self)
...@@ -165,13 +167,17 @@ class Protocol(base.Protocol): ...@@ -165,13 +167,17 @@ class Protocol(base.Protocol):
# lastTid before processing (and possibly missing) subsequent # lastTid before processing (and possibly missing) subsequent
# invalidations. # invalidations.
self.protocol_version = min(protocol_version, self.protocols[-1]) version = min(protocol_version[1:], self.protocols[-1])
if version not in self.protocols:
if self.protocol_version not in self.protocols:
self.client.register_failed( self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version)) self, ZEO.Exceptions.ProtocolError(protocol_version))
return return
self.protocol_version = protocol_version[:1] + version
self.encode = encoder(protocol_version)
self.decode = decoder(protocol_version)
self.heartbeat_bytes = self.encode(-1, 0, '.reply', None)
self._write(self.protocol_version) self._write(self.protocol_version)
credentials = (self.credentials,) if self.credentials else () credentials = (self.credentials,) if self.credentials else ()
...@@ -199,9 +205,12 @@ class Protocol(base.Protocol): ...@@ -199,9 +205,12 @@ class Protocol(base.Protocol):
exception_type_type = type(Exception) exception_type_type = type(Exception)
def message_received(self, data): def message_received(self, data):
msgid, async, name, args = decode(data) msgid, async, name, args = self.decode(data)
if name == '.reply': if name == '.reply':
future = self.futures.pop(msgid) future = self.futures.pop(msgid)
if isinstance(future, tuple):
future = self.futures.pop(future)
if (async): # ZEO 5 exception if (async): # ZEO 5 exception
class_, args = args class_, args = args
factory = exc_factories.get(class_) factory = exc_factories.get(class_)
...@@ -245,13 +254,15 @@ class Protocol(base.Protocol): ...@@ -245,13 +254,15 @@ class Protocol(base.Protocol):
def load_before(self, oid, tid): def load_before(self, oid, tid):
# Special-case loadBefore, so we collapse outstanding requests # Special-case loadBefore, so we collapse outstanding requests
message_id = (oid, tid) oid_tid = (oid, tid)
future = self.futures.get(message_id) future = self.futures.get(oid_tid)
if future is None: if future is None:
future = asyncio.Future(loop=self.loop) future = asyncio.Future(loop=self.loop)
self.futures[message_id] = future self.futures[oid_tid] = future
self.message_id += 1
self.futures[self.message_id] = oid_tid
self._write( self._write(
self.encode(message_id, False, 'loadBefore', (oid, tid))) self.encode(self.message_id, False, 'loadBefore', (oid, tid)))
return future return future
# Methods called by the server. # Methods called by the server.
...@@ -267,7 +278,7 @@ class Protocol(base.Protocol): ...@@ -267,7 +278,7 @@ class Protocol(base.Protocol):
def heartbeat(self, write=True): def heartbeat(self, write=True):
if write: if write:
self._write(b'(J\xff\xff\xff\xffK\x00U\x06.replyNt.') self._write(self.heartbeat_bytes)
self.heartbeat_handle = self.loop.call_later( self.heartbeat_handle = self.loop.call_later(
self.heartbeat_interval, self.heartbeat) self.heartbeat_interval, self.heartbeat)
......
...@@ -26,10 +26,18 @@ from ..shortrepr import short_repr ...@@ -26,10 +26,18 @@ from ..shortrepr import short_repr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def encoder(): def encoder(protocol):
"""Return a non-thread-safe encoder """Return a non-thread-safe encoder
""" """
if protocol[:1] == b'M':
from msgpack import packb
def encode(*args):
return packb(args, use_bin_type=True)
return encode
else:
assert protocol[:1] == b'Z'
if PY3 or PYPY: if PY3 or PYPY:
f = BytesIO() f = BytesIO()
getvalue = f.getvalue getvalue = f.getvalue
...@@ -54,9 +62,20 @@ def encoder(): ...@@ -54,9 +62,20 @@ def encoder():
def encode(*args): def encode(*args):
return encoder()(*args) return encoder(b'Z')(*args)
def decode(msg): def decoder(protocol):
if protocol[:1] == b'M':
from msgpack import unpackb
def msgpack_decode(data):
"""Decodes msg and returns its parts"""
return unpackb(data, encoding='utf-8')
return msgpack_decode
else:
assert protocol[:1] == b'Z'
return pickle_decode
def pickle_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global unpickler.find_global = find_global
...@@ -71,7 +90,14 @@ def decode(msg): ...@@ -71,7 +90,14 @@ def decode(msg):
logger.error("can't decode message: %s" % short_repr(msg)) logger.error("can't decode message: %s" % short_repr(msg))
raise raise
def server_decode(msg): def server_decoder(protocol):
if protocol[:1] == b'M':
return decoder(protocol)
else:
assert protocol[:1] == b'Z'
return pickle_server_decode
def pickle_server_decode(msg):
"""Decodes msg and returns its parts""" """Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg)) unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global unpickler.find_global = server_find_global
......
...@@ -76,13 +76,14 @@ class Acceptor(asyncore.dispatcher): ...@@ -76,13 +76,14 @@ class Acceptor(asyncore.dispatcher):
And creates a separate thread for each. And creates a separate thread for each.
""" """
def __init__(self, storage_server, addr, ssl): def __init__(self, storage_server, addr, ssl, msgpack):
self.storage_server = storage_server self.storage_server = storage_server
self.addr = addr self.addr = addr
self.__socket_map = {} self.__socket_map = {}
asyncore.dispatcher.__init__(self, map=self.__socket_map) asyncore.dispatcher.__init__(self, map=self.__socket_map)
self.ssl_context = ssl self.ssl_context = ssl
self.msgpack = msgpack
self._open_socket() self._open_socket()
def _open_socket(self): def _open_socket(self):
...@@ -165,7 +166,7 @@ class Acceptor(asyncore.dispatcher): ...@@ -165,7 +166,7 @@ class Acceptor(asyncore.dispatcher):
def run(): def run():
loop = new_event_loop() loop = new_event_loop()
zs = self.storage_server.create_client_handler() zs = self.storage_server.create_client_handler()
protocol = ServerProtocol(loop, self.addr, zs) protocol = ServerProtocol(loop, self.addr, zs, self.msgpack)
protocol.stop = loop.stop protocol.stop = loop.stop
if self.ssl_context is None: if self.ssl_context is None:
......
...@@ -11,13 +11,13 @@ from ..shortrepr import short_repr ...@@ -11,13 +11,13 @@ from ..shortrepr import short_repr
from . import base from . import base
from .compat import asyncio, new_event_loop from .compat import asyncio, new_event_loop
from .marshal import server_decode from .marshal import server_decoder, encoder
class ServerProtocol(base.Protocol): class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface """asyncio low-level ZEO server interface
""" """
protocols = (b'Z5', ) protocols = (b'5', )
name = 'server protocol' name = 'server protocol'
methods = set(('register', )) methods = set(('register', ))
...@@ -26,12 +26,16 @@ class ServerProtocol(base.Protocol): ...@@ -26,12 +26,16 @@ class ServerProtocol(base.Protocol):
ZODB.POSException.POSKeyError, ZODB.POSException.POSKeyError,
) )
def __init__(self, loop, addr, zeo_storage): def __init__(self, loop, addr, zeo_storage, msgpack):
"""Create a server's client interface """Create a server's client interface
""" """
super(ServerProtocol, self).__init__(loop, addr) super(ServerProtocol, self).__init__(loop, addr)
self.zeo_storage = zeo_storage self.zeo_storage = zeo_storage
self.announce_protocol = (
(b'M' if msgpack else b'Z') + best_protocol_version
)
closed = False closed = False
def close(self): def close(self):
logger.debug("Closing server protocol") logger.debug("Closing server protocol")
...@@ -44,7 +48,7 @@ class ServerProtocol(base.Protocol): ...@@ -44,7 +48,7 @@ class ServerProtocol(base.Protocol):
def connection_made(self, transport): def connection_made(self, transport):
self.connected = True self.connected = True
super(ServerProtocol, self).connection_made(transport) super(ServerProtocol, self).connection_made(transport)
self._write(best_protocol_version) self._write(self.announce_protocol)
def connection_lost(self, exc): def connection_lost(self, exc):
self.connected = False self.connected = False
...@@ -61,10 +65,13 @@ class ServerProtocol(base.Protocol): ...@@ -61,10 +65,13 @@ class ServerProtocol(base.Protocol):
self._write(json.dumps(self.zeo_storage.ruok()).encode("ascii")) self._write(json.dumps(self.zeo_storage.ruok()).encode("ascii"))
self.close() self.close()
else: else:
if protocol_version in self.protocols: version = protocol_version[1:]
if version in self.protocols:
logger.info("received handshake %r" % logger.info("received handshake %r" %
str(protocol_version.decode('ascii'))) str(protocol_version.decode('ascii')))
self.protocol_version = protocol_version self.protocol_version = protocol_version
self.encode = encoder(protocol_version)
self.decode = server_decoder(protocol_version)
self.zeo_storage.notify_connected(self) self.zeo_storage.notify_connected(self)
else: else:
logger.error("bad handshake %s" % short_repr(protocol_version)) logger.error("bad handshake %s" % short_repr(protocol_version))
...@@ -79,7 +86,7 @@ class ServerProtocol(base.Protocol): ...@@ -79,7 +86,7 @@ class ServerProtocol(base.Protocol):
def message_received(self, message): def message_received(self, message):
try: try:
message_id, async, name, args = server_decode(message) message_id, async, name, args = self.decode(message)
except Exception: except Exception:
logger.exception("Can't deserialize message") logger.exception("Can't deserialize message")
self.close() self.close()
...@@ -144,8 +151,8 @@ best_protocol_version = os.environ.get( ...@@ -144,8 +151,8 @@ best_protocol_version = os.environ.get(
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8') ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols assert best_protocol_version in ServerProtocol.protocols
def new_connection(loop, addr, socket, zeo_storage): def new_connection(loop, addr, socket, zeo_storage, msgpack):
protocol = ServerProtocol(loop, addr, zeo_storage) protocol = ServerProtocol(loop, addr, zeo_storage, msgpack)
cr = loop.create_connection((lambda : protocol), sock=socket) cr = loop.create_connection((lambda : protocol), sock=socket)
asyncio.async(cr, loop=loop) asyncio.async(cr, loop=loop)
...@@ -213,10 +220,11 @@ class MTDelay(Delay): ...@@ -213,10 +220,11 @@ class MTDelay(Delay):
class Acceptor(object): class Acceptor(object):
def __init__(self, storage_server, addr, ssl): def __init__(self, storage_server, addr, ssl, msgpack):
self.storage_server = storage_server self.storage_server = storage_server
self.addr = addr self.addr = addr
self.ssl_context = ssl self.ssl_context = ssl
self.msgpack = msgpack
self.event_loop = loop = new_event_loop() self.event_loop = loop = new_event_loop()
if isinstance(addr, tuple): if isinstance(addr, tuple):
...@@ -243,7 +251,8 @@ class Acceptor(object): ...@@ -243,7 +251,8 @@ class Acceptor(object):
try: try:
logger.debug("Accepted connection") logger.debug("Accepted connection")
zs = self.storage_server.create_client_handler() zs = self.storage_server.create_client_handler()
protocol = ServerProtocol(self.event_loop, self.addr, zs) protocol = ServerProtocol(
self.event_loop, self.addr, zs, self.msgpack)
except Exception: except Exception:
logger.exception("Failure in protocol factory") logger.exception("Failure in protocol factory")
......
...@@ -21,13 +21,16 @@ from ..Exceptions import ClientDisconnected, ProtocolError ...@@ -21,13 +21,16 @@ from ..Exceptions import ClientDisconnected, ProtocolError
from .testing import Loop from .testing import Loop
from .client import ClientRunner, Fallback from .client import ClientRunner, Fallback
from .server import new_connection, best_protocol_version from .server import new_connection, best_protocol_version
from .marshal import encoder, decode from .marshal import encoder, decoder
class Base(object): class Base(object):
enc = b'Z'
def setUp(self): def setUp(self):
super(Base, self).setUp() super(Base, self).setUp()
self.encode = encoder() self.encode = encoder(self.enc)
self.decode = decoder(self.enc)
def unsized(self, data, unpickle=False): def unsized(self, data, unpickle=False):
result = [] result = []
...@@ -36,7 +39,11 @@ class Base(object): ...@@ -36,7 +39,11 @@ class Base(object):
data = data[2:] data = data[2:]
self.assertEqual(struct.unpack(">I", size)[0], len(message)) self.assertEqual(struct.unpack(">I", size)[0], len(message))
if unpickle: if unpickle:
message = decode(message) message = tuple(self.decode(message))
if isinstance(message[-1], list):
message = message[:-1] + (tuple(message[-1]),)
if isinstance(message[0], list):
message = (tuple(message[-1]),) + message[1:]
result.append(message) result.append(message)
if len(result) == 1: if len(result) == 1:
...@@ -98,8 +105,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -98,8 +105,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
transport = loop.transport transport = loop.transport
if finish_start: if finish_start:
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.pop(2, False), b'Z3101') self.assertEqual(self.pop(2, False), self.enc + b'3101')
self.respond(1, None) self.respond(1, None)
self.respond(2, 'a'*8) self.respond(2, 'a'*8)
self.pop(4) self.pop(4)
...@@ -108,9 +115,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -108,9 +115,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
return (wrapper, cache, self.loop, self.client, protocol, transport) return (wrapper, cache, self.loop, self.client, protocol, transport)
def respond(self, message_id, result): def respond(self, message_id, result, async=False):
self.loop.protocol.data_received( self.loop.protocol.data_received(
sized(self.encode(message_id, False, '.reply', result))) sized(self.encode(message_id, async, '.reply', result)))
def wait_for_result(self, future, timeout): def wait_for_result(self, future, timeout):
if future.done() and future.exception() is not None: if future.done() and future.exception() is not None:
...@@ -133,11 +140,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -133,11 +140,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# The server sends the client it's protocol. In this case, # The server sends the client it's protocol. In this case,
# it's a very high one. The client will send it's highest that # it's a very high one. The client will send it's highest that
# it can use. # it can use.
protocol.data_received(sized(b'Z99999')) protocol.data_received(sized(self.enc + b'99999'))
# The client sends back a handshake, and registers the # The client sends back a handshake, and registers the
# storage, and requests the last transaction. # storage, and requests the last transaction.
self.assertEqual(self.pop(2, False), b'Z5') self.assertEqual(self.pop(2, False), self.enc + b'5')
self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False))) self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False)))
# The client isn't connected until it initializes it's cache: # The client isn't connected until it initializes it's cache:
...@@ -195,12 +202,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -195,12 +202,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
loaded = self.load_before(b'1'*8, maxtid) loaded = self.load_before(b'1'*8, maxtid)
# The data wasn't in the cache, so we made a server call: # The data wasn't in the cache, so we made a server call:
self.assertEqual( self.assertEqual(self.pop(), (5, False, 'loadBefore', (b'1'*8, maxtid)))
self.pop(),
((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid)))
# Note load_before uses the oid as the message id. # Note load_before uses the oid as the message id.
self.respond((b'1'*8, maxtid), (b'data', b'a'*8, None)) self.respond(5, (b'data', b'a'*8, None))
self.assertEqual(loaded.result(), (b'data', b'a'*8, None)) self.assertEqual(tuple(loaded.result()), (b'data', b'a'*8, None))
# If we make another request, it will be satisfied from the cache: # If we make another request, it will be satisfied from the cache:
loaded = self.load_before(b'1'*8, maxtid) loaded = self.load_before(b'1'*8, maxtid)
...@@ -217,27 +222,23 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -217,27 +222,23 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# the requests will be collapsed: # the requests will be collapsed:
loaded2 = self.load_before(b'1'*8, maxtid) loaded2 = self.load_before(b'1'*8, maxtid)
self.assertEqual( self.assertEqual(self.pop(), (6, False, 'loadBefore', (b'1'*8, maxtid)))
self.pop(), self.respond(6, (b'data2', b'b'*8, None))
((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid))) self.assertEqual(tuple(loaded.result()), (b'data2', b'b'*8, None))
self.respond((b'1'*8, maxtid), (b'data2', b'b'*8, None)) self.assertEqual(tuple(loaded2.result()), (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None))
# Loading non-current data may also be satisfied from cache # Loading non-current data may also be satisfied from cache
loaded = self.load_before(b'1'*8, b'b'*8) loaded = self.load_before(b'1'*8, b'b'*8)
self.assertEqual(loaded.result(), (b'data', b'a'*8, b'b'*8)) self.assertEqual(tuple(loaded.result()), (b'data', b'a'*8, b'b'*8))
self.assertFalse(transport.data) self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'c'*8) loaded = self.load_before(b'1'*8, b'c'*8)
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None)) self.assertEqual(tuple(loaded.result()), (b'data2', b'b'*8, None))
self.assertFalse(transport.data) self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8) loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual( self.assertEqual(self.pop(), (7, False, 'loadBefore', (b'1'*8, b'_'*8)))
self.pop(), self.respond(7, (b'data0', b'^'*8, b'_'*8))
((b'1'*8, b'_'*8), False, 'loadBefore', (b'1'*8, b'_'*8))) self.assertEqual(tuple(loaded.result()), (b'data0', b'^'*8, b'_'*8))
self.respond((b'1'*8, b'_'*8), (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
# When committing transactions, we need to update the cache # When committing transactions, we need to update the cache
# with committed data. To do this, we pass a (oid, data, resolved) # with committed data. To do this, we pass a (oid, data, resolved)
...@@ -259,8 +260,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -259,8 +260,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
cache.load(b'4'*8)) cache.load(b'4'*8))
self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8)) self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8))
self.assertEqual(self.pop(), self.assertEqual(self.pop(),
(5, False, 'tpc_finish', (b'd'*8,))) (8, False, 'tpc_finish', (b'd'*8,)))
self.respond(5, b'e'*8) self.respond(8, b'e'*8)
self.assertEqual(committed.result(), b'e'*8) self.assertEqual(committed.result(), b'e'*8)
self.assertEqual(cache.load(b'1'*8), None) self.assertEqual(cache.load(b'1'*8), None)
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8)) self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8))
...@@ -274,8 +275,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -274,8 +275,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(loaded.done() or f1.done()) self.assertFalse(loaded.done() or f1.done())
self.assertEqual( self.assertEqual(
self.pop(), self.pop(),
[((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid)), [(9, False, 'loadBefore', (b'1'*8, maxtid)),
(6, False, 'foo', (1, 2))], (10, False, 'foo', (1, 2))],
) )
exc = TypeError(43) exc = TypeError(43)
...@@ -301,8 +302,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -301,8 +302,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# This time we'll send a lower protocol version. The client # This time we'll send a lower protocol version. The client
# will send it back, because it's lower than the client's # will send it back, because it's lower than the client's
# protocol: # protocol:
protocol.data_received(sized(b'Z310')) protocol.data_received(sized(self.enc + b'310'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z310') self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'310')
self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False))) self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False)))
self.assertFalse(wrapper.notify_connected.called) self.assertFalse(wrapper.notify_connected.called)
...@@ -337,8 +338,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -337,8 +338,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
cache.store(b'2'*8, b'a'*8, None, '2 data') cache.store(b'2'*8, b'a'*8, None, '2 data')
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None) self.respond(1, None)
self.respond(2, b'e'*8) self.respond(2, b'e'*8)
self.pop(4) self.pop(4)
...@@ -372,8 +373,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -372,8 +373,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertTrue(cache) self.assertTrue(cache)
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None) self.respond(1, None)
self.respond(2, b'e'*8) self.respond(2, b'e'*8)
self.pop(4) self.pop(4)
...@@ -423,8 +424,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -423,8 +424,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertEqual(sorted(loop.connecting), addrs[:1]) self.assertEqual(sorted(loop.connecting), addrs[:1])
protocol = loop.protocol protocol = loop.protocol
transport = loop.transport transport = loop.transport
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None) self.respond(1, None)
# Now, when the first connection fails, it won't be retried, # Now, when the first connection fails, it won't be retried,
...@@ -441,8 +442,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -441,8 +442,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport = self.start() wrapper, cache, loop, client, protocol, transport = self.start()
cache.store(b'4'*8, b'a'*8, None, '4 data') cache.store(b'4'*8, b'a'*8, None, '4 data')
cache.setLastTid('b'*8) cache.setLastTid('b'*8)
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None) self.respond(1, None)
self.respond(2, 'a'*8) self.respond(2, 'a'*8)
self.pop() self.pop()
...@@ -455,8 +456,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -455,8 +456,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(transport is loop.transport) self.assertFalse(transport is loop.transport)
protocol = loop.protocol protocol = loop.protocol
transport = loop.transport transport = loop.transport
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None) self.respond(1, None)
self.respond(2, 'b'*8) self.respond(2, 'b'*8)
self.pop(4) self.pop(4)
...@@ -475,13 +476,13 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -475,13 +476,13 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We'll treat the first address as read-only and we'll let it connect: # We'll treat the first address as read-only and we'll let it connect:
loop.connect_connecting(addrs[0]) loop.connect_connecting(addrs[0])
protocol, transport = loop.protocol, loop.transport protocol, transport = loop.protocol, loop.transport
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
# We see that the client tried a writable connection: # We see that the client tried a writable connection:
self.assertEqual(self.pop(), self.assertEqual(self.pop(),
(1, False, 'register', ('TEST', False))) (1, False, 'register', ('TEST', False)))
# We respond with a read-only exception: # We respond with a read-only exception:
self.respond(1, (ReadOnlyError, ReadOnlyError())) self.respond(1, ('ZODB.POSException.ReadOnlyError', ()), True)
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
# The client tries for a read-only connection: # The client tries for a read-only connection:
...@@ -507,8 +508,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -507,8 +508,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We connect the second address: # We connect the second address:
loop.connect_connecting(addrs[1]) loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(b'Z3101')) loop.protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(loop.transport.pop(2)), self.enc + b'3101')
self.assertEqual(self.parse(loop.transport.pop()), self.assertEqual(self.parse(loop.transport.pop()),
(1, False, 'register', ('TEST', False))) (1, False, 'register', ('TEST', False)))
self.assertTrue(self.is_read_only()) self.assertTrue(self.is_read_only())
...@@ -542,8 +543,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -542,8 +543,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
def test_invalidations_while_verifying(self): def test_invalidations_while_verifying(self):
# While we're verifying, invalidations are ignored # While we're verifying, invalidations are ignored
wrapper, cache, loop, client, protocol, transport = self.start() wrapper, cache, loop, client, protocol, transport = self.start()
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None) self.respond(1, None)
self.pop(4) self.pop(4)
self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False) self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False)
...@@ -560,8 +561,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -560,8 +561,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# Similarly, invalidations aren't processed while reconnecting: # Similarly, invalidations aren't processed while reconnecting:
protocol.data_received(sized(b'Z3101')) protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101') self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None) self.respond(1, None)
self.pop(4) self.pop(4)
self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False) self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False)
...@@ -604,7 +605,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -604,7 +605,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport = self.start() wrapper, cache, loop, client, protocol, transport = self.start()
with mock.patch("ZEO.asyncio.client.logger.error") as error: with mock.patch("ZEO.asyncio.client.logger.error") as error:
self.assertFalse(error.called) self.assertFalse(error.called)
protocol.data_received(sized(b'Z200')) protocol.data_received(sized(self.enc + b'200'))
self.assert_(isinstance(error.call_args[0][1], ProtocolError)) self.assert_(isinstance(error.call_args[0][1], ProtocolError))
...@@ -688,6 +689,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner): ...@@ -688,6 +689,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.connection_lost(None) protocol.connection_lost(None)
self.assertTrue(handle.cancelled) self.assertTrue(handle.cancelled)
class MsgpackClientTests(ClientTests):
enc = b'M'
class MemoryCache(object): class MemoryCache(object):
...@@ -750,12 +753,13 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -750,12 +753,13 @@ class ServerTests(Base, setupstack.TestCase):
# connections. Servers are pretty passive. # connections. Servers are pretty passive.
def connect(self, finish=False): def connect(self, finish=False):
protocol = server_protocol() protocol = server_protocol(self.enc == b'M')
self.loop = protocol.loop self.loop = protocol.loop
self.target = protocol.zeo_storage self.target = protocol.zeo_storage
if finish: if finish:
self.assertEqual(self.pop(parse=False), best_protocol_version) self.assertEqual(self.pop(parse=False),
protocol.data_received(sized(b'Z5')) self.enc + best_protocol_version)
protocol.data_received(sized(self.enc + b'5'))
return protocol return protocol
message_id = 0 message_id = 0
...@@ -790,12 +794,13 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -790,12 +794,13 @@ class ServerTests(Base, setupstack.TestCase):
self.assertFalse(protocol.zeo_storage.notify_connected.called) self.assertFalse(protocol.zeo_storage.notify_connected.called)
# The server sends it's protocol. # The server sends it's protocol.
self.assertEqual(self.pop(parse=False), best_protocol_version) self.assertEqual(self.pop(parse=False),
self.enc + best_protocol_version)
# The client sends it's protocol: # The client sends it's protocol:
protocol.data_received(sized(b'Z5')) protocol.data_received(sized(self.enc + b'5'))
self.assertEqual(protocol.protocol_version, b'Z5') self.assertEqual(protocol.protocol_version, self.enc + b'5')
protocol.zeo_storage.notify_connected.assert_called_once_with(protocol) protocol.zeo_storage.notify_connected.assert_called_once_with(protocol)
...@@ -823,7 +828,11 @@ class ServerTests(Base, setupstack.TestCase): ...@@ -823,7 +828,11 @@ class ServerTests(Base, setupstack.TestCase):
self.call('foo', target=None) self.call('foo', target=None)
self.assertTrue(protocol.loop.transport.closed) self.assertTrue(protocol.loop.transport.closed)
def server_protocol(zeo_storage=None, class MsgpackServerTests(ServerTests):
enc = b'M'
def server_protocol(msgpack,
zeo_storage=None,
protocol_version=None, protocol_version=None,
addr=('1.2.3.4', '42'), addr=('1.2.3.4', '42'),
): ):
...@@ -831,7 +840,7 @@ def server_protocol(zeo_storage=None, ...@@ -831,7 +840,7 @@ def server_protocol(zeo_storage=None,
zeo_storage = mock.Mock() zeo_storage = mock.Mock()
loop = Loop() loop = Loop()
sock = () # anything not None sock = () # anything not None
new_connection(loop, addr, sock, zeo_storage) new_connection(loop, addr, sock, zeo_storage, msgpack)
if protocol_version: if protocol_version:
loop.protocol.data_received(sized(protocol_version)) loop.protocol.data_received(sized(protocol_version))
return loop.protocol return loop.protocol
...@@ -861,4 +870,6 @@ def test_suite(): ...@@ -861,4 +870,6 @@ def test_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ClientTests)) suite.addTest(unittest.makeSuite(ClientTests))
suite.addTest(unittest.makeSuite(ServerTests)) suite.addTest(unittest.makeSuite(ServerTests))
suite.addTest(unittest.makeSuite(MsgpackClientTests))
suite.addTest(unittest.makeSuite(MsgpackServerTests))
return suite return suite
...@@ -100,6 +100,7 @@ class ZEOOptionsMixin: ...@@ -100,6 +100,7 @@ class ZEOOptionsMixin:
self.add("client_conflict_resolution", self.add("client_conflict_resolution",
"zeo.client_conflict_resolution", "zeo.client_conflict_resolution",
default=0) default=0)
self.add("msgpack", "zeo.msgpack", default=0)
self.add("invalidation_queue_size", "zeo.invalidation_queue_size", self.add("invalidation_queue_size", "zeo.invalidation_queue_size",
default=100) default=100)
self.add("invalidation_age", "zeo.invalidation_age") self.add("invalidation_age", "zeo.invalidation_age")
...@@ -342,6 +343,7 @@ def create_server(storages, options): ...@@ -342,6 +343,7 @@ def create_server(storages, options):
storages, storages,
read_only = options.read_only, read_only = options.read_only,
client_conflict_resolution=options.client_conflict_resolution, client_conflict_resolution=options.client_conflict_resolution,
msgpack=options.msgpack,
invalidation_queue_size = options.invalidation_queue_size, invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age, invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout, transaction_timeout = options.transaction_timeout,
......
...@@ -115,6 +115,16 @@ ...@@ -115,6 +115,16 @@
</description> </description>
</key> </key>
<key name="msgpack" datatype="boolean" required="no" default="false">
<description>
Use msgpack to serialize and de-serialize ZEO protocol messages.
An advantage of using msgpack for ZEO communication is that
it's a little bit faster and a ZEO server can support Python 2
or Python 3 clients (but not both).
</description>
</key>
</sectiontype> </sectiontype>
</component> </component>
...@@ -17,7 +17,7 @@ Let's start a Z4 server ...@@ -17,7 +17,7 @@ Let's start a Z4 server
... ''' ... '''
>>> addr, stop = start_server( >>> addr, stop = start_server(
... storage_conf, dict(invalidation_queue_size=5), protocol=b'Z4') ... storage_conf, dict(invalidation_queue_size=5), protocol=b'4')
A current client should be able to connect to a old server: A current client should be able to connect to a old server:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment