Commit 6b8003ce authored by Jim Fulton's avatar Jim Fulton

Summary: Fixed: misshandled protocol disconnect during connection setup

If a protocol was disconnected while registering, maybe because the
server was still starting, the disconnection was handled correctly by
the Client, because the protocol attribute wasn't set, the connection
wasn't retried.
parent 667b23f6
...@@ -149,16 +149,18 @@ class Protocol(asyncio.Protocol): ...@@ -149,16 +149,18 @@ class Protocol(asyncio.Protocol):
self.heartbeat(write=False) self.heartbeat(write=False)
def connection_lost(self, exc): def connection_lost(self, exc):
logger.debug('connection_lost %r', exc)
self.heartbeat_handle.cancel() self.heartbeat_handle.cancel()
if self.closed: if self.closed:
for f in self.pop_futures(): for f in self.pop_futures():
f.cancel() f.cancel()
else: else:
self.client.disconnected(self)
# We have to be careful processing the futures, because # We have to be careful processing the futures, because
# exception callbacks might modufy them. # exception callbacks might modufy them.
for f in self.pop_futures(): for f in self.pop_futures():
f.set_exception(ClientDisconnected(exc or 'connection lost')) f.set_exception(ClientDisconnected(exc or 'connection lost'))
self.closed = True
self.client.disconnected(self)
def finish_connect(self, protocol_version): def finish_connect(self, protocol_version):
...@@ -439,6 +441,7 @@ class Client: ...@@ -439,6 +441,7 @@ class Client:
self.protocols = () self.protocols = ()
def disconnected(self, protocol=None): def disconnected(self, protocol=None):
logger.debug('disconnected %r %r', self, protocol)
if protocol is None or protocol is self.protocol: if protocol is None or protocol is self.protocol:
if protocol is self.protocol and protocol is not None: if protocol is self.protocol and protocol is not None:
self.client.notify_disconnected() self.client.notify_disconnected()
...@@ -447,6 +450,8 @@ class Client: ...@@ -447,6 +450,8 @@ class Client:
self.connected = concurrent.futures.Future() self.connected = concurrent.futures.Future()
self.protocol = None self.protocol = None
self._clear_protocols() self._clear_protocols()
if all(p.closed for p in self.protocols):
self.try_connecting() self.try_connecting()
def upgrade(self, protocol): def upgrade(self, protocol):
...@@ -457,6 +462,7 @@ class Client: ...@@ -457,6 +462,7 @@ class Client:
self._clear_protocols(protocol) self._clear_protocols(protocol)
def try_connecting(self): def try_connecting(self):
logger.debug('try_connecting')
if not self.closed: if not self.closed:
self.protocols = [ self.protocols = [
Protocol(self.loop, addr, self, Protocol(self.loop, addr, self,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment