Commit cbf756c7 authored by Jim Fulton's avatar Jim Fulton

Hardened asyncio interface

- Moved cache into async thread to avoid lots of locking.

- Setup delegation to storage.

- Provide thread wrapper that runs the async protocol in a thread.
parent 08703051
from pickle import loads, dumps """Low-level protocol adapters
import asyncio
import concurrent.futures
import logging
import struct
import threading
logger = logging.getLogger(__name__)
class Disconnected(Exception): Derived from ngi connection adapters and filling a similar role to the
pass old zrpc smac layer for sized messages.
"""
import struct
class BaseTransportAdapter: class BaseTransportAdapter:
def __init__(self, transport): def __init__(self, transport):
self.transport = transport self.transport = transport
def close(self): def close(self):
self.transport.close self.transport.close()
def is_closing(self): def is_closing(self):
return self.transport.is_closing() return self.transport.is_closing()
def get_extra_info(self, name, default=None): def get_extra_info(self, name, default=None):
...@@ -57,8 +52,7 @@ class SizedTransportAdapter(BaseTransportAdapter): ...@@ -57,8 +52,7 @@ class SizedTransportAdapter(BaseTransportAdapter):
""" """
def write(self, message): def write(self, message):
self.transport.write(struct.pack(">I", len(message))) self.transport.writelines((struct.pack(">I", len(message)), message))
self.transport.write(message)
def writelines(self, list_of_data): def writelines(self, list_of_data):
self.transport.writelines(sized_iter(list_of_data)) self.transport.writelines(sized_iter(list_of_data))
...@@ -68,7 +62,6 @@ def sized_iter(data): ...@@ -68,7 +62,6 @@ def sized_iter(data):
yield struct.pack(">I", len(message)) yield struct.pack(">I", len(message))
yield message yield message
class SizedProtocolAdapter(BaseProtocolAdapter): class SizedProtocolAdapter(BaseProtocolAdapter):
def __init__(self, protocol): def __init__(self, protocol):
...@@ -103,144 +96,3 @@ class SizedProtocolAdapter(BaseProtocolAdapter): ...@@ -103,144 +96,3 @@ class SizedProtocolAdapter(BaseProtocolAdapter):
self.want = 4 self.want = 4
self.getting_size = True self.getting_size = True
self.protocol.data_received(collected) self.protocol.data_received(collected)
class ClientProtocol(asyncio.Protocol):
"""asyncio low-level ZEO client interface
"""
def __init__(self, addr,
client=None, storage_key='1', read_only=False, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self.loop = loop
self.addr = addr
self.storage_key = storage_key
self.read_only = read_only
self.client = client
self.connected = asyncio.Future()
def protocol_factory(self):
return SizedProtocolAdapter(self)
def connect(self):
self.protocol_version = None
self.futures = {} # outstanding requests {request_id -> future}
if isinstance(self.addr, tuple):
host, port = self.addr
cr = self.loop.create_connection(self.protocol_factory, host, port)
else:
cr = self.loop.create_unix_connection(
self.protocol_factory, self.addr)
future = asyncio.async(cr, loop=self.loop)
@future.add_done_callback
def done_connecting(future):
e = future.exception()
if e is not None:
self.connected.set_exception(e)
return self.connected
def connection_made(self, transport):
logger.info("Connected")
self.transport = SizedTransportAdapter(transport)
def connection_lost(self, exc):
logger.info("Disconnected, %r", exc)
for f in self.futures.values():
d.set_exception(exc or Disconnected())
self.futures = {}
self.connect() # Reconnect
exception_type_type = type(Exception)
def data_received(self, data):
if self.protocol_version is None:
self.protocol_version = data
self.transport.write(data) # pleased to meet you version :)
self.call_async('register', self.storage_key, self.read_only)
self.connected.set_result(data)
else:
msgid, async, name, args = loads(data)
if name == '.reply':
future = self.futures.pop(msgid)
if (isinstance(args, tuple) and len(args) > 1 and
type(args[0]) == self.exception_type_type and
issubclass(r_args[0], Exception)
):
future.set_exception(args[0]) # XXX security checks
else:
future.set_result(args)
else:
assert async # clients only get async calls
if self.client:
getattr(self.client, name)(*args) # XXX security
else:
logger.info('called %r %r', (name, args))
def call_async(self, method, *args):
# XXX connection status...
self.transport.write(dumps((0, True, method, args), 3))
message_id = 0
def call(self, method, *args):
future = asyncio.Future()
self.message_id += 1
self.futures[self.message_id] = future
self.transport.write(dumps((self.message_id, False, method, args), 3))
return future
def call_concurrent(self, result_future, method, *args):
future = self.call(method, *args)
@future.add_done_callback
def concurrent_result(future):
if future.exception() is None:
result_future.set_result(future.result())
else:
result_future.set_exception(future.exception())
class ClientThread:
"""Thread wrapper for client interface
A ClientProtocol is run in a dedicated thread.
Calls to it are made in a thread-safe fashion.
"""
def __init__(self, addr,
client=None, storage_key='1', read_only=False, timeout=None):
self.addr = addr
self.client = client
self.storage_key = storage_key
self.read_only = read_only
self.connected = concurrent.futures.Future()
threading.Thread(target=self.run,
name='zeo_client_'+storage_key,
daemon=True,
).start()
self.connected.result(timeout)
def run(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.loop = loop
self.proto = ClientProtocol(
self.addr, None, self.storage_key, self.read_only)
f = self.proto.connect()
@f.add_done_callback
def thread_done_connecting(future):
e = future.exception()
if e is not None:
self.connected.set_exception(e)
else:
self.connected.set_result(None) # XXX prob return some info
loop.run_forever()
def call_async(self, method, *args):
self.loop.call_soon_threadsafe(self.proto.call_async, method, *args)
def call(self, method, *args, timeout=None):
result = concurrent.futures.Future()
self.loop.call_soon_threadsafe(
self.proto.call_concurrent, result, method, *args)
return result.result()
This diff is collapsed.
import asyncio
class Loop:
def __init__(self, debug=True):
self.get_debug = lambda : debug
def call_soon(self, func, *args):
func(*args)
def create_connection(self, protocol_factory, host, port):
self.protocol = protocol = protocol_factory()
self.transport = transport = Transport()
future = asyncio.Future(loop=self)
future.set_result((transport, protocol))
protocol.connection_made(transport)
return future
def call_soon_threadsafe(self, func, *args):
func(*args)
class Transport:
def __init__(self):
self.data = []
def write(self, data):
self.data.append(data)
def writelines(self, lines):
self.data.extend(lines)
def pop(self, count=None):
if count:
r = self.data[:count]
del self.data[:count]
else:
r = self.data[:]
del self.data[:]
return r
closed = False
def close(self):
self.closed = True
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment