Commit bdbc36dd authored by Jim Fulton's avatar Jim Fulton

Async changes:

- Issue with notify_connected, ClientStorage wants to make requests in
  response to being notified.  This is problematic because
  synchronsouse calls cause deadlock in this situation as do
  asyncronous calls done in a multi-threaded fashion.

  - Call get_info from io thread during startup, because
    notify_connected wants it.

  - Added an same-thread asyncronous API.

  - Added comment warning of this issue.

  - Added a little more logging.

- fixed an ordering issue when protocol is disconnected. It should
  notify the client before it cleans up it's futures to prevent
  getting more.

- Expose protocol_version to client so it can adjust it's behavior to
  the .

- More logging
parent b6ec0eca
...@@ -92,6 +92,8 @@ class Protocol(asyncio.Protocol): ...@@ -92,6 +92,8 @@ class Protocol(asyncio.Protocol):
@cr.add_done_callback @cr.add_done_callback
def done_connecting(future): def done_connecting(future):
if future.exception() is not None: if future.exception() is not None:
logger.info("Connection to %rfailed, retrying, %s",
self.addr, future.exception())
# keep trying # keep trying
if not self.closed: if not self.closed:
self.loop.call_later( self.loop.call_later(
...@@ -160,9 +162,9 @@ class Protocol(asyncio.Protocol): ...@@ -160,9 +162,9 @@ class Protocol(asyncio.Protocol):
f.cancel() f.cancel()
else: else:
logger.info("Disconnected, %s, %r", self, exc) logger.info("Disconnected, %s, %r", self, exc)
self.client.disconnected(self)
for f in self.futures.values(): for f in self.futures.values():
f.set_exception(exc) f.set_exception(exc)
self.client.disconnected(self)
def finish_connect(self, protocol_version): def finish_connect(self, protocol_version):
...@@ -295,11 +297,14 @@ class Protocol(asyncio.Protocol): ...@@ -295,11 +297,14 @@ class Protocol(asyncio.Protocol):
def promise(self, method, *args): def promise(self, method, *args):
return self.call(Promise(), method, args) return self.call(Promise(), method, args)
# Methods called by the server: # Methods called by the server.
# WARNING WARNING we can't call methods that call back to us
# syncronously, as that would lead to DEADLOCK!
client_methods = ( client_methods = (
'invalidateTransaction', 'serialnos', 'info', 'invalidateTransaction', 'serialnos', 'info',
'receiveBlobStart', 'receiveBlobChunk', 'receiveBlobStop', 'receiveBlobStart', 'receiveBlobChunk', 'receiveBlobStop',
# plus: notify_connected, notify_disconnected
) )
client_delegated = client_methods[1:] client_delegated = client_methods[1:]
...@@ -400,7 +405,7 @@ class Client: ...@@ -400,7 +405,7 @@ class Client:
# A protcol failed registration. That's weird. If they've all # A protcol failed registration. That's weird. If they've all
# failed, we should try again in a bit. # failed, we should try again in a bit.
protocol.close() protocol.close()
logger.error("Registration or cache validation failed, %s", exc) logger.exception("Registration or cache validation failed, %s", exc)
if (self.protocol is None and not if (self.protocol is None and not
any(not p.closed for p in self.protocols) any(not p.closed for p in self.protocols)
): ):
...@@ -455,10 +460,20 @@ class Client: ...@@ -455,10 +460,20 @@ class Client:
self.register_failed(protocol, exc) self.register_failed(protocol, exc)
def finished_verify(self, server_tid): def finished_verify(self, server_tid):
# The cache is validated and the last tid we got from the server.
# Set ready so we apply any invalidations that follow.
# We've been ignoring them up to this point.
self.cache.setLastTid(server_tid) self.cache.setLastTid(server_tid)
self.ready = True self.ready = True
@self.protocol.promise('get_info')
def got_info(info):
self.connected.set_result(None) self.connected.set_result(None)
self.client.notify_connected(self) self.client.notify_connected(self, info)
@got_info.catch
def failed_info(exc):
self.register_failed(self, exc)
def get_peername(self): def get_peername(self):
return self.protocol.get_peername() return self.protocol.get_peername()
...@@ -470,6 +485,9 @@ class Client: ...@@ -470,6 +485,9 @@ class Client:
else: else:
future.set_exception(ZEO.Exceptions.ClientDisconnected()) future.set_exception(ZEO.Exceptions.ClientDisconnected())
def call_async_from_same_thread(self, method, *args):
return self.protocol.call_async(method, args)
def call_async_iter_threadsafe(self, future, it): def call_async_iter_threadsafe(self, future, it):
if self.ready: if self.ready:
self.protocol.call_async_iter(it) self.protocol.call_async_iter(it)
...@@ -557,6 +575,9 @@ class Client: ...@@ -557,6 +575,9 @@ class Client:
self.cache.setLastTid(tid) self.cache.setLastTid(tid)
self.client.invalidateTransaction(tid, oids) self.client.invalidateTransaction(tid, oids)
@property
def protocol_version(self):
return self.protocol.protocol_version
class ClientRunner: class ClientRunner:
...@@ -641,21 +662,21 @@ class ClientThread(ClientRunner): ...@@ -641,21 +662,21 @@ class ClientThread(ClientRunner):
""" """
def __init__(self, addrs, client, cache, def __init__(self, addrs, client, cache,
storage_key='1', read_only=False, timeout=30): storage_key='1', read_only=False, timeout=30,
disconnect_poll=1):
self.set_options(addrs, client, cache, storage_key, read_only, self.set_options(addrs, client, cache, storage_key, read_only,
timeout, disconnect_poll) timeout, disconnect_poll)
threading.Thread( threading.Thread(
target=self.run, target=self.run,
args=(addr, client, cache, storage_key, read_only),
name='zeo_client_'+storage_key, name='zeo_client_'+storage_key,
daemon=True, daemon=True,
).start() ).start()
self.connected.result(timeout) self.connected.result(timeout)
def run(self, *args): def run(self):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
self.setup_delegation(loop, *args) self.setup_delegation(loop)
loop.run_forever() loop.run_forever()
class Promise: class Promise:
......
...@@ -56,6 +56,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -56,6 +56,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
]) ])
respond(1, None) respond(1, None)
respond(2, 'a'*8) respond(2, 'a'*8)
self.assertEqual(parse(transport.pop()), (3, False, 'get_info', ()))
respond(3, dict(length=42))
return (wrapper, cache, self.loop, self.client, protocol, transport, return (wrapper, cache, self.loop, self.client, protocol, transport,
send, respond) send, respond)
...@@ -110,16 +112,20 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -110,16 +112,20 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
respond(1, None) respond(1, None)
respond(2, 'a'*8) respond(2, 'a'*8)
# After verification, the client requests info:
self.assertEqual(parse(transport.pop()), (3, False, 'get_info', ()))
respond(3, dict(length=42))
# Now we're connected, the cache was initialized, and the # Now we're connected, the cache was initialized, and the
# queued message has been sent: # queued message has been sent:
self.assert_(client.connected.done()) self.assert_(client.connected.done())
self.assertEqual(cache.getLastTid(), 'a'*8) self.assertEqual(cache.getLastTid(), 'a'*8)
self.assertEqual(parse(transport.pop()), (3, False, 'foo', (1, 2))) self.assertEqual(parse(transport.pop()), (4, False, 'foo', (1, 2)))
# The wrapper object (ClientStorage) has been notified: # The wrapper object (ClientStorage) has been notified:
wrapper.notify_connected.assert_called_with(client) wrapper.notify_connected.assert_called_with(client, {'length': 42})
respond(3, 42) respond(4, 42)
self.assertEqual(f1.result(), 42) self.assertEqual(f1.result(), 42)
# Now we can make async calls: # Now we can make async calls:
...@@ -132,8 +138,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -132,8 +138,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# The data wasn't in the cache, so we make a server call: # The data wasn't in the cache, so we make a server call:
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
(4, False, 'loadEx', (b'1'*8,))) (5, False, 'loadEx', (b'1'*8,)))
respond(4, (b'data', b'a'*8)) respond(5, (b'data', b'a'*8))
self.assertEqual(loaded.result(), (b'data', b'a'*8)) self.assertEqual(loaded.result(), (b'data', b'a'*8))
# If we make another request, it will be satisfied from the cache: # If we make another request, it will be satisfied from the cache:
...@@ -149,8 +155,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -149,8 +155,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Now, if we try to load current again, we'll make a server request. # Now, if we try to load current again, we'll make a server request.
loaded = self.load(b'1'*8) loaded = self.load(b'1'*8)
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
(5, False, 'loadEx', (b'1'*8,))) (6, False, 'loadEx', (b'1'*8,)))
respond(5, (b'data2', b'b'*8)) respond(6, (b'data2', b'b'*8))
self.assertEqual(loaded.result(), (b'data2', b'b'*8)) self.assertEqual(loaded.result(), (b'data2', b'b'*8))
# Loading non-current data may also be satisfied from cache # Loading non-current data may also be satisfied from cache
...@@ -163,8 +169,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -163,8 +169,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
loaded = self.load_before(b'1'*8, b'_'*8) loaded = self.load_before(b'1'*8, b'_'*8)
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
(6, False, 'loadBefore', (b'1'*8, b'_'*8))) (7, False, 'loadBefore', (b'1'*8, b'_'*8)))
respond(6, (b'data0', b'^'*8, b'_'*8)) respond(7, (b'data0', b'^'*8, b'_'*8))
self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8)) self.assertEqual(loaded.result(), (b'data0', b'^'*8, b'_'*8))
# When committing transactions, we need to update the cache # When committing transactions, we need to update the cache
...@@ -187,8 +193,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -187,8 +193,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
cache.load(b'4'*8)) cache.load(b'4'*8))
self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8)) self.assertEqual(cache.load(b'1'*8), (b'data2', b'b'*8))
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
(7, False, 'tpc_finish', (b'd'*8,))) (8, False, 'tpc_finish', (b'd'*8,)))
respond(7, b'e'*8) respond(8, b'e'*8)
self.assertEqual(committed.result(), None) self.assertEqual(committed.result(), None)
self.assertEqual(cache.load(b'1'*8), None) self.assertEqual(cache.load(b'1'*8), None)
self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8)) self.assertEqual(cache.load(b'2'*8), ('committed 2', b'e'*8))
...@@ -201,8 +207,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -201,8 +207,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
f1 = self.call('foo', 1, 2) f1 = self.call('foo', 1, 2)
self.assertFalse(loaded.done() or f1.done()) self.assertFalse(loaded.done() or f1.done())
self.assertEqual(parse(transport.pop()), self.assertEqual(parse(transport.pop()),
[(8, False, 'loadEx', (b'1'*8,)), [(9, False, 'loadEx', (b'1'*8,)),
(9, False, 'foo', (1, 2))], (10, False, 'foo', (1, 2))],
) )
exc = TypeError(43) exc = TypeError(43)
...@@ -235,9 +241,11 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -235,9 +241,11 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertFalse(wrapper.notify_connected.called) self.assertFalse(wrapper.notify_connected.called)
respond(1, None) respond(1, None)
respond(2, b'e'*8) respond(2, b'e'*8)
wrapper.notify_connected.assert_called_with(client) self.assertEqual(parse(transport.pop()), (3, False, 'get_info', ()))
respond(3, dict(length=42))
# Because the server tid matches the cache tid, we're done connecting # Because the server tid matches the cache tid, we're done connecting
wrapper.notify_connected.assert_called_with(client, {'length': 42})
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
self.assertEqual(cache.getLastTid(), b'e'*8) self.assertEqual(cache.getLastTid(), b'e'*8)
...@@ -277,6 +285,10 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -277,6 +285,10 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
(3, False, 'getInvalidations', (b'a'*8, ))) (3, False, 'getInvalidations', (b'a'*8, )))
respond(3, (b'e'*8, [b'4'*8])) respond(3, (b'e'*8, [b'4'*8]))
self.assertEqual(self.parse(transport.pop()),
(4, False, 'get_info', ()))
respond(4, dict(length=42))
# Now that verification is done, we're done connecting # Now that verification is done, we're done connecting
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
self.assertEqual(cache.getLastTid(), b'e'*8) self.assertEqual(cache.getLastTid(), b'e'*8)
...@@ -316,6 +328,10 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -316,6 +328,10 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# We respond None, indicating that we're too far out of date: # We respond None, indicating that we're too far out of date:
respond(3, None) respond(3, None)
self.assertEqual(self.parse(transport.pop()),
(4, False, 'get_info', ()))
respond(4, dict(length=42))
# Now that verification is done, we're done connecting # Now that verification is done, we're done connecting
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
self.assertEqual(cache.getLastTid(), b'e'*8) self.assertEqual(cache.getLastTid(), b'e'*8)
...@@ -395,6 +411,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -395,6 +411,8 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
]) ])
respond(1, None) respond(1, None)
respond(2, 'b'*8) respond(2, 'b'*8)
self.assertEqual(parse(transport.pop()), (3, False, 'get_info', ()))
respond(3, dict(length=42))
self.assert_(client.connected.done() and not transport.data) self.assert_(client.connected.done() and not transport.data)
self.assert_(client.ready) self.assert_(client.ready)
...@@ -435,6 +453,12 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -435,6 +453,12 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.assertEqual(client.protocol, protocol) self.assertEqual(client.protocol, protocol)
self.assertEqual(protocol.read_only, True) self.assertEqual(protocol.read_only, True)
connected = client.connected connected = client.connected
# The client asks for info, and we respond:
self.assertEqual(self.parse(transport.pop()),
(5, False, 'get_info', ()))
respond(5, dict(length=42))
self.assert_(connected.done()) self.assert_(connected.done())
# We connect the second address: # We connect the second address:
...@@ -464,6 +488,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -464,6 +488,7 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
# Now, we finish verification # Now, we finish verification
respond(2, 'b'*8) respond(2, 'b'*8)
respond(3, dict(length=42))
self.assert_(client.ready) self.assert_(client.ready)
self.assert_(client.connected.done()) self.assert_(client.connected.done())
...@@ -558,6 +583,16 @@ class AsyncTests(setupstack.TestCase, ClientRunner): ...@@ -558,6 +583,16 @@ class AsyncTests(setupstack.TestCase, ClientRunner):
self.start(finish_start=True)) self.start(finish_start=True))
self.assertEqual(client.get_peername(), '1.2.3.4') self.assertEqual(client.get_peername(), '1.2.3.4')
def test_call_async_from_same_thread(self):
# There are a few (1?) cases where we call into client storage
# where it needs to call back asyncronously. Because we're
# calling from the same thread, we don't need to use a futurte.
wrapper, cache, loop, client, protocol, transport, send, respond = (
self.start(finish_start=True))
client.call_async_from_same_thread('foo', 1)
self.assertEqual(self.parse(transport.pop()), (0, True, 'foo', (1, )))
def unsized(self, data, unpickle=False): def unsized(self, data, unpickle=False):
result = [] result = []
while data: while data:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment