Commit 107f1077 authored by Jim Fulton's avatar Jim Fulton

Updated the asyncio to support multiple addresses and read-only fallback

Also got rid of the adapter machinery. It didn't buy enough to justify
the wrapping.
parent e0c64161
"""Low-level protocol adapters
Derived from ngi connection adapters and filling a similar role to the
old zrpc smac layer for sized messages.
"""
import struct
class BaseTransportAdapter:
def __init__(self, transport):
self.transport = transport
def close(self):
self.transport.close()
def is_closing(self):
return self.transport.is_closing()
def get_extra_info(self, name, default=None):
return self.transport.get_extra_info(name, default)
def pause_reading(self):
self.transport.pause_reading()
def resume_reading(self):
self.transport.resume_reading()
def abort(self):
self.transport.abort()
def can_write_eof(self):
return self.transport.can_write_eof()
def get_write_buffer_size(self):
return self.transport.get_write_buffer_size()
def get_write_buffer_limits(self):
return self.transport.get_write_buffer_limits()
def set_write_buffer_limits(self, high=None, low=None):
self.transport.set_write_buffer_limits(high, low)
def write(self, data):
self.transport.write(data)
def writelines(self, list_of_data):
self.transport.writelines(list_of_data)
def write_eof(self):
self.transport.write_eof()
class BaseProtocolAdapter:
def __init__(self, protocol):
self.protocol = protocol
def connection_made(self, transport):
self.protocol.connection_made(transport)
def connection_lost(self, exc):
self.protocol.connection_lost(exc)
def data_received(self, data):
self.protocol.data_received(data)
def eof_received(self):
return self.protocol.eof_received()
class SizedTransportAdapter(BaseTransportAdapter):
"""Sized-message transport adapter
"""
def write(self, message):
self.transport.writelines((struct.pack(">I", len(message)), message))
def writelines(self, list_of_data):
self.transport.writelines(sized_iter(list_of_data))
def sized_iter(data):
for message in data:
yield struct.pack(">I", len(message))
yield message
class SizedProtocolAdapter(BaseProtocolAdapter):
def __init__(self, protocol):
self.protocol = protocol
self.want = 4
self.got = 0
self.getting_size = True
self.input = []
def data_received(self, data):
self.got += len(data)
self.input.append(data)
while self.got >= self.want:
extra = self.got - self.want
if extra == 0:
collected = b''.join(self.input)
self.input = []
else:
input = self.input
self.input = [data[-extra:]]
input[-1] = input[-1][:-extra]
collected = b''.join(input)
self.got = extra
if self.getting_size:
# we were recieving the message size
assert self.want == 4
self.want = struct.unpack(">I", collected)[0]
self.getting_size = False
else:
self.want = 4
self.getting_size = True
self.protocol.data_received(collected)
from pickle import loads, dumps from pickle import loads, dumps
from ZODB.ConflictResolution import ResolvedSerial from ZODB.ConflictResolution import ResolvedSerial
from struct import unpack
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import logging import logging
import random import random
import threading import threading
import ZEO.Exceptions import ZEO.Exceptions
import ZODB.POSException
from . import adapters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
Fallback = object()
local_random = random.Random() # use separate generator to facilitate tests local_random = random.Random() # use separate generator to facilitate tests
class Client(asyncio.Protocol): class Closed(Exception):
"""A connection has been closed
"""
class Protocol(asyncio.Protocol):
"""asyncio low-level ZEO client interface """asyncio low-level ZEO client interface
""" """
...@@ -23,34 +29,46 @@ class Client(asyncio.Protocol): ...@@ -23,34 +29,46 @@ class Client(asyncio.Protocol):
# One place where special care was required was in cache setup on # One place where special care was required was in cache setup on
# connect. See finish connect below. # connect. See finish connect below.
transport = protocol_version = None
def __init__(self, addr, client, storage_key, read_only, loop,
connect_timeout=1):
"""Create a client interface
addr is either a host,port tuple or a string file name.
def __init__(self, addr, client, cache, storage_key, read_only, loop): client is a ClientStorage. It must be thread safe.
cache is a ZEO.interfaces.IClientCache.
"""
self.loop = loop self.loop = loop
self.addr = addr self.addr = addr
self.storage_key = storage_key self.storage_key = storage_key
self.read_only = read_only self.read_only = read_only
self.name = "%s(%r, %r, %r)" % (
self.__class__.__name__, addr, storage_key, read_only)
self.client = client self.client = client
for name in self.client_delegated: self.connect_timeout = connect_timeout
setattr(self, name, getattr(client, name)) self.futures = {} # { message_id -> future }
self.info = client.info self.input = []
self.cache = cache self.connect()
self.disconnected()
def __repr__(self):
return self.name
closed = False closed = False
def close(self): def close(self):
self.closed = True if not self.closed:
self.transport.close() self.closed = True
self.cache.close() self._connecting.cancel()
if self.transport is not None:
self.transport.close()
for future in self.futures.values():
future.set_exception(Closed())
self.futures.clear()
def protocol_factory(self): def protocol_factory(self):
return adapters.SizedProtocolAdapter(self) return self
def disconnected(self):
self.ready = False
self.connected = concurrent.futures.Future()
self.protocol_version = None
self.futures = {}
self.connect()
def connect(self): def connect(self):
if isinstance(self.addr, tuple): if isinstance(self.addr, tuple):
...@@ -60,24 +78,38 @@ class Client(asyncio.Protocol): ...@@ -60,24 +78,38 @@ class Client(asyncio.Protocol):
cr = self.loop.create_unix_connection( cr = self.loop.create_unix_connection(
self.protocol_factory, self.addr) self.protocol_factory, self.addr)
cr = asyncio.async(cr, loop=self.loop) self._connecting = cr = asyncio.async(cr, loop=self.loop)
@cr.add_done_callback @cr.add_done_callback
def done_connecting(future): def done_connecting(future):
if future.exception() is not None: if future.exception() is not None:
# keep trying # keep trying
self.loop.call_later(1 + local_random.random(), self.connect) if not self.closed:
self.loop.call_later(1 + local_random.random(),
self.connect)
def connection_made(self, transport): def connection_made(self, transport):
logger.info("Connected") logger.info("Connected %s", self)
self.transport = adapters.SizedTransportAdapter(transport) self.transport = transport
writelines = transport.writelines
from struct import pack
def write(message):
writelines((pack(">I", len(message)), message))
self._write = write
def connection_lost(self, exc): def connection_lost(self, exc):
exc = exc or ClientDisconnected() if exc is None:
logger.info("Disconnected, %r", exc) # we were closed
for f in self.futures.values(): for f in self.futures.values():
f.set_exception(exc) f.cancel()
self.disconnected() else:
logger.info("Disconnected, %s, %r", self, exc)
for f in self.futures.values():
f.set_exception(exc)
self.client.disconnected(self)
def finish_connect(self, protocol_version): def finish_connect(self, protocol_version):
# We use a promise model rather than coroutines here because # We use a promise model rather than coroutines here because
...@@ -95,56 +127,74 @@ class Client(asyncio.Protocol): ...@@ -95,56 +127,74 @@ class Client(asyncio.Protocol):
# lastTid before processing subsequent invalidations. # lastTid before processing subsequent invalidations.
self.protocol_version = protocol_version self.protocol_version = protocol_version
self.transport.write(protocol_version) self._write(protocol_version)
register = self.promise('register', self.storage_key, self.read_only)
register = self.promise(
'register', self.storage_key,
self.read_only if self.read_only is not Fallback else False,
)
# Get lastTransaction in flight right away to make successful
# connection quicker
lastTransaction = self.promise('lastTransaction') lastTransaction = self.promise('lastTransaction')
cache = self.cache
@register(lambda _ : lastTransaction)
def verify(server_tid):
if not cache:
return server_tid
cache_tid = cache.getLastTid()
if not cache_tid:
logger.error("Non-empty cache w/o tid -- clearing")
cache.clear()
self.client.invalidateCache()
return server_tid
elif cache_tid == server_tid:
logger.info("Cache up to date %r", server_tid)
return server_tid
elif cache_tid >= server_tid:
raise AssertionError("Server behind client, %r < %r",
server_tid, cache_tid)
@self.promise('getInvalidations', cache_tid)
def verify_invalidations(vdata):
if vdata:
tid, oids = vdata
for oid in oids:
cache.invalidate(oid, None)
return tid
else:
# cache is too old
self.cache.clear()
self.client.invalidateCache()
return server_tid
return verify_invalidations @register
def registered(_):
if self.read_only is Fallback:
self.read_only = False
self.client.registered(self, lastTransaction)
@register.catch
def register_failed(exc):
if (isinstance(exc, ZODB.POSException.ReadOnlyError) and
self.read_only is Fallback):
# We tried a write connection, degrade to a read-only one
self.read_only = True
register = self.promise(
'register', self.storage_key, self.read_only)
lastTransaction = self.promise('lastTransaction')
@register
def registered(_):
self.client.registered(self, lastTransaction)
@register.catch
def register_failed(exc):
self.client.register_failed(self, exc)
@verify else:
def finish_verification(server_tid): self.client.register_failed(self, exc)
cache.setLastTid(server_tid)
self.ready = True
finish_verification( got = 0
lambda _ : self.connected.set_result(None), want = 4
lambda e: self.connected.set_exception(e), getting_size = True
) def data_received(self, data):
self.got += len(data)
self.input.append(data)
while self.got >= self.want:
extra = self.got - self.want
if extra == 0:
collected = b''.join(self.input)
self.input = []
else:
input = self.input
self.input = [data[-extra:]]
input[-1] = input[-1][:-extra]
collected = b''.join(input)
self.got = extra
if self.getting_size:
# we were recieving the message size
assert self.want == 4
self.want = unpack(">I", collected)[0]
self.getting_size = False
else:
self.want = 4
self.getting_size = True
self.message_received(collected)
exception_type_type = type(Exception) exception_type_type = type(Exception)
def data_received(self, data): def message_received(self, data):
if self.protocol_version is None: if self.protocol_version is None:
self.finish_connect(data) self.finish_connect(data)
else: else:
...@@ -153,45 +203,191 @@ class Client(asyncio.Protocol): ...@@ -153,45 +203,191 @@ class Client(asyncio.Protocol):
future = self.futures.pop(msgid) future = self.futures.pop(msgid)
if (isinstance(args, tuple) and len(args) > 1 and if (isinstance(args, tuple) and len(args) > 1 and
type(args[0]) == self.exception_type_type and type(args[0]) == self.exception_type_type and
issubclass(r_args[0], Exception) issubclass(args[0], Exception)
): ):
future.set_exception(args[0]) future.set_exception(args[1])
else: else:
future.set_result(args) future.set_result(args)
else: else:
assert async # clients only get async calls assert async # clients only get async calls
if name in self.client_methods: if name in self.client_methods:
getattr(self, name)(*args) getattr(self.client, name)(*args)
else: else:
raise AttributeError(name) raise AttributeError(name)
def call_async(self, method, *args): def call_async(self, method, args):
if self.ready: self._write(dumps((0, True, method, args), 3))
self.transport.write(dumps((0, True, method, args), 3))
else:
raise ZEO.Exceptions.ClientDisconnected()
def call_async_threadsafe(self, future, method, args):
try:
self.call_async(method, *args)
except Exception as e:
future.set_exception(e)
else:
future.set_result(None)
message_id = 0 message_id = 0
def _call(self, future, method, args): def call(self, future, method, args):
self.message_id += 1 self.message_id += 1
self.futures[self.message_id] = future self.futures[self.message_id] = future
self.transport.write(dumps((self.message_id, False, method, args), 3)) self._write(dumps((self.message_id, False, method, args), 3))
return future return future
def promise(self, method, *args): def promise(self, method, *args):
return self._call(Promise(), method, args) return self.call(Promise(), method, args)
# Methods called by the server:
client_methods = (
'invalidateTransaction', 'serialnos', 'info',
'receiveBlobStart', 'receiveBlobChunk', 'receiveBlobStop',
)
client_delegated = client_methods[1:]
class Client:
"""asyncio low-level ZEO client interface
"""
# All of the code in this class runs in a single dedicated
# thread. Thus, we can mostly avoid worrying about interleaved
# operations.
# One place where special care was required was in cache setup on
# connect.
protocol = None
ready = False
def __init__(self, addrs, client, cache, storage_key, read_only, loop):
"""Create a client interface
addr is either a host,port tuple or a string file name.
client is a ClientStorage. It must be thread safe.
cache is a ZEO.interfaces.IClientCache.
"""
self.loop = loop
self.addrs = addrs
self.storage_key = storage_key
self.read_only = read_only
self.client = client
for name in Protocol.client_delegated:
setattr(self, name, getattr(client, name))
self.cache = cache
self.protocols = ()
self.disconnected(None)
def call_threadsafe(self, result_future, method, args): closed = False
def close(self):
if not self.closed:
self.closed = True
self.protocol.close()
self.cache.close()
self._clear_protocols()
def _clear_protocols(self, protocol=None):
for p in self.protocols:
if p is not protocol:
p.close()
self.protocols = ()
def disconnected(self, protocol=None):
if protocol is None or protocol is self.protocol:
self.ready = False
self.connected = concurrent.futures.Future()
self.protocol = None
self._clear_protocols()
self.try_connecting()
def upgrade(self, protocol):
self.ready = False
self.connected = concurrent.futures.Future()
self.protocol.close()
self.protocol = protocol
self._clear_protocols(protocol)
def try_connecting(self):
if not self.closed:
self.protocols = [
Protocol(addr, self, self.storage_key, self.read_only,
self.loop)
for addr in self.addrs
]
def registered(self, protocol, last_transaction_promise):
if self.protocol is None:
self.protocol = protocol
if not (self.read_only is Fallback and protocol.read_only):
# We're happy with this protocol. Tell the others to
# stop trying.
self._clear_protocols(protocol)
self.verify(last_transaction_promise)
elif (self.read_only is Fallback and not protocol.read_only and
self.protocol.read_only):
self.upgrade(protocol)
self.verify(last_transaction_promise)
else:
protocol.close() # too late, we went home with another
def register_failed(self, protocol, exc):
# A protcol failed registration. That's weird. If they've all
# failed, we should try again in a bit.
protocol.close()
logger.error("Registration or cache validation failed, %s", exc)
if (self.protocol is None and not
any(not p.closed for p in self.protocols)
):
self.loop.call_later(9 + local_random.random(), self.try_connecting)
def verify(self, last_transaction_promise):
protocol = self.protocol
@last_transaction_promise
def finish_verify(server_tid):
cache = self.cache
if cache:
cache_tid = cache.getLastTid()
if not cache_tid:
logger.error("Non-empty cache w/o tid -- clearing")
cache.clear()
self.client.invalidateCache()
self.finished_verify(server_tid)
elif cache_tid > server_tid:
raise AssertionError("Server behind client, %r < %r, %s",
server_tid, cache_tid, protocol)
elif cache_tid == server_tid:
self.finished_verify(server_tid)
else:
@protocol.promise('getInvalidations', cache_tid)
def verify_invalidations(vdata):
if vdata:
tid, oids = vdata
for oid in oids:
cache.invalidate(oid, None)
return tid
else:
# cache is too old
logger.info("cache too old %s", protocol)
self.cache.clear()
self.client.invalidateCache()
return server_tid
verify_invalidations(self.finished_verify,
self.connected.set_exception)
else:
self.finished_verify(server_tid)
@finish_verify.catch
def verify_failed(exc):
del self.protocol
self.register_failed(protocol, exc)
def finished_verify(self, server_tid):
self.cache.setLastTid(server_tid)
self.ready = True
self.connected.set_result(None)
def call_async_threadsafe(self, future, method, args):
if self.ready: if self.ready:
return self._call(result_future, method, args) self.protocol.call_async(method, args)
future.set_result(None)
else:
future.set_exception(ZEO.Exceptions.ClientDisconnected())
def _when_ready(self, func, result_future, *args):
@self.connected.add_done_callback @self.connected.add_done_callback
def done(future): def done(future):
...@@ -199,9 +395,16 @@ class Client(asyncio.Protocol): ...@@ -199,9 +395,16 @@ class Client(asyncio.Protocol):
if e is not None: if e is not None:
result_future.set_exception(e) result_future.set_exception(e)
else: else:
self.call_threadsafe(result_future, method, args) if self.ready:
func(result_future, *args)
else:
self._when_ready(func, result_future, *args)
return result_future def call_threadsafe(self, future, method, args):
if self.ready:
self.protocol.call(future, method, args)
else:
self._when_ready(self.call_threadsafe, future, method, args)
# Special methods because they update the cache. # Special methods because they update the cache.
...@@ -209,21 +412,23 @@ class Client(asyncio.Protocol): ...@@ -209,21 +412,23 @@ class Client(asyncio.Protocol):
data = self.cache.load(oid) data = self.cache.load(oid)
if data is not None: if data is not None:
future.set_result(data) future.set_result(data)
else: elif self.ready:
@self.promise('loadEx', oid) @self.protocol.promise('loadEx', oid)
def load(data): def load(data):
future.set_result(data) future.set_result(data)
data, tid = data data, tid = data
self.cache.store(oid, tid, None, data) self.cache.store(oid, tid, None, data)
load.catch(future.set_exception) load.catch(future.set_exception)
else:
self._when_ready(self.load_threadsafe, future, oid)
def load_before_threadsafe(self, future, oid, tid): def load_before_threadsafe(self, future, oid, tid):
data = self.cache.loadBefore(oid, tid) data = self.cache.loadBefore(oid, tid)
if data is not None: if data is not None:
future.set_result(data) future.set_result(data)
else: elif self.ready:
@self.promise('loadBefore', oid, tid) @self.protocol.promise('loadBefore', oid, tid)
def load_before(data): def load_before(data):
future.set_result(data) future.set_result(data)
if data: if data:
...@@ -231,19 +436,24 @@ class Client(asyncio.Protocol): ...@@ -231,19 +436,24 @@ class Client(asyncio.Protocol):
self.cache.store(oid, start, end, data) self.cache.store(oid, start, end, data)
load_before.catch(future.set_exception) load_before.catch(future.set_exception)
else:
self._when_ready(self.load_before_threadsafe, future, oid, tid)
def tpc_finish_threadsafe(self, future, tid, updates): def tpc_finish_threadsafe(self, future, tid, updates):
@self.promise('tpc_finish', tid) if self.ready:
def committed(_): @self.protocol.promise('tpc_finish', tid)
cache = self.cache def committed(_):
for oid, s, data in updates: cache = self.cache
cache.invalidate(oid, tid) for oid, s, data in updates:
if data and s != ResolvedSerial: cache.invalidate(oid, tid)
cache.store(oid, tid, None, data) if data and s != ResolvedSerial:
cache.setLastTid(tid) cache.store(oid, tid, None, data)
future.set_result(None) cache.setLastTid(tid)
future.set_result(None)
committed.catch(future.set_exception)
committed.catch(future.set_exception)
else:
future.set_exception(ClientDisconnected())
# Methods called by the server: # Methods called by the server:
...@@ -254,17 +464,18 @@ class Client(asyncio.Protocol): ...@@ -254,17 +464,18 @@ class Client(asyncio.Protocol):
client_delegated = client_methods[1:] client_delegated = client_methods[1:]
def invalidateTransaction(self, tid, oids): def invalidateTransaction(self, tid, oids):
for oid in oids: if self.ready:
self.cache.invalidate(oid, tid) for oid in oids:
self.cache.setLastTid(tid) self.cache.invalidate(oid, tid)
self.client.invalidateTransaction(tid, oids) self.cache.setLastTid(tid)
self.client.invalidateTransaction(tid, oids)
class ClientRunner: class ClientRunner:
def set_options(self, addr, wrapper, cache, storage_key, read_only, def set_options(self, addrs, wrapper, cache, storage_key, read_only,
timeout=30): timeout=30):
self.__args = addr, wrapper, cache, storage_key, read_only self.__args = addrs, wrapper, cache, storage_key, read_only
self.timeout = timeout self.timeout = timeout
self.connected = concurrent.futures.Future() self.connected = concurrent.futures.Future()
...@@ -316,9 +527,9 @@ class ClientThread(ClientRunner): ...@@ -316,9 +527,9 @@ class ClientThread(ClientRunner):
Calls to it are made in a thread-safe fashion. Calls to it are made in a thread-safe fashion.
""" """
def __init__(self, addr, client, cache, def __init__(self, addrs, client, cache,
storage_key='1', read_only=False, timeout=30): storage_key='1', read_only=False, timeout=30):
self.set_options(addr, client, cache, storage_key, read_only, timeout) self.set_options(addrs, client, cache, storage_key, read_only, timeout)
threading.Thread( threading.Thread(
target=self.run, target=self.run,
args=(addr, client, cache, storage_key, read_only), args=(addr, client, cache, storage_key, read_only),
...@@ -348,7 +559,7 @@ class Promise: ...@@ -348,7 +559,7 @@ class Promise:
# completed if the callbacks are set in the same code that # completed if the callbacks are set in the same code that
# created the promise, which they are. # created the promise, which they are.
next = success_callback = error_callback = None next = success_callback = error_callback = cancelled = None
def __call__(self, success_callback = None, error_callback = None): def __call__(self, success_callback = None, error_callback = None):
self.next = self.__class__() self.next = self.__class__()
...@@ -356,6 +567,9 @@ class Promise: ...@@ -356,6 +567,9 @@ class Promise:
self.error_callback = error_callback self.error_callback = error_callback
return self.next return self.next
def cancel(self):
self.set_exception(concurrent.futures.CancelledError)
def catch(self, error_callback): def catch(self, error_callback):
self.error_callback = error_callback self.error_callback = error_callback
......
import asyncio import asyncio
import pprint
class Loop: class Loop:
def __init__(self, debug=True): protocol = transport = None
def __init__(self, addrs=(), debug=True):
self.addrs = addrs
self.get_debug = lambda : debug self.get_debug = lambda : debug
self.connecting = {}
self.later = []
self.exceptions = []
def call_soon(self, func, *args): def call_soon(self, func, *args):
func(*args) func(*args)
def create_connection(self, protocol_factory, host, port): def _connect(self, future, protocol_factory):
self.protocol = protocol = protocol_factory() self.protocol = protocol = protocol_factory()
self.transport = transport = Transport() self.transport = transport = Transport()
future = asyncio.Future(loop=self)
future.set_result((transport, protocol))
protocol.connection_made(transport) protocol.connection_made(transport)
future.set_result((transport, protocol))
def connect_connecting(self, addr):
future, protocol_factory = self.connecting.pop(addr)
self._connect(future, protocol_factory)
def fail_connecting(self, addr):
future, protocol_factory = self.connecting.pop(addr)
if not future.cancelled():
future.set_exception(ConnectionRefusedError())
def create_connection(self, protocol_factory, host, port):
future = asyncio.Future(loop=self)
addr = host, port
if addr in self.addrs:
self._connect(future, protocol_factory)
else:
self.connecting[addr] = future, protocol_factory
return future
def create_unix_connection(self, protocol_factory, path):
future = asyncio.Future(loop=self)
if path in self.addrs:
self._connect(future, protocol_factory)
else:
self.connecting[path] = future, protocol_factory
return future return future
def call_soon_threadsafe(self, func, *args): def call_soon_threadsafe(self, func, *args):
func(*args) func(*args)
def call_later(self, delay, func, *args):
self.later.append((delay, func, args))
def call_exception_handler(self, context):
self.exceptions.append(context)
class Transport: class Transport:
def __init__(self): def __init__(self):
......
from zope.testing import setupstack from zope.testing import setupstack
from concurrent.futures import Future from concurrent.futures import Future
from unittest import mock from unittest import mock
from ZODB.POSException import ReadOnlyError
import asyncio import asyncio
import collections import collections
import logging
import pdb
import pickle import pickle
import struct import struct
import unittest import unittest
from .testing import Loop from .testing import Loop
from .client import ClientRunner from .client import ClientRunner, Fallback
from ..Exceptions import ClientDisconnected from ..Exceptions import ClientDisconnected
class AsyncTests(setupstack.TestCase, ClientRunner): class AsyncTests(setupstack.TestCase, ClientRunner):
addr = ('127.0.0.1', 8200) def start(self,
addrs=(('127.0.0.1', 8200), ), loop_addrs=None,
def start(self): read_only=False,
):
# To create a client, we need to specify an address, a client # To create a client, we need to specify an address, a client
# object and a cache. # object and a cache.
wrapper = mock.Mock() wrapper = mock.Mock()
cache = MemoryCache() cache = MemoryCache()
self.set_options(self.addr, wrapper, cache, 'TEST', False) self.set_options(addrs, wrapper, cache, 'TEST', read_only)
# We can also provide an event loop. We'll use a testing loop # We can also provide an event loop. We'll use a testing loop
# so we don't have to actually make any network connection. # so we don't have to actually make any network connection.
self.setup_delegation(Loop()) loop = Loop(addrs if loop_addrs is None else loop_addrs)
protocol = self.loop.protocol self.setup_delegation(loop)
transport = self.loop.transport protocol = loop.protocol
transport = loop.transport
def send(meth, *args): def send(meth, *args):
protocol.data_received( loop.protocol.data_received(
sized(pickle.dumps((0, True, meth, args), 3))) sized(pickle.dumps((0, True, meth, args), 3)))
def respond(message_id, result): def respond(message_id, result):
protocol.data_received( loop.protocol.data_received(
sized(pickle.dumps((message_id, False, '.reply', result), 3))) sized(pickle.dumps((message_id, False, '.reply', result), 3)))
return (wrapper, cache, self.loop, self.client, protocol, transport, return (wrapper, cache, self.loop, self.client, protocol, transport,
...@@ -159,7 +165,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -159,7 +165,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'd'*8)) self.assertEqual(cache.load(b'2'*8), ('committed 2', b'd'*8))
self.assertEqual(cache.load(b'4'*8), ('committed 4', b'd'*8)) self.assertEqual(cache.load(b'4'*8), ('committed 4', b'd'*8))
# Is the protocol is disconnected, it will reconnect and will # If the protocol is disconnected, it will reconnect and will
# resolve outstanding requests with exceptions: # resolve outstanding requests with exceptions:
loaded = self.load(b'1'*8) loaded = self.load(b'1'*8)
f1 = self.call('foo', 1, 2) f1 = self.call('foo', 1, 2)
...@@ -186,11 +192,11 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -186,11 +192,11 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
[(10, False, 'register', ('TEST', False)), [(1, False, 'register', ('TEST', False)),
(11, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
respond(10, None) respond(1, None)
respond(11, b'd'*8) respond(2, b'd'*8)
# Because the server tid matches the cache tid, we're done connecting # Because the server tid matches the cache tid, we're done connecting
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
...@@ -279,15 +285,147 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -279,15 +285,147 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(cache) self.assertFalse(cache)
wrapper.invalidateCache.assert_called_with() wrapper.invalidateCache.assert_called_with()
def test_cache_crazy(self): def test_multiple_addresses(self):
# We can pass multiple addresses to client constructor
addrs = [('1.2.3.4', 8200), ('2.2.3.4', 8200)]
wrapper, cache, loop, client, protocol, transport, send, respond = ( wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start()) self.start(addrs, ()))
cache.setLastTid(b'e'*8) # We haven't connected yet
cache.store(b'4'*8, b'e'*8, None, '4 data') self.assert_(protocol is None and transport is None)
self.assertTrue(cache)
# There are 2 connection attempts outstanding:
self.assertEqual(sorted(loop.connecting), addrs)
# We cause the first one to fail:
loop.fail_connecting(addrs[0])
self.assertEqual(sorted(loop.connecting), addrs[1:])
# The failed connection is attempted in the future:
delay, func, args = loop.later.pop(0)
self.assert_(1 <= delay <= 2)
func(*args)
self.assertEqual(sorted(loop.connecting), addrs)
# Let's connect the second address
loop.connect_connecting(addrs[1])
self.assertEqual(sorted(loop.connecting), addrs[:1])
protocol = loop.protocol
transport = loop.transport
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
respond(1, None)
# Now, when the first connection fails, it won't be retried,
# because we're already connected.
self.assertEqual(sorted(loop.later), [])
loop.fail_connecting(addrs[0])
self.assertEqual(sorted(loop.connecting), [])
self.assertEqual(sorted(loop.later), [])
def test_bad_server_tid(self):
# If in verification we get a server_tid behing the cache's, make sure
# we retry the connection later.
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
cache.store(b'4'*8, b'a'*8, None, '4 data')
cache.setLastTid('b'*8)
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
parse = self.parse
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
respond(2, 'a'*8)
self.assertFalse(client.connected.done() or transport.data) self.assertFalse(client.connected.done() or transport.data)
delay, func, args = loop.later.pop(0)
self.assert_(8 < delay < 10)
self.assertEqual(len(loop.later), 0)
func(*args) # connect again
self.assertFalse(protocol is loop.protocol)
self.assertFalse(transport is loop.transport)
protocol = loop.protocol
transport = loop.transport
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
respond(2, 'b'*8)
self.assert_(client.connected.done() and not transport.data)
self.assert_(client.ready)
def test_readonly_fallback(self):
addrs = [('1.2.3.4', 8200), ('2.2.3.4', 8200)]
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(addrs, (), read_only=Fallback))
# We'll treat the first address as read-only and we'll let it connect:
loop.connect_connecting(addrs[0])
protocol, transport = loop.protocol, loop.transport
protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
# We see that the client tried a writable connection:
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
# We respond with a read-only exception:
respond(1, (ReadOnlyError, ReadOnlyError()))
# The client tries for a read-only connection:
self.assertEqual(self.parse(transport.pop()),
[(3, False, 'register', ('TEST', True)),
(4, False, 'lastTransaction', ()),
])
# We respond with successfully:
respond(3, None)
respond(4, 'b'*8)
# At this point, the client is ready and using the protocol,
# and the protocol is read-only:
self.assert_(client.ready)
self.assertEqual(client.protocol, protocol)
self.assertEqual(protocol.read_only, True)
connected = client.connected
self.assert_(connected.done())
# We connect the second address:
loop.connect_connecting(addrs[1])
loop.protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(loop.transport.pop(2)), b'Z101')
self.assertEqual(self.parse(loop.transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
# We respond and the writable connection succeeds:
respond(1, None)
# Now, the original protocol is closed, and the client is
# no-longer ready:
self.assertFalse(client.ready)
self.assertFalse(client.protocol is protocol)
self.assertEqual(client.protocol, loop.protocol)
self.assertEqual(protocol.closed, True)
self.assert_(client.connected is not connected)
self.assertFalse(client.connected.done())
protocol, transport = loop.protocol, loop.transport
self.assertEqual(protocol.read_only, False)
# Now, we finish verification
respond(2, 'b'*8)
self.assert_(client.ready)
self.assert_(client.connected.done())
def test_invalidations_while_verifying(self):
# While we're verifying, invalidations are ignored
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start())
protocol.data_received(sized(b'Z101')) protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101') self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(self.parse(transport.pop()), self.assertEqual(self.parse(transport.pop()),
...@@ -295,15 +433,35 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -295,15 +433,35 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
(2, False, 'lastTransaction', ()), (2, False, 'lastTransaction', ()),
]) ])
respond(1, None) respond(1, None)
send('invalidateTransaction', b'b'*8, [b'1'*8])
self.assertFalse(wrapper.invalidateTransaction.called)
respond(2, b'a'*8) respond(2, b'a'*8)
send('invalidateTransaction', b'c'*8, [b'1'*8])
wrapper.invalidateTransaction.assert_called_with(b'c'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
# The server tid is less than the client tid, WTF? We error # We'll disconnect:
with self.assertRaisesRegex(AssertionError, 'Server behind client'): protocol.connection_lost(Exception("lost"))
client.connected.result() self.assert_(protocol is not loop.protocol)
self.assert_(transport is not loop.transport)
protocol = loop.protocol
transport = loop.transport
# todo: # Similarly, invalidations aren't processed while reconnecting:
# bad cache validation, make sure ZODB cache is cleared
# cache boolean value in interface protocol.data_received(sized(b'Z101'))
self.assertEqual(self.unsized(transport.pop(2)), b'Z101')
self.assertEqual(self.parse(transport.pop()),
[(1, False, 'register', ('TEST', False)),
(2, False, 'lastTransaction', ()),
])
respond(1, None)
send('invalidateTransaction', b'd'*8, [b'1'*8])
self.assertFalse(wrapper.invalidateTransaction.called)
respond(2, b'c'*8)
send('invalidateTransaction', b'e'*8, [b'1'*8])
wrapper.invalidateTransaction.assert_called_with(b'e'*8, [b'1'*8])
wrapper.invalidateTransaction.reset_mock()
def unsized(self, data, unpickle=False): def unsized(self, data, unpickle=False):
result = [] result = []
...@@ -378,6 +536,21 @@ class MemoryCache: ...@@ -378,6 +536,21 @@ class MemoryCache:
def setLastTid(self, tid): def setLastTid(self, tid):
self.last_tid = tid self.last_tid = tid
class Logging:
def __init__(self, level=logging.ERROR):
self.level = level
def __enter__(self):
self.handler = logging.StreamHandler()
logging.getLogger().addHandler(self.handler)
logging.getLogger().setLevel(self.level)
def __exit__(self, *args):
logging.getLogger().removeHandler(self.handler)
logging.getLogger().setLevel(logging.NOTSET)
def test_suite(): def test_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(AsyncTests)) suite.addTest(unittest.makeSuite(AsyncTests))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment