Commit b31fed14 authored by Jim Fulton's avatar Jim Fulton

Implemented msgpack as an optional ZEO message encoding with basic tests.

parent c3183420
......@@ -289,6 +289,13 @@ client-conflict-resolution
Flag indicating that clients should perform conflict
resolution. This option defaults to false.
msgpack
Use msgpack to serialize and de-serialize ZEO protocol messages.
An advantage of using msgpack for ZEO communication is that
it's a little bit faster and a ZEO server can support Python 2
or Python 3 clients (but not both).
Server SSL configuration
~~~~~~~~~~~~~~~~~~~~~~~~
......
......@@ -36,7 +36,7 @@ install_requires = [
'zope.interface',
]
tests_require = ['zope.testing', 'manuel', 'random2', 'mock']
tests_require = ['zope.testing', 'manuel', 'random2', 'mock', 'msgpack-python']
if sys.version_info[:2] < (3, ):
install_requires.extend(('futures', 'trollius'))
......@@ -128,7 +128,11 @@ setup(name="ZEO",
classifiers = classifiers,
test_suite="__main__.alltests", # to support "setup.py test"
tests_require = tests_require,
extras_require = dict(test=tests_require, uvloop=['uvloop >=0.5.1']),
extras_require = dict(
test=tests_require,
uvloop=['uvloop >=0.5.1'],
msgpack=['msgpack-python'],
),
install_requires = install_requires,
zip_safe = False,
entry_points = """
......
......@@ -663,6 +663,7 @@ class StorageServer:
ssl=None,
client_conflict_resolution=False,
Acceptor=Acceptor,
msgpack=False,
):
"""StorageServer constructor.
......@@ -757,7 +758,7 @@ class StorageServer:
self.client_conflict_resolution = client_conflict_resolution
if addr is not None:
self.acceptor = Acceptor(self, addr, ssl)
self.acceptor = Acceptor(self, addr, ssl, msgpack)
if isinstance(addr, tuple) and addr[0]:
self.addr = self.acceptor.addr
else:
......
......@@ -10,8 +10,6 @@ import socket
from struct import unpack
import sys
from .marshal import encoder
logger = logging.getLogger(__name__)
INET_FAMILIES = socket.AF_INET, socket.AF_INET6
......@@ -129,13 +127,13 @@ class Protocol(asyncio.Protocol):
self.getting_size = True
self.message_received(collected)
except Exception:
#import traceback; traceback.print_exc()
logger.exception("data_received %s %s %s",
self.want, self.got, self.getting_size)
def first_message_received(self, protocol_version):
# Handler for first/handshake message, set up in __init__
del self.message_received # use default handler from here on
self.encode = encoder()
self.finish_connect(protocol_version)
def call_async(self, method, args):
......
......@@ -13,7 +13,7 @@ import ZEO.interfaces
from . import base
from .compat import asyncio, new_event_loop
from .marshal import decode
from .marshal import encoder, decoder
logger = logging.getLogger(__name__)
......@@ -63,7 +63,7 @@ class Protocol(base.Protocol):
# One place where special care was required was in cache setup on
# connect. See finish connect below.
protocols = b'Z309', b'Z310', b'Z3101', b'Z4', b'Z5'
protocols = b'309', b'310', b'3101', b'4', b'5'
def __init__(self, loop,
addr, client, storage_key, read_only, connect_poll=1,
......@@ -150,6 +150,8 @@ class Protocol(base.Protocol):
# We have to be careful processing the futures, because
# exception callbacks might modufy them.
for f in self.pop_futures():
if isinstance(f, tuple):
continue
f.set_exception(ClientDisconnected(exc or 'connection lost'))
self.closed = True
self.client.disconnected(self)
......@@ -165,13 +167,17 @@ class Protocol(base.Protocol):
# lastTid before processing (and possibly missing) subsequent
# invalidations.
self.protocol_version = min(protocol_version, self.protocols[-1])
if self.protocol_version not in self.protocols:
version = min(protocol_version[1:], self.protocols[-1])
if version not in self.protocols:
self.client.register_failed(
self, ZEO.Exceptions.ProtocolError(protocol_version))
return
self.protocol_version = protocol_version[:1] + version
self.encode = encoder(protocol_version)
self.decode = decoder(protocol_version)
self.heartbeat_bytes = self.encode(-1, 0, '.reply', None)
self._write(self.protocol_version)
credentials = (self.credentials,) if self.credentials else ()
......@@ -199,9 +205,12 @@ class Protocol(base.Protocol):
exception_type_type = type(Exception)
def message_received(self, data):
msgid, async, name, args = decode(data)
msgid, async, name, args = self.decode(data)
if name == '.reply':
future = self.futures.pop(msgid)
if isinstance(future, tuple):
future = self.futures.pop(future)
if (async): # ZEO 5 exception
class_, args = args
factory = exc_factories.get(class_)
......@@ -245,13 +254,15 @@ class Protocol(base.Protocol):
def load_before(self, oid, tid):
# Special-case loadBefore, so we collapse outstanding requests
message_id = (oid, tid)
future = self.futures.get(message_id)
oid_tid = (oid, tid)
future = self.futures.get(oid_tid)
if future is None:
future = asyncio.Future(loop=self.loop)
self.futures[message_id] = future
self.futures[oid_tid] = future
self.message_id += 1
self.futures[self.message_id] = oid_tid
self._write(
self.encode(message_id, False, 'loadBefore', (oid, tid)))
self.encode(self.message_id, False, 'loadBefore', (oid, tid)))
return future
# Methods called by the server.
......@@ -267,7 +278,7 @@ class Protocol(base.Protocol):
def heartbeat(self, write=True):
if write:
self._write(b'(J\xff\xff\xff\xffK\x00U\x06.replyNt.')
self._write(self.heartbeat_bytes)
self.heartbeat_handle = self.loop.call_later(
self.heartbeat_interval, self.heartbeat)
......
......@@ -26,10 +26,18 @@ from ..shortrepr import short_repr
logger = logging.getLogger(__name__)
def encoder():
def encoder(protocol):
"""Return a non-thread-safe encoder
"""
if protocol[:1] == b'M':
from msgpack import packb
def encode(*args):
return packb(args, use_bin_type=True)
return encode
else:
assert protocol[:1] == b'Z'
if PY3 or PYPY:
f = BytesIO()
getvalue = f.getvalue
......@@ -54,9 +62,20 @@ def encoder():
def encode(*args):
return encoder()(*args)
return encoder(b'Z')(*args)
def decode(msg):
def decoder(protocol):
if protocol[:1] == b'M':
from msgpack import unpackb
def msgpack_decode(data):
"""Decodes msg and returns its parts"""
return unpackb(data, encoding='utf-8')
return msgpack_decode
else:
assert protocol[:1] == b'Z'
return pickle_decode
def pickle_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = find_global
......@@ -71,7 +90,14 @@ def decode(msg):
logger.error("can't decode message: %s" % short_repr(msg))
raise
def server_decode(msg):
def server_decoder(protocol):
if protocol[:1] == b'M':
return decoder(protocol)
else:
assert protocol[:1] == b'Z'
return pickle_server_decode
def pickle_server_decode(msg):
"""Decodes msg and returns its parts"""
unpickler = Unpickler(BytesIO(msg))
unpickler.find_global = server_find_global
......
......@@ -76,13 +76,14 @@ class Acceptor(asyncore.dispatcher):
And creates a separate thread for each.
"""
def __init__(self, storage_server, addr, ssl):
def __init__(self, storage_server, addr, ssl, msgpack):
self.storage_server = storage_server
self.addr = addr
self.__socket_map = {}
asyncore.dispatcher.__init__(self, map=self.__socket_map)
self.ssl_context = ssl
self.msgpack = msgpack
self._open_socket()
def _open_socket(self):
......@@ -165,7 +166,7 @@ class Acceptor(asyncore.dispatcher):
def run():
loop = new_event_loop()
zs = self.storage_server.create_client_handler()
protocol = ServerProtocol(loop, self.addr, zs)
protocol = ServerProtocol(loop, self.addr, zs, self.msgpack)
protocol.stop = loop.stop
if self.ssl_context is None:
......
......@@ -11,13 +11,13 @@ from ..shortrepr import short_repr
from . import base
from .compat import asyncio, new_event_loop
from .marshal import server_decode
from .marshal import server_decoder, encoder
class ServerProtocol(base.Protocol):
"""asyncio low-level ZEO server interface
"""
protocols = (b'Z5', )
protocols = (b'5', )
name = 'server protocol'
methods = set(('register', ))
......@@ -26,12 +26,16 @@ class ServerProtocol(base.Protocol):
ZODB.POSException.POSKeyError,
)
def __init__(self, loop, addr, zeo_storage):
def __init__(self, loop, addr, zeo_storage, msgpack):
"""Create a server's client interface
"""
super(ServerProtocol, self).__init__(loop, addr)
self.zeo_storage = zeo_storage
self.announce_protocol = (
(b'M' if msgpack else b'Z') + best_protocol_version
)
closed = False
def close(self):
logger.debug("Closing server protocol")
......@@ -44,7 +48,7 @@ class ServerProtocol(base.Protocol):
def connection_made(self, transport):
self.connected = True
super(ServerProtocol, self).connection_made(transport)
self._write(best_protocol_version)
self._write(self.announce_protocol)
def connection_lost(self, exc):
self.connected = False
......@@ -61,10 +65,13 @@ class ServerProtocol(base.Protocol):
self._write(json.dumps(self.zeo_storage.ruok()).encode("ascii"))
self.close()
else:
if protocol_version in self.protocols:
version = protocol_version[1:]
if version in self.protocols:
logger.info("received handshake %r" %
str(protocol_version.decode('ascii')))
self.protocol_version = protocol_version
self.encode = encoder(protocol_version)
self.decode = server_decoder(protocol_version)
self.zeo_storage.notify_connected(self)
else:
logger.error("bad handshake %s" % short_repr(protocol_version))
......@@ -79,7 +86,7 @@ class ServerProtocol(base.Protocol):
def message_received(self, message):
try:
message_id, async, name, args = server_decode(message)
message_id, async, name, args = self.decode(message)
except Exception:
logger.exception("Can't deserialize message")
self.close()
......@@ -144,8 +151,8 @@ best_protocol_version = os.environ.get(
ServerProtocol.protocols[-1].decode('utf-8')).encode('utf-8')
assert best_protocol_version in ServerProtocol.protocols
def new_connection(loop, addr, socket, zeo_storage):
protocol = ServerProtocol(loop, addr, zeo_storage)
def new_connection(loop, addr, socket, zeo_storage, msgpack):
protocol = ServerProtocol(loop, addr, zeo_storage, msgpack)
cr = loop.create_connection((lambda : protocol), sock=socket)
asyncio.async(cr, loop=loop)
......@@ -213,10 +220,11 @@ class MTDelay(Delay):
class Acceptor(object):
def __init__(self, storage_server, addr, ssl):
def __init__(self, storage_server, addr, ssl, msgpack):
self.storage_server = storage_server
self.addr = addr
self.ssl_context = ssl
self.msgpack = msgpack
self.event_loop = loop = new_event_loop()
if isinstance(addr, tuple):
......@@ -243,7 +251,8 @@ class Acceptor(object):
try:
logger.debug("Accepted connection")
zs = self.storage_server.create_client_handler()
protocol = ServerProtocol(self.event_loop, self.addr, zs)
protocol = ServerProtocol(
self.event_loop, self.addr, zs, self.msgpack)
except Exception:
logger.exception("Failure in protocol factory")
......
......@@ -21,13 +21,16 @@ from ..Exceptions import ClientDisconnected, ProtocolError
from .testing import Loop
from .client import ClientRunner, Fallback
from .server import new_connection, best_protocol_version
from .marshal import encoder, decode
from .marshal import encoder, decoder
class Base(object):
enc = b'Z'
def setUp(self):
super(Base, self).setUp()
self.encode = encoder()
self.encode = encoder(self.enc)
self.decode = decoder(self.enc)
def unsized(self, data, unpickle=False):
result = []
......@@ -36,7 +39,11 @@ class Base(object):
data = data[2:]
self.assertEqual(struct.unpack(">I", size)[0], len(message))
if unpickle:
message = decode(message)
message = tuple(self.decode(message))
if isinstance(message[-1], list):
message = message[:-1] + (tuple(message[-1]),)
if isinstance(message[0], list):
message = (tuple(message[-1]),) + message[1:]
result.append(message)
if len(result) == 1:
......@@ -98,8 +105,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
transport = loop.transport
if finish_start:
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.pop(2, False), b'Z3101')
protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.pop(2, False), self.enc + b'3101')
self.respond(1, None)
self.respond(2, 'a'*8)
self.pop(4)
......@@ -108,9 +115,9 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
return (wrapper, cache, self.loop, self.client, protocol, transport)
def respond(self, message_id, result):
def respond(self, message_id, result, async=False):
self.loop.protocol.data_received(
sized(self.encode(message_id, False, '.reply', result)))
sized(self.encode(message_id, async, '.reply', result)))
def wait_for_result(self, future, timeout):
if future.done() and future.exception() is not None:
......@@ -133,11 +140,11 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# The server sends the client it's protocol. In this case,
# it's a very high one. The client will send it's highest that
# it can use.
protocol.data_received(sized(b'Z99999'))
protocol.data_received(sized(self.enc + b'99999'))
# The client sends back a handshake, and registers the
# storage, and requests the last transaction.
self.assertEqual(self.pop(2, False), b'Z5')
self.assertEqual(self.pop(2, False), self.enc + b'5')
self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False)))
# The client isn't connected until it initializes it's cache:
......@@ -195,12 +202,10 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
loaded = self.load_before(b'1'*8, maxtid)
# The data wasn't in the cache, so we made a server call:
self.assertEqual(
self.pop(),
((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid)))
self.assertEqual(self.pop(), (5, False, 'loadBefore', (b'1'*8, maxtid)))
# Note load_before uses the oid as the message id.
self.respond((b'1'*8, maxtid), (b'data', b'a'*8, None))
self.assertEqual(loaded.result(), (b'data', b'a'*8, None))
self.respond(5, (b'data', b'a'*8, None))
self.assertEqual(tuple(loaded.result()), (b'data', b'a'*8, None))
# If we make another request, it will be satisfied from the cache:
loaded = self.load_before(b'1'*8, maxtid)
......@@ -217,27 +222,23 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# the requests will be collapsed:
loaded2 = self.load_before(b'1'*8, maxtid)
self.assertEqual(
self.pop(),
((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid)))
self.respond((b'1'*8, maxtid), (b'data2', b'b'*8, None))
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(loaded2.result(), (b'data2', b'b'*8, None))
self.assertEqual(self.pop(), (6, False, 'loadBefore', (b'1'*8, maxtid)))
self.respond(6, (b'data2', b'b'*8, None))
self.assertEqual(tuple(loaded.result()), (b'data2', b'b'*8, None))
self.assertEqual(tuple(loaded2.result()), (b'data2', b'b'*8, None))
# Loading non-current data may also be satisfied from cache
loaded = self.load_before(b'1'*8, b'b'*8)
self.assertEqual(loaded.result(), (b'data', b'a'*8, b'b'*8))
self.assertEqual(tuple(loaded.result()), (b'data', b'a'*8, b'b'*8))
self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'c'*8)
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertEqual(tuple(loaded.result()), (b'data2', b'b'*8, None))
self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(
self.pop(),
((b'1'*8, b'_'*8), False, 'loadBefore', (b'1'*8, b'_'*8)))
self.respond((b'1'*8, b'_'*8), (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
self.assertEqual(self.pop(), (7, False, 'loadBefore', (b'1'*8, b'_'*8)))
self.respond(7, (b'data0', b'^'*8, b'_'*8))
self.assertEqual(tuple(loaded.result()), (b'data0', b'^'*8, b'_'*8))
# When committing transactions, we need to update the cache
# with committed data. To do this, we pass a (oid, data, resolved)
......@@ -259,8 +260,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
cache.load(b'4'*8))
self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8))
self.assertEqual(self.pop(),
(5, False, 'tpc_finish', (b'd'*8,)))
self.respond(5, b'e'*8)
(8, False, 'tpc_finish', (b'd'*8,)))
self.respond(8, b'e'*8)
self.assertEqual(committed.result(), b'e'*8)
self.assertEqual(cache.load(b'1'*8), None)
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8))
......@@ -274,8 +275,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(loaded.done() or f1.done())
self.assertEqual(
self.pop(),
[((b'1'*8, maxtid), False, 'loadBefore', (b'1'*8, maxtid)),
(6, False, 'foo', (1, 2))],
[(9, False, 'loadBefore', (b'1'*8, maxtid)),
(10, False, 'foo', (1, 2))],
)
exc = TypeError(43)
......@@ -301,8 +302,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# This time we'll send a lower protocol version. The client
# will send it back, because it's lower than the client's
# protocol:
protocol.data_received(sized(b'Z310'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z310')
protocol.data_received(sized(self.enc + b'310'))
self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'310')
self.assertEqual(self.pop(), (1, False, 'register', ('TEST', False)))
self.assertFalse(wrapper.notify_connected.called)
......@@ -337,8 +338,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
cache.store(b'2'*8, b'a'*8, None, '2 data')
self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None)
self.respond(2, b'e'*8)
self.pop(4)
......@@ -372,8 +373,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertTrue(cache)
self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None)
self.respond(2, b'e'*8)
self.pop(4)
......@@ -423,8 +424,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertEqual(sorted(loop.connecting), addrs[:1])
protocol = loop.protocol
transport = loop.transport
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None)
# Now, when the first connection fails, it won't be retried,
......@@ -441,8 +442,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport = self.start()
cache.store(b'4'*8, b'a'*8, None, '4 data')
cache.setLastTid('b'*8)
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None)
self.respond(2, 'a'*8)
self.pop()
......@@ -455,8 +456,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
self.assertFalse(transport is loop.transport)
protocol = loop.protocol
transport = loop.transport
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None)
self.respond(2, 'b'*8)
self.pop(4)
......@@ -475,13 +476,13 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We'll treat the first address as read-only and we'll let it connect:
loop.connect_connecting(addrs[0])
protocol, transport = loop.protocol, loop.transport
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
# We see that the client tried a writable connection:
self.assertEqual(self.pop(),
(1, False, 'register', ('TEST', False)))
# We respond with a read-only exception:
self.respond(1, (ReadOnlyError, ReadOnlyError()))
self.respond(1, ('ZODB.POSException.ReadOnlyError', ()), True)
self.assertTrue(self.is_read_only())
# The client tries for a read-only connection:
......@@ -507,8 +508,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# We connect the second address:
loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), b'Z3101')
loop.protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), self.enc + b'3101')
self.assertEqual(self.parse(loop.transport.pop()),
(1, False, 'register', ('TEST', False)))
self.assertTrue(self.is_read_only())
......@@ -542,8 +543,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
def test_invalidations_while_verifying(self):
# While we're verifying, invalidations are ignored
wrapper, cache, loop, client, protocol, transport = self.start()
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None)
self.pop(4)
self.send('invalidateTransaction', b'b'*8, [b'1'*8], called=False)
......@@ -560,8 +561,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
# Similarly, invalidations aren't processed while reconnecting:
protocol.data_received(sized(b'Z3101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z3101')
protocol.data_received(sized(self.enc + b'3101'))
self.assertEqual(self.unsized(transport.pop(2)), self.enc + b'3101')
self.respond(1, None)
self.pop(4)
self.send('invalidateTransaction', b'd'*8, [b'1'*8], called=False)
......@@ -604,7 +605,7 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
wrapper, cache, loop, client, protocol, transport = self.start()
with mock.patch("ZEO.asyncio.client.logger.error") as error:
self.assertFalse(error.called)
protocol.data_received(sized(b'Z200'))
protocol.data_received(sized(self.enc + b'200'))
self.assert_(isinstance(error.call_args[0][1], ProtocolError))
......@@ -688,6 +689,8 @@ class ClientTests(Base, setupstack.TestCase, ClientRunner):
protocol.connection_lost(None)
self.assertTrue(handle.cancelled)
class MsgpackClientTests(ClientTests):
enc = b'M'
class MemoryCache(object):
......@@ -750,12 +753,13 @@ class ServerTests(Base, setupstack.TestCase):
# connections. Servers are pretty passive.
def connect(self, finish=False):
protocol = server_protocol()
protocol = server_protocol(self.enc == b'M')
self.loop = protocol.loop
self.target = protocol.zeo_storage
if finish:
self.assertEqual(self.pop(parse=False), best_protocol_version)
protocol.data_received(sized(b'Z5'))
self.assertEqual(self.pop(parse=False),
self.enc + best_protocol_version)
protocol.data_received(sized(self.enc + b'5'))
return protocol
message_id = 0
......@@ -790,12 +794,13 @@ class ServerTests(Base, setupstack.TestCase):
self.assertFalse(protocol.zeo_storage.notify_connected.called)
# The server sends it's protocol.
self.assertEqual(self.pop(parse=False), best_protocol_version)
self.assertEqual(self.pop(parse=False),
self.enc + best_protocol_version)
# The client sends it's protocol:
protocol.data_received(sized(b'Z5'))
protocol.data_received(sized(self.enc + b'5'))
self.assertEqual(protocol.protocol_version, b'Z5')
self.assertEqual(protocol.protocol_version, self.enc + b'5')
protocol.zeo_storage.notify_connected.assert_called_once_with(protocol)
......@@ -823,7 +828,11 @@ class ServerTests(Base, setupstack.TestCase):
self.call('foo', target=None)
self.assertTrue(protocol.loop.transport.closed)
def server_protocol(zeo_storage=None,
class MsgpackServerTests(ServerTests):
enc = b'M'
def server_protocol(msgpack,
zeo_storage=None,
protocol_version=None,
addr=('1.2.3.4', '42'),
):
......@@ -831,7 +840,7 @@ def server_protocol(zeo_storage=None,
zeo_storage = mock.Mock()
loop = Loop()
sock = () # anything not None
new_connection(loop, addr, sock, zeo_storage)
new_connection(loop, addr, sock, zeo_storage, msgpack)
if protocol_version:
loop.protocol.data_received(sized(protocol_version))
return loop.protocol
......@@ -861,4 +870,6 @@ def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(ClientTests))
suite.addTest(unittest.makeSuite(ServerTests))
suite.addTest(unittest.makeSuite(MsgpackClientTests))
suite.addTest(unittest.makeSuite(MsgpackServerTests))
return suite
......@@ -100,6 +100,7 @@ class ZEOOptionsMixin:
self.add("client_conflict_resolution",
"zeo.client_conflict_resolution",
default=0)
self.add("msgpack", "zeo.msgpack", default=0)
self.add("invalidation_queue_size", "zeo.invalidation_queue_size",
default=100)
self.add("invalidation_age", "zeo.invalidation_age")
......@@ -342,6 +343,7 @@ def create_server(storages, options):
storages,
read_only = options.read_only,
client_conflict_resolution=options.client_conflict_resolution,
msgpack=options.msgpack,
invalidation_queue_size = options.invalidation_queue_size,
invalidation_age = options.invalidation_age,
transaction_timeout = options.transaction_timeout,
......
......@@ -115,6 +115,16 @@
</description>
</key>
<key name="msgpack" datatype="boolean" required="no" default="false">
<description>
Use msgpack to serialize and de-serialize ZEO protocol messages.
An advantage of using msgpack for ZEO communication is that
it's a little bit faster and a ZEO server can support Python 2
or Python 3 clients (but not both).
</description>
</key>
</sectiontype>
</component>
......@@ -17,7 +17,7 @@ Let's start a Z4 server
... '''
>>> addr, stop = start_server(
... storage_conf, dict(invalidation_queue_size=5), protocol=b'Z4')
... storage_conf, dict(invalidation_queue_size=5), protocol=b'4')
A current client should be able to connect to a old server:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment