Commit cbf756c7 authored by Jim Fulton's avatar Jim Fulton

Hardened asyncio interface

- Moved cache into async thread to avoid lots of locking.

- Setup delegation to storage.

- Provide thread wrapper that runs the async protocol in a thread.
parent 08703051
from pickle import loads, dumps
import asyncio
import concurrent.futures
import logging
import struct
import threading
logger = logging.getLogger(__name__)
"""Low-level protocol adapters
class Disconnected(Exception):
pass
Derived from ngi connection adapters and filling a similar role to the
old zrpc smac layer for sized messages.
"""
import struct
class BaseTransportAdapter:
def __init__(self, transport):
self.transport = transport
def close(self):
self.transport.close
self.transport.close()
def is_closing(self):
return self.transport.is_closing()
def get_extra_info(self, name, default=None):
......@@ -57,8 +52,7 @@ class SizedTransportAdapter(BaseTransportAdapter):
"""
def write(self, message):
self.transport.write(struct.pack(">I", len(message)))
self.transport.write(message)
self.transport.writelines((struct.pack(">I", len(message)), message))
def writelines(self, list_of_data):
self.transport.writelines(sized_iter(list_of_data))
......@@ -68,7 +62,6 @@ def sized_iter(data):
yield struct.pack(">I", len(message))
yield message
class SizedProtocolAdapter(BaseProtocolAdapter):
def __init__(self, protocol):
......@@ -103,144 +96,3 @@ class SizedProtocolAdapter(BaseProtocolAdapter):
self.want = 4
self.getting_size = True
self.protocol.data_received(collected)
class ClientProtocol(asyncio.Protocol):
"""asyncio low-level ZEO client interface
"""
def __init__(self, addr,
client=None, storage_key='1', read_only=False, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self.loop = loop
self.addr = addr
self.storage_key = storage_key
self.read_only = read_only
self.client = client
self.connected = asyncio.Future()
def protocol_factory(self):
return SizedProtocolAdapter(self)
def connect(self):
self.protocol_version = None
self.futures = {} # outstanding requests {request_id -> future}
if isinstance(self.addr, tuple):
host, port = self.addr
cr = self.loop.create_connection(self.protocol_factory, host, port)
else:
cr = self.loop.create_unix_connection(
self.protocol_factory, self.addr)
future = asyncio.async(cr, loop=self.loop)
@future.add_done_callback
def done_connecting(future):
e = future.exception()
if e is not None:
self.connected.set_exception(e)
return self.connected
def connection_made(self, transport):
logger.info("Connected")
self.transport = SizedTransportAdapter(transport)
def connection_lost(self, exc):
logger.info("Disconnected, %r", exc)
for f in self.futures.values():
d.set_exception(exc or Disconnected())
self.futures = {}
self.connect() # Reconnect
exception_type_type = type(Exception)
def data_received(self, data):
if self.protocol_version is None:
self.protocol_version = data
self.transport.write(data) # pleased to meet you version :)
self.call_async('register', self.storage_key, self.read_only)
self.connected.set_result(data)
else:
msgid, async, name, args = loads(data)
if name == '.reply':
future = self.futures.pop(msgid)
if (isinstance(args, tuple) and len(args) > 1 and
type(args[0]) == self.exception_type_type and
issubclass(r_args[0], Exception)
):
future.set_exception(args[0]) # XXX security checks
else:
future.set_result(args)
else:
assert async # clients only get async calls
if self.client:
getattr(self.client, name)(*args) # XXX security
else:
logger.info('called %r %r', (name, args))
def call_async(self, method, *args):
# XXX connection status...
self.transport.write(dumps((0, True, method, args), 3))
message_id = 0
def call(self, method, *args):
future = asyncio.Future()
self.message_id += 1
self.futures[self.message_id] = future
self.transport.write(dumps((self.message_id, False, method, args), 3))
return future
def call_concurrent(self, result_future, method, *args):
future = self.call(method, *args)
@future.add_done_callback
def concurrent_result(future):
if future.exception() is None:
result_future.set_result(future.result())
else:
result_future.set_exception(future.exception())
class ClientThread:
"""Thread wrapper for client interface
A ClientProtocol is run in a dedicated thread.
Calls to it are made in a thread-safe fashion.
"""
def __init__(self, addr,
client=None, storage_key='1', read_only=False, timeout=None):
self.addr = addr
self.client = client
self.storage_key = storage_key
self.read_only = read_only
self.connected = concurrent.futures.Future()
threading.Thread(target=self.run,
name='zeo_client_'+storage_key,
daemon=True,
).start()
self.connected.result(timeout)
def run(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.loop = loop
self.proto = ClientProtocol(
self.addr, None, self.storage_key, self.read_only)
f = self.proto.connect()
@f.add_done_callback
def thread_done_connecting(future):
e = future.exception()
if e is not None:
self.connected.set_exception(e)
else:
self.connected.set_result(None) # XXX prob return some info
loop.run_forever()
def call_async(self, method, *args):
self.loop.call_soon_threadsafe(self.proto.call_async, method, *args)
def call(self, method, *args, timeout=None):
result = concurrent.futures.Future()
self.loop.call_soon_threadsafe(
self.proto.call_concurrent, result, method, *args)
return result.result()
from pickle import loads, dumps
from ZODB.ConflictResolution import ResolvedSerial
import asyncio
import concurrent.futures
import logging
import random
import threading
import ZEO.Exceptions
from . import adapters
logger = logging.getLogger(__name__)
local_random = random.Random() # use separate generator to facilitate tests
class Client(asyncio.Protocol):
"""asyncio low-level ZEO client interface
"""
# All of the code in this class runs in a single dedicated
# thread. Thus, we can mostly avoid worrying about interleaved
# operations.
# One place where special care was required was in cache setup on
# connect. See finish connect below.
def __init__(self, addr, client, cache, storage_key, read_only, loop):
self.loop = loop
self.addr = addr
self.storage_key = storage_key
self.read_only = read_only
self.client = client
for name in self.client_delegated:
setattr(self, name, getattr(client, name))
self.info = client.info
self.cache = cache
self.disconnected()
closed = False
def close(self):
self.closed = True
self.transport.close()
self.cache.close()
def protocol_factory(self):
return adapters.SizedProtocolAdapter(self)
def disconnected(self):
self.ready = False
self.connected = concurrent.futures.Future()
self.protocol_version = None
self.futures = {}
self.connect()
def connect(self):
if isinstance(self.addr, tuple):
host, port = self.addr
cr = self.loop.create_connection(self.protocol_factory, host, port)
else:
cr = self.loop.create_unix_connection(
self.protocol_factory, self.addr)
cr = asyncio.async(cr, loop=self.loop)
@cr.add_done_callback
def done_connecting(future):
if future.exception() is not None:
# keep trying
self.loop.call_later(1 + local_random.random(), self.connect)
def connection_made(self, transport):
logger.info("Connected")
self.transport = adapters.SizedTransportAdapter(transport)
def connection_lost(self, exc):
exc = exc or ClientDisconnected()
logger.info("Disconnected, %r", exc)
for f in self.futures.values():
f.set_exception(exc)
self.disconnected()
def finish_connect(self, protocol_version):
# We use a promise model rather than coroutines here because
# for the most part, this class is reactive a coroutines
# aren't a good model of it's activities. During
# initialization, however, we use promises to provide an
# impertive flow.
# The promise(/future) implementation we use differs from
# asyncio.Future in that callbacks are called immediately,
# rather than using the loops call_soon. We want to avoid a
# race between invalidations and cache initialization. In
# particular, after calling lastTransaction or
# getInvalidations, we want to make sure we set the cache's
# lastTid before processing subsequent invalidations.
self.protocol_version = protocol_version
self.transport.write(protocol_version)
register = self.promise('register', self.storage_key, self.read_only)
lastTransaction = self.promise('lastTransaction')
cache = self.cache
@register(lambda _ : lastTransaction)
def verify(server_tid):
if not cache:
return server_tid
cache_tid = cache.getLastTid()
if not cache_tid:
logger.error("Non-empty cache w/o tid -- clearing")
cache.clear()
self.client.invalidateCache()
return server_tid
elif cache_tid == server_tid:
logger.info("Cache up to date %r", server_tid)
return server_tid
elif cache_tid >= server_tid:
raise AssertionError("Server behind client, %r < %r",
server_tid, cache_tid)
@self.promise('getInvalidations', cache_tid)
def verify_invalidations(vdata):
if vdata:
tid, oids = vdata
for oid in oids:
cache.invalidate(oid, None)
return tid
else:
# cache is too old
self.cache.clear()
self.client.invalidateCache()
return server_tid
return verify_invalidations
@verify
def finish_verification(server_tid):
cache.setLastTid(server_tid)
self.ready = True
finish_verification(
lambda _ : self.connected.set_result(None),
lambda e: self.connected.set_exception(e),
)
exception_type_type = type(Exception)
def data_received(self, data):
if self.protocol_version is None:
self.finish_connect(data)
else:
msgid, async, name, args = loads(data)
if name == '.reply':
future = self.futures.pop(msgid)
if (isinstance(args, tuple) and len(args) > 1 and
type(args[0]) == self.exception_type_type and
issubclass(r_args[0], Exception)
):
future.set_exception(args[0])
else:
future.set_result(args)
else:
assert async # clients only get async calls
if name in self.client_methods:
getattr(self, name)(*args)
else:
raise AttributeError(name)
def call_async(self, method, *args):
if self.ready:
self.transport.write(dumps((0, True, method, args), 3))
else:
raise ZEO.Exceptions.ClientDisconnected()
def call_async_threadsafe(self, future, method, args):
try:
self.call_async(method, *args)
except Exception as e:
future.set_exception(e)
else:
future.set_result(None)
message_id = 0
def _call(self, future, method, args):
self.message_id += 1
self.futures[self.message_id] = future
self.transport.write(dumps((self.message_id, False, method, args), 3))
return future
def promise(self, method, *args):
return self._call(Promise(), method, args)
def call_threadsafe(self, result_future, method, args):
if self.ready:
return self._call(result_future, method, args)
@self.connected.add_done_callback
def done(future):
e = future.exception()
if e is not None:
result_future.set_exception(e)
else:
self.call_threadsafe(result_future, method, args)
return result_future
# Special methods because they update the cache.
def load_threadsafe(self, future, oid):
data = self.cache.load(oid)
if data is not None:
future.set_result(data)
else:
@self.promise('loadEx', oid)
def load(data):
future.set_result(data)
data, tid = data
self.cache.store(oid, tid, None, data)
load.catch(future.set_exception)
def load_before_threadsafe(self, future, oid, tid):
data = self.cache.loadBefore(oid, tid)
if data is not None:
future.set_result(data)
else:
@self.promise('loadBefore', oid, tid)
def load_before(data):
future.set_result(data)
if data:
data, start, end = data
self.cache.store(oid, start, end, data)
load_before.catch(future.set_exception)
def tpc_finish_threadsafe(self, future, tid, updates):
@self.promise('tpc_finish', tid)
def committed(_):
cache = self.cache
for oid, s, data in updates:
cache.invalidate(oid, tid)
if data and s != ResolvedSerial:
cache.store(oid, tid, None, data)
cache.setLastTid(tid)
future.set_result(None)
committed.catch(future.set_exception)
# Methods called by the server:
client_methods = (
'invalidateTransaction', 'serialnos', 'info',
'receiveBlobStart', 'receiveBlobChunk', 'receiveBlobStop',
)
client_delegated = client_methods[1:]
def invalidateTransaction(self, tid, oids):
for oid in oids:
self.cache.invalidate(oid, tid)
self.cache.setLastTid(tid)
self.client.invalidateTransaction(tid, oids)
class ClientRunner:
def set_options(self, addr, wrapper, cache, storage_key, read_only,
timeout=30):
self.__args = addr, wrapper, cache, storage_key, read_only
self.timeout = timeout
self.connected = concurrent.futures.Future()
def setup_delegation(self, loop):
self.loop = loop
self.client = Client(*self.__args, loop=loop)
from concurrent.futures import Future
call_soon_threadsafe = loop.call_soon_threadsafe
def call(meth, *args, timeout=False):
result = Future()
call_soon_threadsafe(meth, result, *args)
return self.wait_for_result(result, timeout)
self.__call = call
@self.client.connected.add_done_callback
def thread_done_connecting(future):
e = future.exception()
if e is not None:
self.connected.set_exception(e)
else:
self.connected.set_result(None)
def wait_for_result(self, future, timeout):
return future.result(self.timeout if timeout is False else timeout)
def call(self, method, *args, timeout=None):
return self.__call(self.client.call_threadsafe, method, args)
def callAsync(self, method, *args):
return self.__call(self.client.call_async_threadsafe, method, args)
def load(self, oid):
return self.__call(self.client.load_threadsafe, oid)
def load_before(self, oid, tid):
return self.__call(self.client.load_before_threadsafe, oid, tid)
def tpc_finish(self, tid, updates):
return self.__call(self.client.tpc_finish_threadsafe, tid, updates)
class ClientThread(ClientRunner):
"""Thread wrapper for client interface
A ClientProtocol is run in a dedicated thread.
Calls to it are made in a thread-safe fashion.
"""
def __init__(self, addr, client, cache,
storage_key='1', read_only=False, timeout=30):
self.set_options(addr, client, cache, storage_key, read_only, timeout)
threading.Thread(
target=self.run,
args=(addr, client, cache, storage_key, read_only),
name='zeo_client_'+storage_key,
daemon=True,
).start()
self.connected.result(timeout)
def run(self, *args):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.setup_delegation(loop, *args)
loop.run_forever()
class Promise:
"""Lightweight future with a partial promise API.
These are lighweight because they call callbacks synchronously
rather than through an event loop, and because they ony support
single callbacks.
"""
# Note that we can know that they are completed after callbacks
# are set up because they're used to make network requests.
# Requests are made by writing to a transport. Because we're used
# in a single-threaded protocol, we can't get a response and be
# completed if the callbacks are set in the same code that
# created the promise, which they are.
next = success_callback = error_callback = None
def __call__(self, success_callback = None, error_callback = None):
self.next = self.__class__()
self.success_callback = success_callback
self.error_callback = error_callback
return self.next
def catch(self, error_callback):
self.error_callback = error_callback
def set_exception(self, exc):
self._notify(None, exc)
def set_result(self, result):
self._notify(result, None)
def _notify(self, result, exc):
next = self.next
if exc is not None:
if self.error_callback is not None:
try:
result = self.error_callback(exc)
except Exception:
logger.exception("Exception handling error %s", exc)
if next is not None:
next.set_exception(exc)
else:
if next is not None:
next.set_result(result)
elif next is not None:
next.set_exception(exc)
else:
if self.success_callback is not None:
try:
result = self.success_callback(result)
except Exception as exc:
logger.exception("Exception in success callback")
if next is not None:
next.set_exception(exc)
else:
if next is not None:
if isinstance(result, Promise):
result(next.set_result, next.set_exception)
else:
next.set_result(result)
elif next is not None:
next.set_result(result)
import asyncio
class Loop:
def __init__(self, debug=True):
self.get_debug = lambda : debug
def call_soon(self, func, *args):
func(*args)
def create_connection(self, protocol_factory, host, port):
self.protocol = protocol = protocol_factory()
self.transport = transport = Transport()
future = asyncio.Future(loop=self)
future.set_result((transport, protocol))
protocol.connection_made(transport)
return future
def call_soon_threadsafe(self, func, *args):
func(*args)
class Transport:
def __init__(self):
self.data = []
def write(self, data):
self.data.append(data)
def writelines(self, lines):
self.data.extend(lines)
def pop(self, count=None):
if count:
r = self.data[:count]
del self.data[:count]
else:
r = self.data[:]
del self.data[:]
return r
closed = False
def close(self):
self.closed = True
from zope.testing import setupstack
from concurrent.futures import Future
from unittest import mock
import asyncio
import collections
import pickle
import struct
import unittest
from .testing import Loop
from .client import ClientRunner
from ..Exceptions import ClientDisconnected
class AsyncTests(setupstack.TestCase, ClientRunner):
addr = ('127.0.0.1', 8200)
def start(self):
# To create a client, we need to specify an address, a client
# object and a cache.
wrapper = mock.Mock()
cache = MemoryCache()
self.set_options(self.addr, wrapper, cache, 'TEST', False)
# We can also provide an event loop. We'll use a testing loop
# so we don't have to actually make any network connection.
self.setup_delegation(Loop())
protocol = self.loop.protocol
transport = self.loop.transport
def send(meth, *args):
protocol.data_received(
sized(pickle.dumps((0, True, meth, args), 3)))
def respond(message_id, result):
protocol.data_received(
sized(pickle.dumps((message_id, False, '.reply', result), 3)))
return (wrapper, cache, self.loop, self.client, protocol, transport,
send, respond)
def wait_for_result(self, future, timeout):
return future
def testBasics(self):
# Here, we'll go through the basic usage of the asyncio ZEO
# network client. The client is responsible for the core
# functionality of a ZEO client storage. The client storage
# is largely just a wrapper around the asyncio client.
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
# The client isn't connected until the server sends it some data.
self.assertFalse(client.connected.done() or transport.data)
# The server sends the client some data:
protocol.data_received(sized(b'Z101'))
# The client sends back a handshake, and registers the
# storage, and requests the last transaction.
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
parse = self.parse
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
# Actually, the client isn't connected until it initializes it's cache:
self.assertFalse(client.connected.done() or transport.data)
# If we try to make calls while the client is connecting, they're queued
f1 = self.call('foo', 1, 2)
self.assertFalse(f1.done())
# If we try to make an async call, we get an immediate error:
f2 = self.callAsync('bar', 3, 4)
self.assert_(isinstance(f2.exception(), ClientDisconnected))
# Let's respond to those first 2 calls:
respond(1, None)
respond(2, 'a'*8)
# Now we're connected, the cache was initialized, and the
# queued message has been sent:
self.assert_(client.connected.done())
self.assertEqual(cache.getLastTid(), 'a'*8)
self.assertEqual(parse(transport.pop()), (3, False, 'foo', (1, 2)))
respond(3, 42)
self.assertEqual(f1.result(), 42)
# Now we can make async calls:
f2 = self.callAsync('bar', 3, 4)
self.assert_(f2.done() and f2.exception() is None)
self.assertEqual(parse(transport.pop()), (0, True, 'bar', (3, 4)))
# Loading objects gets special handling to leverage the cache.
loaded = self.load(b'1'*8)
# The data wasn't in the cache, so we make a server call:
self.assertEqual(parse(transport.pop()),
(4, False, 'loadEx', (b'1'*8,)))
respond(4, (b'data', b'a'*8))
self.assertEqual(loaded.result(), (b'data', b'a'*8))
# If we make another request, it will be satisfied from the cache:
loaded = self.load(b'1'*8)
self.assertEqual(loaded.result(), (b'data', b'a'*8))
self.assertFalse(transport.data)
# Let's send an invalidation:
send('invalidateTransaction', b'b'*8, [b'1'*8])
wrapper.invalidateTransaction.assert_called_with(b'b'*8, [b'1'*8])
# Now, if we try to load current again, we'll make a server request.
loaded = self.load(b'1'*8)
self.assertEqual(parse(transport.pop()),
(5, False, 'loadEx', (b'1'*8,)))
respond(5, (b'data2', b'b'*8))
self.assertEqual(loaded.result(), (b'data2', b'b'*8))
# Loading non-current data may also be satisfied from cache
loaded = self.load_before(b'1'*8, b'b'*8)
self.assertEqual(loaded.result(), (b'data', b'a'*8, b'b'*8))
self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'c'*8)
self.assertEqual(loaded.result(), (b'data2', b'b'*8, None))
self.assertFalse(transport.data)
loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(parse(transport.pop()),
(6, False, 'loadBefore', (b'1'*8, b'_'*8)))
respond(6, (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
# When committing transactions, we need to update the cache
# with committed data. To do this, we pass a (oid, tid, data)
# iteratable to tpc_finish_threadsafe.
from ZODB.ConflictResolution import ResolvedSerial
committed = self.tpc_finish(
b'd'*8,
[(b'2'*8, b'd'*8, 'committed 2'),
(b'1'*8, ResolvedSerial, 'committed 3'),
(b'4'*8, b'd'*8, 'committed 4'),
])
self.assertFalse(committed.done() or
cache.load(b'2'*8) or
cache.load(b'4'*8))
self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8))
self.assertEqual(parse(transport.pop()),
(7, False, 'tpc_finish', (b'd'*8,)))
respond(7, None)
self.assertEqual(committed.result(), None)
self.assertEqual(cache.load(b'1'*8), None)
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'd'*8))
self.assertEqual(cache.load(b'4'*8), ('committed 4', b'd'*8))
# Is the protocol is disconnected, it will reconnect and will
# resolve outstanding requests with exceptions:
loaded = self.load(b'1'*8)
f1 = self.call('foo', 1, 2)
self.assertFalse(loaded.done() or f1.done())
self.assertEqual(parse(transport.pop()),
[(8, False, 'loadEx', (b'1'*8,)),
(9, False, 'foo', (1, 2))],
)
exc = TypeError(43)
protocol.connection_lost(exc)
self.assertEqual(loaded.exception(), exc)
self.assertEqual(f1.exception(), exc)
# Because we reconnected, a new protocol and transport were created:
self.assert_(protocol is not loop.protocol)
self.assert_(transport is not loop.transport)
protocol = loop.protocol
transport = loop.transport
# and we have a new incomplete connect future:
self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(parse(transport.pop()),
[(10, False, 'register', ('TEST', False)),
(11, False, 'lastTransaction', ()),
])
respond(10, None)
respond(11, b'd'*8)
# Because the server tid matches the cache tid, we're done connecting
self.assert_(client.connected.done() and not transport.data)
self.assertEqual(cache.getLastTid(), b'd'*8)
# Because we were able to update the cache, we didn't have to
# invalidate the database cache:
wrapper.invalidateTransaction.assert_not_called()
# The close method closes the connection and cache:
client.close()
self.assert_(transport.closed and cache.closed)
# The client doesn't reconnect
self.assertEqual(loop.protocol, protocol)
self.assertEqual(loop.transport, transport)
def test_cache_behind(self):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
cache.setLastTid(b'a'*8)
cache.store(b'4'*8, b'a'*8, None, '4 data')
cache.store(b'2'*8, b'a'*8, None, '2 data')
self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
respond(2, b'e'*8)
# We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done())
self.assertEqual(self.parse(transport.pop()),
(3, False, 'getInvalidations', (b'a'*8, )))
respond(3, (b'e'*8, [b'4'*8]))
# Now that verification is done, we're done connecting
self.assert_(client.connected.done() and not transport.data)
self.assertEqual(cache.getLastTid(), b'e'*8)
# And the cache has been updated:
self.assertEqual(cache.load(b'2'*8),
('2 data', b'a'*8)) # unchanged
self.assertEqual(cache.load(b'4'*8), None)
# Because we were able to update the cache, we didn't have to
# invalidate the database cache:
wrapper.invalidateCache.assert_not_called()
def test_cache_way_behind(self):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
cache.setLastTid(b'a'*8)
cache.store(b'4'*8, b'a'*8, None, '4 data')
self.assertTrue(cache)
self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
respond(2, b'e'*8)
# We have to verify the cache, so we're not done connecting:
self.assertFalse(client.connected.done())
self.assertEqual(self.parse(transport.pop()),
(3, False, 'getInvalidations', (b'a'*8, )))
# We respond None, indicating that we're too far out of date:
respond(3, None)
# Now that verification is done, we're done connecting
self.assert_(client.connected.done() and not transport.data)
self.assertEqual(cache.getLastTid(), b'e'*8)
# But the cache is now empty and we invalidated the database cache
self.assertFalse(cache)
wrapper.invalidateCache.assert_called_with()
def test_cache_crazy(self):
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
cache.setLastTid(b'e'*8)
cache.store(b'4'*8, b'e'*8, None, '4 data')
self.assertTrue(cache)
self.assertFalse(client.connected.done() or transport.data)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
respond(2, b'a'*8)
# The server tid is less than the client tid, WTF? We error
with self.assertRaisesRegex(AssertionError, 'Server behind client'):
client.connected.result()
# todo:
# bad cache validation, make sure ZODB cache is cleared
# cache boolean value in interface
def unsized(self, data, unpickle=False):
result = []
while data:
size, message, *data = data
self.assertEqual(struct.unpack(">I", size)[0], len(message))
if unpickle:
message = pickle.loads(message)
result.append(message)
if len(result) == 1:
result = result[0]
return result
def parse(self, data):
return self.unsized(data, True)
def response(*data):
return sized(pickle.dumps(data, 3))
def sized(message):
return struct.pack(">I", len(message)) + message
class MemoryCache:
def __init__(self):
# { oid -> [(start, end, data)] }
self.data = collections.defaultdict(list)
self.last_tid = None
clear = __init__
closed = False
def close(self):
self.closed = True
def __len__(self):
return len(self.data)
def load(self, oid):
revisions = self.data[oid]
if revisions:
start, end, data = revisions[-1]
if not end:
return data, start
return None
def store(self, oid, start_tid, end_tid, data):
assert start_tid is not None
revisions = self.data[oid]
revisions.append((start_tid, end_tid, data))
revisions.sort()
def loadBefore(self, oid, tid):
for start, end, data in self.data[oid]:
if start < tid and (end is None or end >= tid):
return data, start, end
def invalidate(self, oid, tid):
revisions = self.data[oid]
if revisions:
if tid is None:
del revisions[:]
else:
start, end, data = revisions[-1]
if end is None:
revisions[-1] = start, tid, data
def getLastTid(self):
return self.last_tid
def setLastTid(self, tid):
self.last_tid = tid
def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(AsyncTests))
return suite
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment