Commit fe487c07 authored by Julien Muchembled's avatar Julien Muchembled

ssl: fix handshaking connections being stuck when they're aborted

parent aaefaf8b
...@@ -433,9 +433,12 @@ class Connection(BaseConnection): ...@@ -433,9 +433,12 @@ class Connection(BaseConnection):
def abort(self): def abort(self):
"""Abort dealing with this connection.""" """Abort dealing with this connection."""
assert self.pending()
if self.connecting:
self.close()
return
logging.debug('aborting a connector for %r', self) logging.debug('aborting a connector for %r', self)
self.aborted = True self.aborted = True
assert self.pending()
if self._on_close is not None: if self._on_close is not None:
self._on_close() self._on_close()
self._on_close = None self._on_close = None
......
...@@ -278,8 +278,8 @@ class ServerNode(Node): ...@@ -278,8 +278,8 @@ class ServerNode(Node):
if not address: if not address:
address = self.newAddress() address = self.newAddress()
if cluster is None: if cluster is None:
master_nodes = kw['master_nodes'] master_nodes = kw.get('master_nodes', ())
name = kw['name'] name = kw.get('name', 'test')
else: else:
master_nodes = kw.get('master_nodes', cluster.master_nodes) master_nodes = kw.get('master_nodes', cluster.master_nodes)
name = kw.get('name', cluster.name) name = kw.get('name', cluster.name)
...@@ -292,7 +292,7 @@ class ServerNode(Node): ...@@ -292,7 +292,7 @@ class ServerNode(Node):
self.daemon = True self.daemon = True
self.node_name = '%s_%u' % (self.node_type, port) self.node_name = '%s_%u' % (self.node_type, port)
kw.update(getCluster=name, getBind=address, kw.update(getCluster=name, getBind=address,
getMasters=parseMasterList(master_nodes, address)) getMasters=master_nodes and parseMasterList(master_nodes, address))
super(ServerNode, self).__init__(Mock(kw)) super(ServerNode, self).__init__(Mock(kw))
def getVirtualAddress(self): def getVirtualAddress(self):
......
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
import unittest import unittest
from neo.lib.connection import ClientConnection, ListeningConnection
from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets
from .. import SSL from .. import SSL
from . import NEOCluster, test, testReplication from . import MasterApplication, NEOCluster, test, testReplication
class SSLMixin: class SSLMixin:
...@@ -34,6 +37,30 @@ class SSLTests(SSLMixin, test.Test): ...@@ -34,6 +37,30 @@ class SSLTests(SSLMixin, test.Test):
# exclude expected failures # exclude expected failures
testDeadlockAvoidance = testStorageFailureDuringTpcFinish = None testDeadlockAvoidance = testStorageFailureDuringTpcFinish = None
def testAbortConnection(self):
app = MasterApplication(getSSL=SSL, getReplicas=0, getPartitions=1)
handler = EventHandler(app)
app.listening_conn = ListeningConnection(app, handler, app.server)
node = app.nm.createMaster(address=app.listening_conn.getAddress(),
uuid=app.uuid)
for after_handshake in 1, 0:
conn = ClientConnection(app, handler, node)
conn.ask(Packets.Ping())
connector = conn.getConnector()
del connector.connect_limit[connector.addr]
app.em.poll(1)
self.assertTrue(isinstance(connector,
connector.SSLHandshakeConnectorClass))
self.assertNotIn(connector.getDescriptor(), app.em.writer_set)
if after_handshake:
while not isinstance(connector, connector.SSLConnectorClass):
app.em.poll(1)
conn.abort()
fd = connector.getDescriptor()
while fd in app.em.reader_set:
app.em.poll(1)
self.assertIs(conn.getConnector(), None)
class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests): class SSLReplicationTests(SSLMixin, testReplication.ReplicationTests):
# do not repeat slowest tests with SSL # do not repeat slowest tests with SSL
testBackupNodeLost = testBackupNormalCase = None testBackupNodeLost = testBackupNormalCase = None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment