Commit 107f1077 authored by Jim Fulton's avatar Jim Fulton

Updated the asyncio to support multiple addresses and read-only fallback

Also got rid of the adapter machinery. It didn't buy enough to justify
the wrapping.
parent e0c64161
"""Low-level protocol adapters
Derived from ngi connection adapters and filling a similar role to the
old zrpc smac layer for sized messages.
"""
import struct
class BaseTransportAdapter:
def __init__(self, transport):
self.transport = transport
def close(self):
self.transport.close()
def is_closing(self):
return self.transport.is_closing()
def get_extra_info(self, name, default=None):
return self.transport.get_extra_info(name, default)
def pause_reading(self):
self.transport.pause_reading()
def resume_reading(self):
self.transport.resume_reading()
def abort(self):
self.transport.abort()
def can_write_eof(self):
return self.transport.can_write_eof()
def get_write_buffer_size(self):
return self.transport.get_write_buffer_size()
def get_write_buffer_limits(self):
return self.transport.get_write_buffer_limits()
def set_write_buffer_limits(self, high=None, low=None):
self.transport.set_write_buffer_limits(high, low)
def write(self, data):
self.transport.write(data)
def writelines(self, list_of_data):
self.transport.writelines(list_of_data)
def write_eof(self):
self.transport.write_eof()
class BaseProtocolAdapter:
def __init__(self, protocol):
self.protocol = protocol
def connection_made(self, transport):
self.protocol.connection_made(transport)
def connection_lost(self, exc):
self.protocol.connection_lost(exc)
def data_received(self, data):
self.protocol.data_received(data)
def eof_received(self):
return self.protocol.eof_received()
class SizedTransportAdapter(BaseTransportAdapter):
"""Sized-message transport adapter
"""
def write(self, message):
self.transport.writelines((struct.pack(">I", len(message)), message))
def writelines(self, list_of_data):
self.transport.writelines(sized_iter(list_of_data))
def sized_iter(data):
for message in data:
yield struct.pack(">I", len(message))
yield message
class SizedProtocolAdapter(BaseProtocolAdapter):
def __init__(self, protocol):
self.protocol = protocol
self.want = 4
self.got = 0
self.getting_size = True
self.input = []
def data_received(self, data):
self.got += len(data)
self.input.append(data)
while self.got >= self.want:
extra = self.got - self.want
if extra == 0:
collected = b''.join(self.input)
self.input = []
else:
input = self.input
self.input = [data[-extra:]]
input[-1] = input[-1][:-extra]
collected = b''.join(input)
self.got = extra
if self.getting_size:
# we were recieving the message size
assert self.want == 4
self.want = struct.unpack(">I", collected)[0]
self.getting_size = False
else:
self.want = 4
self.getting_size = True
self.protocol.data_received(collected)
from pickle import loads, dumps
from ZODB.ConflictResolution import ResolvedSerial
from struct import unpack
import asyncio
import concurrent.futures
import logging
import random
import threading
import ZEO.Exceptions
from . import adapters
import ZODB.POSException
logger = logging.getLogger(__name__)
Fallback = object()
local_random = random.Random() # use separate generator to facilitate tests
class Client(asyncio.Protocol):
class Closed(Exception):
"""A connection has been closed
"""
class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO client interface
"""
......@@ -23,34 +29,46 @@ class Client(asyncio.Protocol):
# One place where special care was required was in cache setup on
# connect. See finish connect below.
transport = protocol_version = None
def __init__(self, addr, client, storage_key, read_only, loop,
connect_timeout=1):
"""Create a client interface
def __init__(self, addr, client, cache, storage_key, read_only, loop):
addr is either a host,port tuple or a string file name.
client is a ClientStorage. It must be thread safe.
cache is a ZEO.interfaces.IClientCache.
"""
self.loop = loop
self.addr = addr
self.storage_key = storage_key
self.read_only = read_only
self.name = "%s(%r, %r, %r)" % (
self.__class__.__name__, addr, storage_key, read_only)
self.client = client
for name in self.client_delegated:
setattr(self, name, getattr(client, name))
self.info = client.info
self.cache = cache
self.disconnected()
self.connect_timeout = connect_timeout
self.futures = {} # { message_id -> future }
self.input = []
self.connect()
def __repr__(self):
return self.name
closed = False
def close(self):
if not self.closed:
self.closed = True
self._connecting.cancel()
if self.transport is not None:
self.transport.close()
self.cache.close()
for future in self.futures.values():
future.set_exception(Closed())
self.futures.clear()
def protocol_factory(self):
return adapters.SizedProtocolAdapter(self)
def disconnected(self):
self.ready = False
self.connected = concurrent.futures.Future()
self.protocol_version = None
self.futures = {}
self.connect()
return self
def connect(self):
if isinstance(self.addr, tuple):
......@@ -60,24 +78,38 @@ class Client(asyncio.Protocol):
cr = self.loop.create_unix_connection(
self.protocol_factory, self.addr)
cr = asyncio.async(cr, loop=self.loop)
self._connecting = cr = asyncio.async(cr, loop=self.loop)
@cr.add_done_callback
def done_connecting(future):
if future.exception() is not None:
# keep trying
self.loop.call_later(1 + local_random.random(), self.connect)
if not self.closed:
self.loop.call_later(1 + local_random.random(),
self.connect)
def connection_made(self, transport):
logger.info("Connected")
self.transport = adapters.SizedTransportAdapter(transport)
logger.info("Connected %s", self)
self.transport = transport
writelines = transport.writelines
from struct import pack
def write(message):
writelines((pack(">I", len(message)), message))
self._write = write
def connection_lost(self, exc):
exc = exc or ClientDisconnected()
logger.info("Disconnected, %r", exc)
if exc is None:
# we were closed
for f in self.futures.values():
f.cancel()
else:
logger.info("Disconnected, %s, %r", self, exc)
for f in self.futures.values():
f.set_exception(exc)
self.disconnected()
self.client.disconnected(self)
def finish_connect(self, protocol_version):
# We use a promise model rather than coroutines here because
......@@ -95,56 +127,74 @@ class Client(asyncio.Protocol):
# lastTid before processing subsequent invalidations.
self.protocol_version = protocol_version
self.transport.write(protocol_version)
register = self.promise('register', self.storage_key, self.read_only)
lastTransaction = self.promise('lastTransaction')
cache = self.cache
self._write(protocol_version)
@register(lambda _ : lastTransaction)
def verify(server_tid):
if not cache:
return server_tid
register = self.promise(
'register', self.storage_key,
self.read_only if self.read_only is not Fallback else False,
)
# Get lastTransaction in flight right away to make successful
# connection quicker
lastTransaction = self.promise('lastTransaction')
cache_tid = cache.getLastTid()
if not cache_tid:
logger.error("Non-empty cache w/o tid -- clearing")
cache.clear()
self.client.invalidateCache()
return server_tid
elif cache_tid == server_tid:
logger.info("Cache up to date %r", server_tid)
return server_tid
elif cache_tid >= server_tid:
raise AssertionError("Server behind client, %r < %r",
server_tid, cache_tid)
@register
def registered(_):
if self.read_only is Fallback:
self.read_only = False
self.client.registered(self, lastTransaction)
@register.catch
def register_failed(exc):
if (isinstance(exc, ZODB.POSException.ReadOnlyError) and
self.read_only is Fallback):
# We tried a write connection, degrade to a read-only one
self.read_only = True
register = self.promise(
'register', self.storage_key, self.read_only)
lastTransaction = self.promise('lastTransaction')
@self.promise('getInvalidations', cache_tid)
def verify_invalidations(vdata):
if vdata:
tid, oids = vdata
for oid in oids:
cache.invalidate(oid, None)
return tid
else:
# cache is too old
self.cache.clear()
self.client.invalidateCache()
return server_tid
@register
def registered(_):
self.client.registered(self, lastTransaction)
return verify_invalidations
@register.catch
def register_failed(exc):
self.client.register_failed(self, exc)
@verify
def finish_verification(server_tid):
cache.setLastTid(server_tid)
self.ready = True
else:
self.client.register_failed(self, exc)
finish_verification(
lambda _ : self.connected.set_result(None),
lambda e: self.connected.set_exception(e),
)
got = 0
want = 4
getting_size = True
def data_received(self, data):
self.got += len(data)
self.input.append(data)
while self.got >= self.want:
extra = self.got - self.want
if extra == 0:
collected = b''.join(self.input)
self.input = []
else:
input = self.input
self.input = [data[-extra:]]
input[-1] = input[-1][:-extra]
collected = b''.join(input)
self.got = extra
if self.getting_size:
# we were recieving the message size
assert self.want == 4
self.want = unpack(">I", collected)[0]
self.getting_size = False
else:
self.want = 4
self.getting_size = True
self.message_received(collected)
exception_type_type = type(Exception)
def data_received(self, data):
def message_received(self, data):
if self.protocol_version is None:
self.finish_connect(data)
else:
......@@ -153,45 +203,191 @@ class Client(asyncio.Protocol):
future = self.futures.pop(msgid)
if (isinstance(args, tuple) and len(args) > 1 and
type(args[0]) == self.exception_type_type and
issubclass(r_args[0], Exception)
issubclass(args[0], Exception)
):
future.set_exception(args[0])
future.set_exception(args[1])
else:
future.set_result(args)
else:
assert async # clients only get async calls
if name in self.client_methods:
getattr(self, name)(*args)
getattr(self.client, name)(*args)
else:
raise AttributeError(name)
def call_async(self, method, *args):
if self.ready:
self.transport.write(dumps((0, True, method, args), 3))
else:
raise ZEO.Exceptions.ClientDisconnected()
def call_async_threadsafe(self, future, method, args):
try:
self.call_async(method, *args)
except Exception as e:
future.set_exception(e)
else:
future.set_result(None)
def call_async(self, method, args):
self._write(dumps((0, True, method, args), 3))
message_id = 0
def _call(self, future, method, args):
def call(self, future, method, args):
self.message_id += 1
self.futures[self.message_id] = future
self.transport.write(dumps((self.message_id, False, method, args), 3))
self._write(dumps((self.message_id, False, method, args), 3))
return future
def promise(self, method, *args):
return self._call(Promise(), method, args)
return self.call(Promise(), method, args)
# Methods called by the server:
client_methods = (
'invalidateTransaction', 'serialnos', 'info',
'receiveBlobStart', 'receiveBlobChunk', 'receiveBlobStop',
)
client_delegated = client_methods[1:]
def call_threadsafe(self, result_future, method, args):
class Client:
"""asyncio low-level ZEO client interface
"""
# All of the code in this class runs in a single dedicated
# thread. Thus, we can mostly avoid worrying about interleaved
# operations.
# One place where special care was required was in cache setup on
# connect.
protocol = None
ready = False
def __init__(self, addrs, client, cache, storage_key, read_only, loop):
"""Create a client interface
addr is either a host,port tuple or a string file name.
client is a ClientStorage. It must be thread safe.
cache is a ZEO.interfaces.IClientCache.
"""
self.loop = loop
self.addrs = addrs
self.storage_key = storage_key
self.read_only = read_only
self.client = client
for name in Protocol.client_delegated:
setattr(self, name, getattr(client, name))
self.cache = cache
self.protocols = ()
self.disconnected(None)
closed = False
def close(self):
if not self.closed:
self.closed = True
self.protocol.close()
self.cache.close()
self._clear_protocols()
def _clear_protocols(self, protocol=None):
for p in self.protocols:
if p is not protocol:
p.close()
self.protocols = ()
def disconnected(self, protocol=None):
if protocol is None or protocol is self.protocol:
self.ready = False
self.connected = concurrent.futures.Future()
self.protocol = None
self._clear_protocols()
self.try_connecting()
def upgrade(self, protocol):
self.ready = False
self.connected = concurrent.futures.Future()
self.protocol.close()
self.protocol = protocol
self._clear_protocols(protocol)
def try_connecting(self):
if not self.closed:
self.protocols = [
Protocol(addr, self, self.storage_key, self.read_only,
self.loop)
for addr in self.addrs
]
def registered(self, protocol, last_transaction_promise):
if self.protocol is None:
self.protocol = protocol
if not (self.read_only is Fallback and protocol.read_only):
# We're happy with this protocol. Tell the others to
# stop trying.
self._clear_protocols(protocol)
self.verify(last_transaction_promise)
elif (self.read_only is Fallback and not protocol.read_only and
self.protocol.read_only):
self.upgrade(protocol)
self.verify(last_transaction_promise)
else:
protocol.close() # too late, we went home with another
def register_failed(self, protocol, exc):
# A protcol failed registration. That's weird. If they've all
# failed, we should try again in a bit.
protocol.close()
logger.error("Registration or cache validation failed, %s", exc)
if (self.protocol is None and not
any(not p.closed for p in self.protocols)
):
self.loop.call_later(9 + local_random.random(), self.try_connecting)
def verify(self, last_transaction_promise):
protocol = self.protocol
@last_transaction_promise
def finish_verify(server_tid):
cache = self.cache
if cache:
cache_tid = cache.getLastTid()
if not cache_tid:
logger.error("Non-empty cache w/o tid -- clearing")
cache.clear()
self.client.invalidateCache()
self.finished_verify(server_tid)
elif cache_tid > server_tid:
raise AssertionError("Server behind client, %r < %r, %s",
server_tid, cache_tid, protocol)
elif cache_tid == server_tid:
self.finished_verify(server_tid)
else:
@protocol.promise('getInvalidations', cache_tid)
def verify_invalidations(vdata):
if vdata:
tid, oids = vdata
for oid in oids:
cache.invalidate(oid, None)
return tid
else:
# cache is too old
logger.info("cache too old %s", protocol)
self.cache.clear()
self.client.invalidateCache()
return server_tid
verify_invalidations(self.finished_verify,
self.connected.set_exception)
else:
self.finished_verify(server_tid)
@finish_verify.catch
def verify_failed(exc):
del self.protocol
self.register_failed(protocol, exc)
def finished_verify(self, server_tid):
self.cache.setLastTid(server_tid)
self.ready = True
self.connected.set_result(None)
def call_async_threadsafe(self, future, method, args):
if self.ready:
return self._call(result_future, method, args)
self.protocol.call_async(method, args)
future.set_result(None)
else:
future.set_exception(ZEO.Exceptions.ClientDisconnected())
def _when_ready(self, func, result_future, *args):
@self.connected.add_done_callback
def done(future):
......@@ -199,9 +395,16 @@ class Client(asyncio.Protocol):
if e is not None:
result_future.set_exception(e)
else:
self.call_threadsafe(result_future, method, args)
if self.ready:
func(result_future, *args)
else:
self._when_ready(func, result_future, *args)
return result_future
def call_threadsafe(self, future, method, args):
if self.ready:
self.protocol.call(future, method, args)
else:
self._when_ready(self.call_threadsafe, future, method, args)
# Special methods because they update the cache.
......@@ -209,21 +412,23 @@ class Client(asyncio.Protocol):
data = self.cache.load(oid)
if data is not None:
future.set_result(data)
else:
@self.promise('loadEx', oid)
elif self.ready:
@self.protocol.promise('loadEx', oid)
def load(data):
future.set_result(data)
data, tid = data
self.cache.store(oid, tid, None, data)
load.catch(future.set_exception)
else:
self._when_ready(self.load_threadsafe, future, oid)
def load_before_threadsafe(self, future, oid, tid):
data = self.cache.loadBefore(oid, tid)
if data is not None:
future.set_result(data)
else:
@self.promise('loadBefore', oid, tid)
elif self.ready:
@self.protocol.promise('loadBefore', oid, tid)
def load_before(data):
future.set_result(data)
if data:
......@@ -231,9 +436,12 @@ class Client(asyncio.Protocol):
self.cache.store(oid, start, end, data)
load_before.catch(future.set_exception)
else:
self._when_ready(self.load_before_threadsafe, future, oid, tid)
def tpc_finish_threadsafe(self, future, tid, updates):
@self.promise('tpc_finish', tid)
if self.ready:
@self.protocol.promise('tpc_finish', tid)
def committed(_):
cache = self.cache
for oid, s, data in updates:
......@@ -244,6 +452,8 @@ class Client(asyncio.Protocol):
future.set_result(None)
committed.catch(future.set_exception)
else:
future.set_exception(ClientDisconnected())
# Methods called by the server:
......@@ -254,6 +464,7 @@ class Client(asyncio.Protocol):
client_delegated = client_methods[1:]
def invalidateTransaction(self, tid, oids):
if self.ready:
for oid in oids:
self.cache.invalidate(oid, tid)
self.cache.setLastTid(tid)
......@@ -262,9 +473,9 @@ class Client(asyncio.Protocol):
class ClientRunner:
def set_options(self, addr, wrapper, cache, storage_key, read_only,
def set_options(self, addrs, wrapper, cache, storage_key, read_only,
timeout=30):
self.__args = addr, wrapper, cache, storage_key, read_only
self.__args = addrs, wrapper, cache, storage_key, read_only
self.timeout = timeout
self.connected = concurrent.futures.Future()
......@@ -316,9 +527,9 @@ class ClientThread(ClientRunner):
Calls to it are made in a thread-safe fashion.
"""
def __init__(self, addr, client, cache,
def __init__(self, addrs, client, cache,
storage_key='1', read_only=False, timeout=30):
self.set_options(addr, client, cache, storage_key, read_only, timeout)
self.set_options(addrs, client, cache, storage_key, read_only, timeout)
threading.Thread(
target=self.run,
args=(addr, client, cache, storage_key, read_only),
......@@ -348,7 +559,7 @@ class Promise:
# completed if the callbacks are set in the same code that
# created the promise, which they are.
next = success_callback = error_callback = None
next = success_callback = error_callback = cancelled = None
def __call__(self, success_callback = None, error_callback = None):
self.next = self.__class__()
......@@ -356,6 +567,9 @@ class Promise:
self.error_callback = error_callback
return self.next
def cancel(self):
self.set_exception(concurrent.futures.CancelledError)
def catch(self, error_callback):
self.error_callback = error_callback
......
import asyncio
import pprint
class Loop:
def __init__(self, debug=True):
protocol = transport = None
def __init__(self, addrs=(), debug=True):
self.addrs = addrs
self.get_debug = lambda : debug
self.connecting = {}
self.later = []
self.exceptions = []
def call_soon(self, func, *args):
func(*args)
def create_connection(self, protocol_factory, host, port):
def _connect(self, future, protocol_factory):
self.protocol = protocol = protocol_factory()
self.transport = transport = Transport()
future = asyncio.Future(loop=self)
future.set_result((transport, protocol))
protocol.connection_made(transport)
future.set_result((transport, protocol))
def connect_connecting(self, addr):
future, protocol_factory = self.connecting.pop(addr)
self._connect(future, protocol_factory)
def fail_connecting(self, addr):
future, protocol_factory = self.connecting.pop(addr)
if not future.cancelled():
future.set_exception(ConnectionRefusedError())
def create_connection(self, protocol_factory, host, port):
future = asyncio.Future(loop=self)
addr = host, port
if addr in self.addrs:
self._connect(future, protocol_factory)
else:
self.connecting[addr] = future, protocol_factory
return future
def create_unix_connection(self, protocol_factory, path):
future = asyncio.Future(loop=self)
if path in self.addrs:
self._connect(future, protocol_factory)
else:
self.connecting[path] = future, protocol_factory
return future
def call_soon_threadsafe(self, func, *args):
func(*args)
def call_later(self, delay, func, *args):
self.later.append((delay, func, args))
def call_exception_handler(self, context):
self.exceptions.append(context)
class Transport:
def __init__(self):
......
from zope.testing import setupstack
from concurrent.futures import Future
from unittest import mock
from ZODB.POSException import ReadOnlyError
import asyncio
import collections
import logging
import pdb
import pickle
import struct
import unittest
from .testing import Loop
from .client import ClientRunner
from .client import ClientRunner, Fallback
from ..Exceptions import ClientDisconnected
class AsyncTests(setupstack.TestCase, ClientRunner):
addr = ('127.0.0.1', 8200)
def start(self):
def start(self,
addrs=(('127.0.0.1', 8200), ), loop_addrs=None,
read_only=False,
):
# To create a client, we need to specify an address, a client
# object and a cache.
wrapper = mock.Mock()
cache = MemoryCache()
self.set_options(self.addr, wrapper, cache, 'TEST', False)
self.set_options(addrs, wrapper, cache, 'TEST', read_only)
# We can also provide an event loop. We'll use a testing loop
# so we don't have to actually make any network connection.
self.setup_delegation(Loop())
protocol = self.loop.protocol
transport = self.loop.transport
loop = Loop(addrs if loop_addrs is None else loop_addrs)
self.setup_delegation(loop)
protocol = loop.protocol
transport = loop.transport
def send(meth, *args):
protocol.data_received(
loop.protocol.data_received(
sized(pickle.dumps((0, True, meth, args), 3)))
def respond(message_id, result):
protocol.data_received(
loop.protocol.data_received(
sized(pickle.dumps((message_id, False, '.reply', result), 3)))
return (wrapper, cache, self.loop, self.client, protocol, transport,
......@@ -159,7 +165,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'd'*8))
self.assertEqual(cache.load(b'4'*8), ('committed 4', b'd'*8))
# Is the protocol is disconnected, it will reconnect and will
# If the protocol is disconnected, it will reconnect and will
# resolve outstanding requests with exceptions:
loaded = self.load(b'1'*8)
f1 = self.call('foo', 1, 2)
......@@ -186,11 +192,11 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(parse(transport.pop()),
[(10, False, 'register', ('TEST', False)),
(11, False, 'lastTransaction', ()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(10, None)
respond(11, b'd'*8)
respond(1, None)
respond(2, b'd'*8)
# Because the server tid matches the cache tid, we're done connecting
self.assert_(client.connected.done() and not transport.data)
......@@ -279,15 +285,147 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(cache)
wrapper.invalidateCache.assert_called_with()
def test_cache_crazy(self):
def test_multiple_addresses(self):
# We can pass multiple addresses to client constructor
addrs = [('1.2.3.4', 8200), ('2.2.3.4', 8200)]
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
self.start(addrs, ()))
cache.setLastTid(b'e'*8)
cache.store(b'4'*8, b'e'*8, None, '4 data')
self.assertTrue(cache)
# We haven't connected yet
self.assert_(protocol is None and transport is None)
# There are 2 connection attempts outstanding:
self.assertEqual(sorted(loop.connecting), addrs)
# We cause the first one to fail:
loop.fail_connecting(addrs[0])
self.assertEqual(sorted(loop.connecting), addrs[1:])
# The failed connection is attempted in the future:
delay, func, args = loop.later.pop(0)
self.assert_(1 <= delay <= 2)
func(*args)
self.assertEqual(sorted(loop.connecting), addrs)
# Let's connect the second address
loop.connect_connecting(addrs[1])
self.assertEqual(sorted(loop.connecting), addrs[:1])
protocol = loop.protocol
transport = loop.transport
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
respond(1, None)
# Now, when the first connection fails, it won't be retried,
# because we're already connected.
self.assertEqual(sorted(loop.later), [])
loop.fail_connecting(addrs[0])
self.assertEqual(sorted(loop.connecting), [])
self.assertEqual(sorted(loop.later), [])
def test_bad_server_tid(self):
# If in verification we get a server_tid behing the cache's, make sure
# we retry the connection later.
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
cache.store(b'4'*8, b'a'*8, None, '4 data')
cache.setLastTid('b'*8)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
parse = self.parse
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
respond(2, 'a'*8)
self.assertFalse(client.connected.done() or transport.data)
delay, func, args = loop.later.pop(0)
self.assert_(8 < delay < 10)
self.assertEqual(len(loop.later), 0)
func(*args) # connect again
self.assertFalse(protocol is loop.protocol)
self.assertFalse(transport is loop.transport)
protocol = loop.protocol
transport = loop.transport
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
respond(2, 'b'*8)
self.assert_(client.connected.done() and not transport.data)
self.assert_(client.ready)
def test_readonly_fallback(self):
addrs = [('1.2.3.4', 8200), ('2.2.3.4', 8200)]
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(addrs, (), read_only=Fallback))
# We'll treat the first address as read-only and we'll let it connect:
loop.connect_connecting(addrs[0])
protocol, transport = loop.protocol, loop.transport
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
# We see that the client tried a writable connection:
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
# We respond with a read-only exception:
respond(1, (ReadOnlyError, ReadOnlyError()))
# The client tries for a read-only connection:
self.assertEqual(self.parse(transport.pop()),
[(3, False, 'register', ('TEST', True)),
(4, False, 'lastTransaction', ()),
])
# We respond with successfully:
respond(3, None)
respond(4, 'b'*8)
# At this point, the client is ready and using the protocol,
# and the protocol is read-only:
self.assert_(client.ready)
self.assertEqual(client.protocol, protocol)
self.assertEqual(protocol.read_only, True)
connected = client.connected
self.assert_(connected.done())
# We connect the second address:
loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), b'Z101')
self.assertEqual(self.parse(loop.transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
# We respond and the writable connection succeeds:
respond(1, None)
# Now, the original protocol is closed, and the client is
# no-longer ready:
self.assertFalse(client.ready)
self.assertFalse(client.protocol is protocol)
self.assertEqual(client.protocol, loop.protocol)
self.assertEqual(protocol.closed, True)
self.assert_(client.connected is not connected)
self.assertFalse(client.connected.done())
protocol, transport = loop.protocol, loop.transport
self.assertEqual(protocol.read_only, False)
# Now, we finish verification
respond(2, 'b'*8)
self.assert_(client.ready)
self.assert_(client.connected.done())
def test_invalidations_while_verifying(self):
# While we're verifying, invalidations are ignored
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(self.parse(transport.pop()),
......@@ -295,15 +433,35 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
(2, False, 'lastTransaction', ()),
])
respond(1, None)
send('invalidateTransaction', b'b'*8, [b'1'*8])
self.assertFalse(wrapper.invalidateTransaction.called)
respond(2, b'a'*8)
send('invalidateTransaction', b'c'*8, [b'1'*8])
wrapper.invalidateTransaction.assert_called_with(b'c'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
# The server tid is less than the client tid, WTF? We error
with self.assertRaisesRegex(AssertionError, 'Server behind client'):
client.connected.result()
# We'll disconnect:
protocol.connection_lost(Exception("lost"))
self.assert_(protocol is not loop.protocol)
self.assert_(transport is not loop.transport)
protocol = loop.protocol
transport = loop.transport
# todo:
# bad cache validation, make sure ZODB cache is cleared
# cache boolean value in interface
# Similarly, invalidations aren't processed while reconnecting:
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
send('invalidateTransaction', b'd'*8, [b'1'*8])
self.assertFalse(wrapper.invalidateTransaction.called)
respond(2, b'c'*8)
send('invalidateTransaction', b'e'*8, [b'1'*8])
wrapper.invalidateTransaction.assert_called_with(b'e'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
def unsized(self, data, unpickle=False):
result = []
......@@ -378,6 +536,21 @@ class MemoryCache:
def setLastTid(self, tid):
self.last_tid = tid
class Logging:
def __init__(self, level=logging.ERROR):
self.level = level
def __enter__(self):
self.handler = logging.StreamHandler()
logging.getLogger().addHandler(self.handler)
logging.getLogger().setLevel(self.level)
def __exit__(self, *args):
logging.getLogger().removeHandler(self.handler)
logging.getLogger().setLevel(logging.NOTSET)
def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(AsyncTests))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment