# -*- coding: utf-8 -*- # # Copyright (C) 2009-2015 Nexedi SA # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. import unittest from time import time from mock import Mock from neo.lib import connection, logging from neo.lib.connection import BaseConnection, ListeningConnection, \ Connection, ClientConnection, ServerConnection, MTClientConnection, \ HandlerSwitcher, CRITICAL_TIMEOUT from neo.lib.connector import registerConnectorHandler from neo.lib.connector import ConnectorException, ConnectorTryAgainException, \ ConnectorInProgressException, ConnectorConnectionRefusedException from neo.lib.handler import EventHandler from neo.lib.protocol import Packets, PACKET_HEADER_FORMAT from . import NeoUnitTestBase, Patch connector_cpt = 0 class DummyConnector(Mock): def __init__(self, addr, s=None): logging.info("initializing connector") global connector_cpt self.desc = connector_cpt connector_cpt += 1 self.packet_cpt = 0 self.addr = addr Mock.__init__(self) def getAddress(self): return self.addr def getDescriptor(self): return self.desc accept = getError = makeClientConnection = makeListeningConnection = \ receive = send = lambda *args, **kw: None dummy_connector = Patch(BaseConnection, ConnectorClass=lambda orig, self, *args, **kw: DummyConnector(*args, **kw)) class ConnectionTests(NeoUnitTestBase): def setUp(self): NeoUnitTestBase.setUp(self) self.app = Mock({'__repr__': 'Fake App'}) self.em = Mock({'__repr__': 'Fake Em'}) self.handler = Mock({'__repr__': 'Fake Handler'}) self.address = ("127.0.0.7", 93413) self.node = Mock({'getAddress': self.address}) connection.connect_limit = 0 def _makeListeningConnection(self, addr): with dummy_connector: conn = ListeningConnection(self.em, self.handler, addr) self.connector = conn.connector return conn def _makeServerConnection(self): addr = self.address self.connector = DummyConnector(addr) return Connection(self.em, self.handler, self.connector, addr) def _makeClientConnection(self): with dummy_connector: conn = ClientConnection(self.em, self.handler, self.node) self.connector = conn.connector return conn _makeConnection = _makeClientConnection def _checkRegistered(self, n=1): self.assertEqual(len(self.em.mockGetNamedCalls("register")), n) def _checkUnregistered(self, n=1): self.assertEqual(len(self.em.mockGetNamedCalls("unregister")), n) def _checkReaderRemoved(self, n=1): self.assertEqual(len(self.em.mockGetNamedCalls("removeReader")), n) def _checkWriterAdded(self, n=1): self.assertEqual(len(self.em.mockGetNamedCalls("addWriter")), n) def _checkWriterRemoved(self, n=1): self.assertEqual(len(self.em.mockGetNamedCalls("removeWriter")), n) def _checkClose(self, n=1): self.assertEqual(len(self.connector.mockGetNamedCalls("close")), n) def _checkAccept(self, n=1): calls = self.connector.mockGetNamedCalls('accept') self.assertEqual(len(calls), n) def _checkSend(self, n=1, data=None): calls = self.connector.mockGetNamedCalls('send') self.assertEqual(len(calls), n) if n > 1 and data is not None: data = calls[n-1].getParam(0) self.assertEqual(data, "testdata") def _checkConnectionAccepted(self, n=1): calls = self.handler.mockGetNamedCalls('connectionAccepted') self.assertEqual(len(calls), n) def _checkConnectionFailed(self, n=1): calls = self.handler.mockGetNamedCalls('connectionFailed') self.assertEqual(len(calls), n) def _checkConnectionClosed(self, n=1): calls = self.handler.mockGetNamedCalls('connectionClosed') self.assertEqual(len(calls), n) def _checkConnectionStarted(self, n=1): calls = self.handler.mockGetNamedCalls('connectionStarted') self.assertEqual(len(calls), n) def _checkConnectionCompleted(self, n=1): calls = self.handler.mockGetNamedCalls('connectionCompleted') self.assertEqual(len(calls), n) def _checkMakeListeningConnection(self, n=1): calls = self.connector.mockGetNamedCalls('makeListeningConnection') self.assertEqual(len(calls), n) def _checkMakeClientConnection(self, n=1): calls = self.connector.mockGetNamedCalls("makeClientConnection") self.assertEqual(len(calls), n) def _checkPacketReceived(self, n=1): calls = self.handler.mockGetNamedCalls('packetReceived') self.assertEqual(len(calls), n) def _checkReadBuf(self, bc, data): content = bc.read_buf.read(len(bc.read_buf)) self.assertEqual(''.join(content), data) def _appendToReadBuf(self, bc, data): bc.read_buf.append(data) def _appendPacketToReadBuf(self, bc, packet): data = ''.join(packet.encode()) bc.read_buf.append(data) def _checkWriteBuf(self, bc, data): self.assertEqual(''.join(bc.write_buf), data) def test_01_BaseConnection(self): # init with address bc = self._makeConnection() self.assertEqual(bc.getAddress(), self.address) self.assertIsNot(bc.connector, None) self._checkRegistered(1) def test_02_ListeningConnection1(self): # test init part addr = ("127.0.0.7", 93413) with Patch(DummyConnector, accept=lambda orig, self: (self, ('', 0))): bc = self._makeListeningConnection(addr=addr) self.assertEqual(bc.getAddress(), addr) self._checkRegistered() self._checkMakeListeningConnection() # test readable bc.readable() self._checkAccept() self._checkConnectionAccepted() def test_02_ListeningConnection2(self): # test with exception raise when getting new connection def accept(orig, self): raise ConnectorTryAgainException addr = ("127.0.0.7", 93413) with Patch(DummyConnector, accept=accept): bc = self._makeListeningConnection(addr=addr) self.assertEqual(bc.getAddress(), addr) self._checkRegistered() self._checkMakeListeningConnection() # test readable bc.readable() self._checkAccept(1) self._checkConnectionAccepted(0) def test_03_Connection(self): bc = self._makeConnection() self.assertEqual(bc.getAddress(), self.address) self._checkReadBuf(bc, '') self._checkWriteBuf(bc, '') self.assertEqual(bc.cur_id, 0) self.assertFalse(bc.aborted) # test uuid self.assertEqual(bc.uuid, None) self.assertEqual(bc.getUUID(), None) uuid = self.getNewUUID(None) bc.setUUID(uuid) self.assertEqual(bc.getUUID(), uuid) # test next id cur_id = bc.cur_id next_id = bc._getNextId() self.assertEqual(next_id, cur_id) next_id = bc._getNextId() self.assertTrue(next_id > cur_id) # test overflow of next id bc.cur_id = 0xffffffff next_id = bc._getNextId() self.assertEqual(next_id, 0xffffffff) next_id = bc._getNextId() self.assertEqual(next_id, 0) def test_Connection_pending(self): bc = self._makeConnection() self.assertEqual(''.join(bc.write_buf), '') self.assertFalse(bc.pending()) bc.write_buf += '1' self.assertTrue(bc.pending()) def test_Connection_recv1(self): # patch receive method to return data with Patch(DummyConnector, receive=lambda orig, self: "testdata"): bc = self._makeConnection() self._checkReadBuf(bc, '') bc._recv() self._checkReadBuf(bc, 'testdata') def test_Connection_recv2(self): # patch receive method to raise try again def receive(orig, self): raise ConnectorTryAgainException with Patch(DummyConnector, receive=receive): bc = self._makeConnection() self._checkReadBuf(bc, '') bc._recv() self._checkReadBuf(bc, '') self._checkConnectionClosed(0) self._checkUnregistered(0) def test_Connection_recv3(self): # patch receive method to raise ConnectorConnectionRefusedException def receive(orig, self): raise ConnectorConnectionRefusedException with Patch(DummyConnector, receive=receive): bc = self._makeConnection() self._checkReadBuf(bc, '') # fake client connection instance with connecting attribute bc.connecting = True bc._recv() self._checkReadBuf(bc, '') self._checkConnectionFailed(1) self._checkUnregistered(1) def test_Connection_recv4(self): # patch receive method to raise any other connector error def receive(orig, self): raise ConnectorException with Patch(DummyConnector, receive=receive): bc = self._makeConnection() self._checkReadBuf(bc, '') self.assertRaises(ConnectorException, bc._recv) self._checkReadBuf(bc, '') self._checkConnectionClosed(1) self._checkUnregistered(1) def test_Connection_send1(self): # no data, nothing done # patch receive method to return data bc = self._makeConnection() self._checkWriteBuf(bc, '') bc._send() self._checkSend(0) self._checkConnectionClosed(0) self._checkUnregistered(0) def test_Connection_send2(self): # send all data with Patch(DummyConnector, send=lambda orig, self, data: len(data)): bc = self._makeConnection() self._checkWriteBuf(bc, '') bc.write_buf = ["testdata"] bc._send() self._checkSend(1, "testdata") self._checkWriteBuf(bc, '') self._checkConnectionClosed(0) self._checkUnregistered(0) def test_Connection_send3(self): # send part of the data with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2): bc = self._makeConnection() self._checkWriteBuf(bc, '') bc.write_buf = ["testdata"] bc._send() self._checkSend(1, "testdata") self._checkWriteBuf(bc, 'data') self._checkConnectionClosed(0) self._checkUnregistered(0) def test_Connection_send4(self): # send multiple packet with Patch(DummyConnector, send=lambda orig, self, data: len(data)): bc = self._makeConnection() self._checkWriteBuf(bc, '') bc.write_buf = ["testdata", "second", "third"] bc._send() self._checkSend(1, "testdatasecondthird") self._checkWriteBuf(bc, '') self._checkConnectionClosed(0) self._checkUnregistered(0) def test_Connection_send5(self): # send part of multiple packet with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2): bc = self._makeConnection() self._checkWriteBuf(bc, '') bc.write_buf = ["testdata", "second", "third"] bc._send() self._checkSend(1, "testdatasecondthird") self._checkWriteBuf(bc, 'econdthird') self._checkConnectionClosed(0) self._checkUnregistered(0) def test_Connection_send6(self): # raise try again def send(orig, self, data): raise ConnectorTryAgainException with Patch(DummyConnector, send=send): bc = self._makeConnection() self._checkWriteBuf(bc, '') bc.write_buf = ["testdata", "second", "third"] bc._send() self._checkSend(1, "testdatasecondthird") self._checkWriteBuf(bc, 'testdatasecondthird') self._checkConnectionClosed(0) self._checkUnregistered(0) def test_Connection_send7(self): # raise other error def send(orig, self, data): raise ConnectorException with Patch(DummyConnector, send=send): bc = self._makeConnection() self._checkWriteBuf(bc, '') bc.write_buf = ["testdata", "second", "third"] self.assertRaises(ConnectorException, bc._send) self._checkSend(1, "testdatasecondthird") # connection closed -> buffers flushed self._checkWriteBuf(bc, '') self._checkConnectionClosed(1) self._checkUnregistered(1) def test_07_Connection_addPacket(self): # new packet p = Packets.Ping() p._id = 0 bc = self._makeConnection() self._checkWriteBuf(bc, '') bc._addPacket(p) self._checkWriteBuf(bc, PACKET_HEADER_FORMAT.pack(0, p._code, 10)) self._checkWriterAdded(1) def test_Connection_analyse1(self): # nothing to read, nothing is done bc = self._makeConnection() bc._queue = Mock() self._checkReadBuf(bc, '') bc._analyse() self._checkPacketReceived(0) self._checkReadBuf(bc, '') p = Packets.AnswerPrimary(self.getNewUUID(None)) p.setId(1) p_data = ''.join(p.encode()) data_edge = len(p_data) - 1 p_data_1, p_data_2 = p_data[:data_edge], p_data[data_edge:] # append an incomplete packet, nothing is done bc.read_buf.append(p_data_1) bc._analyse() self._checkPacketReceived(0) self.assertNotEqual(len(bc.read_buf), 0) self.assertNotEqual(len(bc.read_buf), len(p_data)) # append the rest of the packet bc.read_buf.append(p_data_2) bc._analyse() # check packet decoded self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1) call = bc._queue.mockGetNamedCalls("append")[0] data = call.getParam(0) self.assertEqual(type(data), type(p)) self.assertEqual(data.getId(), p.getId()) self.assertEqual(data.decode(), p.decode()) self._checkReadBuf(bc, '') def test_Connection_analyse2(self): # give multiple packet bc = self._makeConnection() bc._queue = Mock() p1 = Packets.AnswerPrimary(self.getNewUUID(None)) p1.setId(1) self._appendPacketToReadBuf(bc, p1) p2 = Packets.AnswerPrimary( self.getNewUUID(None)) p2.setId(2) self._appendPacketToReadBuf(bc, p2) self.assertEqual(len(bc.read_buf), len(p1) + len(p2)) bc._analyse() # check two packets decoded self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 2) # packet 1 call = bc._queue.mockGetNamedCalls("append")[0] data = call.getParam(0) self.assertEqual(type(data), type(p1)) self.assertEqual(data.getId(), p1.getId()) self.assertEqual(data.decode(), p1.decode()) # packet 2 call = bc._queue.mockGetNamedCalls("append")[1] data = call.getParam(0) self.assertEqual(type(data), type(p2)) self.assertEqual(data.getId(), p2.getId()) self.assertEqual(data.decode(), p2.decode()) self._checkReadBuf(bc, '') def test_Connection_analyse3(self): # give a bad packet, won't be decoded bc = self._makeConnection() p = Packets.Ping() p.setId(1) self._appendToReadBuf(bc, '%s%sdatadatadatadata' % p.encode()) bc._analyse() self._checkPacketReceived(1) # ping packet self._checkClose(1) # malformed packet def test_Connection_analyse4(self): # give an expected packet bc = self._makeConnection() bc._queue = Mock() p = Packets.AnswerPrimary(self.getNewUUID(None)) p.setId(1) self._appendPacketToReadBuf(bc, p) bc._analyse() # check packet decoded self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1) call = bc._queue.mockGetNamedCalls("append")[0] data = call.getParam(0) self.assertEqual(type(data), type(p)) self.assertEqual(data.getId(), p.getId()) self.assertEqual(data.decode(), p.decode()) self._checkReadBuf(bc, '') def test_Connection_writable1(self): # with pending operation after send with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2): bc = self._makeConnection() self._checkWriteBuf(bc, '') bc.write_buf = ["testdata"] self.assertTrue(bc.pending()) self.assertFalse(bc.aborted) bc.writable() # test send was called self._checkSend(1, "testdata") self._checkWriteBuf(bc, "data") self._checkConnectionClosed(0) self._checkClose(0) self._checkUnregistered(0) # pending, so nothing called self.assertTrue(bc.pending()) self._checkWriterRemoved(0) self._checkReaderRemoved(0) self._checkClose(0) def test_Connection_writable2(self): # without pending operation after send with Patch(DummyConnector, send=lambda orig, self, data: len(data)): bc = self._makeConnection() self._checkWriteBuf(bc, '') bc.write_buf = ["testdata"] self.assertTrue(bc.pending()) self.assertFalse(bc.aborted) bc.writable() # test send was called self._checkSend(1, "testdata") self._checkWriteBuf(bc, '') self._checkConnectionClosed(0) self._checkClose(0) self._checkUnregistered(0) # nothing else pending, so writer has been removed self.assertFalse(bc.pending()) self._checkWriterRemoved(1) self._checkReaderRemoved(0) self._checkClose(0) def test_Connection_writable3(self): # without pending operation after send and aborted set to true with Patch(DummyConnector, send=lambda orig, self, data: len(data)): bc = self._makeConnection() self._checkWriteBuf(bc, '') bc.write_buf = ["testdata"] self.assertTrue(bc.pending()) bc.abort() self.assertTrue(bc.aborted) bc.writable() # test send was called self._checkSend(1, "testdata") self._checkWriteBuf(bc, '') self._checkConnectionClosed(1) self._checkClose(1) self._checkUnregistered(1) # nothing else pending, so writer has been removed self.assertFalse(bc.pending()) self._checkClose(1) def test_Connection_readable(self): # With aborted set to false # patch receive method to return data def receive(orig, self): p = Packets.AnswerPrimary(self.getNewUUID(None)) p.setId(1) return ''.join(p.encode()) with Patch(DummyConnector, receive=receive): bc = self._makeConnection() bc._queue = Mock({'__len__': 0}) self._checkReadBuf(bc, '') self.assertFalse(bc.aborted) bc.readable() # check packet decoded self._checkReadBuf(bc, '') self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1) call = bc._queue.mockGetNamedCalls("append")[0] data = call.getParam(0) self.assertEqual(type(data), Packets.AnswerPrimary) self.assertEqual(data.getId(), 1) self._checkReadBuf(bc, '') # check not aborted self.assertFalse(bc.aborted) self._checkUnregistered(0) self._checkWriterRemoved(0) self._checkReaderRemoved(0) self._checkClose(0) def test_ClientConnection_init1(self): # create a good client connection bc = self._makeClientConnection() # check connector created and connection initialize self.assertFalse(bc.connecting) self.assertFalse(bc.isServer()) self._checkMakeClientConnection(1) # check call to handler self.assertFalse(bc.getHandler() is None) self._checkConnectionStarted(1) self._checkConnectionCompleted(1) self._checkConnectionFailed(0) # check call to event manager self.assertIsNot(bc.em, None) self._checkWriterAdded(0) def test_ClientConnection_init2(self): # raise connection in progress def makeClientConnection(orig, self): raise ConnectorInProgressException with Patch(DummyConnector, makeClientConnection=makeClientConnection): bc = self._makeClientConnection() # check connector created and connection initialize self.assertTrue(bc.connecting) self.assertFalse(bc.isServer()) self._checkMakeClientConnection(1) # check call to handler self.assertFalse(bc.getHandler() is None) self._checkConnectionStarted(1) self._checkConnectionCompleted(0) self._checkConnectionFailed(0) # check call to event manager self.assertIsNot(bc.em, None) self._checkWriterAdded(1) def test_ClientConnection_init3(self): # raise another error, connection must fail def makeClientConnection(orig, self): raise ConnectorException with Patch(DummyConnector, makeClientConnection=makeClientConnection): self.assertRaises(ConnectorException, self._makeClientConnection) # since the exception was raised, the connection is not created # check call to handler self._checkConnectionStarted(1) self._checkConnectionCompleted(0) self._checkConnectionFailed(1) # check call to event manager self._checkWriterAdded(0) def test_ClientConnection_writable1(self): # with a non connecting connection, will call parent's method with Patch(DummyConnector, send=lambda orig, self, data: len(data)), \ Patch(DummyConnector, makeClientConnection=lambda orig, self: "OK") as p: bc = self._makeClientConnection() p.revert() # check connector created and connection initialize self.assertFalse(bc.connecting) self._checkWriteBuf(bc, '') bc.write_buf = ["testdata"] self.assertTrue(bc.pending()) self.assertFalse(bc.aborted) # call self._checkConnectionCompleted(1) bc.writable() self.assertFalse(bc.pending()) self.assertFalse(bc.aborted) self.assertFalse(bc.connecting) self._checkSend(1, "testdata") self._checkConnectionClosed(0) self._checkConnectionCompleted(1) self._checkConnectionFailed(0) self._checkUnregistered(0) self._checkWriterRemoved(1) self._checkReaderRemoved(0) self._checkClose(0) def test_ClientConnection_writable2(self): # with a connecting connection, must not call parent's method # with errors, close connection with Patch(DummyConnector, getError=lambda orig, self: True): bc = self._makeClientConnection() # check connector created and connection initialize self._checkWriteBuf(bc, '') bc.write_buf = ["testdata"] self.assertTrue(bc.pending()) self.assertFalse(bc.aborted) # call self._checkConnectionCompleted(1) bc.writable() self.assertFalse(bc.connecting) self.assertFalse(bc.pending()) self.assertFalse(bc.aborted) self._checkWriteBuf(bc, '') self._checkConnectionClosed(1) self._checkConnectionCompleted(1) self._checkConnectionFailed(0) self._checkUnregistered(1) def test_14_ServerConnection(self): bc = self._makeServerConnection() self.assertEqual(bc.getAddress(), ("127.0.0.7", 93413)) self._checkReadBuf(bc, '') self._checkWriteBuf(bc, '') self.assertEqual(bc.cur_id, 0) self.assertFalse(bc.aborted) # test uuid self.assertEqual(bc.uuid, None) self.assertEqual(bc.getUUID(), None) uuid = self.getNewUUID(None) bc.setUUID(uuid) self.assertEqual(bc.getUUID(), uuid) # test next id cur_id = bc.cur_id next_id = bc._getNextId() self.assertEqual(next_id, cur_id) next_id = bc._getNextId() self.assertTrue(next_id > cur_id) # test overflow of next id bc.cur_id = 0xffffffff next_id = bc._getNextId() self.assertEqual(next_id, 0xffffffff) next_id = bc._getNextId() self.assertEqual(next_id, 0) def test_15_Timeout(self): # NOTE: This method uses ping/pong packets only because MT connection # don't accept any other packet without specifying a queue. self.handler = EventHandler(self.app) conn = self._makeClientConnection() use_case_list = ( # (a) For a single packet sent at T, # the limit time for the answer is T + (1 * CRITICAL_TIMEOUT) ((), (1., 0)), # (b) Same as (a), even if send another packet at (T + CT/2). # But receiving a packet (at T + CT - ε) resets the timeout # (which means the limit for the 2nd one is T + 2*CT) ((.5, None), (1., 0, 2., 1)), # (c) Same as (b) with a first answer at well before the limit # (T' = T + CT/2). The limit for the second one is T' + CT. ((.1, None, .5, 1), (1.5, 0)), ) from neo.lib import connection def set_time(t): connection.time = lambda: int(CRITICAL_TIMEOUT * (1000 + t)) closed = [] conn.close = lambda: closed.append(connection.time()) def answer(packet_id): p = Packets.Pong() p.setId(packet_id) conn.connector.receive = [''.join(p.encode())].pop conn.readable() checkTimeout() conn.process() def checkTimeout(): timeout = conn.getTimeout() if timeout and timeout <= connection.time(): conn.onTimeout() try: for use_case, expected in use_case_list: i = iter(use_case) conn.cur_id = 0 set_time(0) # No timeout when no pending request self.assertEqual(conn._handlers.getNextTimeout(), None) conn.ask(Packets.Ping()) for t in i: set_time(t) checkTimeout() packet_id = i.next() if packet_id is None: conn.ask(Packets.Ping()) else: answer(packet_id) i = iter(expected) for t in i: set_time(t - .1) checkTimeout() set_time(t) # this test method relies on the fact that only # conn.close is called in case of a timeout checkTimeout() self.assertEqual(closed.pop(), connection.time()) answer(i.next()) self.assertFalse(conn.isPending()) self.assertFalse(closed) finally: connection.time = time class MTConnectionTests(ConnectionTests): # XXX: here we test non-client-connection-related things too, which # duplicates test suite work... Should be fragmented into finer-grained # test classes. def setUp(self): super(MTConnectionTests, self).setUp() self.dispatcher = Mock({'__repr__': 'Fake Dispatcher'}) def _makeClientConnection(self): with dummy_connector: conn = MTClientConnection(self.em, self.handler, self.node, dispatcher=self.dispatcher) self.connector = conn.connector return conn def test_MTClientConnectionQueueParameter(self): ask = self._makeClientConnection().ask packet = Packets.AskPrimary() # Any non-Ping simple "ask" packet # One cannot "ask" anything without a queue self.assertRaises(TypeError, ask, packet) ask(packet, queue=object()) # ... except Ping ask(Packets.Ping()) class HandlerSwitcherTests(NeoUnitTestBase): def setUp(self): NeoUnitTestBase.setUp(self) self._handler = handler = Mock({ '__repr__': 'initial handler', }) self._connection = Mock({ '__repr__': 'connection', 'getAddress': ('127.0.0.1', 10000), }) self._handlers = HandlerSwitcher(handler) def _makeNotification(self, msg_id): packet = Packets.StartOperation() packet.setId(msg_id) return packet def _makeRequest(self, msg_id): packet = Packets.AskBeginTransaction() packet.setId(msg_id) return packet def _makeAnswer(self, msg_id): packet = Packets.AnswerBeginTransaction(self.getNextTID()) packet.setId(msg_id) return packet def _makeHandler(self): return Mock({'__repr__': 'handler'}) def _checkPacketReceived(self, handler, packet, index=0): calls = handler.mockGetNamedCalls('packetReceived') self.assertEqual(len(calls), index + 1) def _checkCurrentHandler(self, handler): self.assertTrue(self._handlers.getHandler() is handler) def testInit(self): self._checkCurrentHandler(self._handler) self.assertFalse(self._handlers.isPending()) def testEmit(self): # First case, emit is called outside of a handler self.assertFalse(self._handlers.isPending()) request = self._makeRequest(1) self._handlers.emit(request, 0, None) self.assertTrue(self._handlers.isPending()) # Second case, emit is called from inside a handler with a pending # handler change. new_handler = self._makeHandler() applied = self._handlers.setHandler(new_handler) self.assertFalse(applied) self._checkCurrentHandler(self._handler) call_tracker = [] def packetReceived(conn, packet, kw): self._handlers.emit(self._makeRequest(2), 0, None) call_tracker.append(True) self._handler.packetReceived = packetReceived self._handlers.handle(self._connection, self._makeAnswer(1)) self.assertEqual(call_tracker, [True]) # Effective handler must not have changed (new request is blocking # it) self._checkCurrentHandler(self._handler) # Handling the next response will cause the handler to change delattr(self._handler, 'packetReceived') self._handlers.handle(self._connection, self._makeAnswer(2)) self._checkCurrentHandler(new_handler) def testHandleNotification(self): # handle with current handler notif1 = self._makeNotification(1) self._handlers.handle(self._connection, notif1) self._checkPacketReceived(self._handler, notif1) # emit a request and delay an handler request = self._makeRequest(2) self._handlers.emit(request, 0, None) handler = self._makeHandler() applied = self._handlers.setHandler(handler) self.assertFalse(applied) # next notification fall into the current handler notif2 = self._makeNotification(3) self._handlers.handle(self._connection, notif2) self._checkPacketReceived(self._handler, notif2, index=1) # handle with new handler answer = self._makeAnswer(2) self._handlers.handle(self._connection, answer) notif3 = self._makeNotification(4) self._handlers.handle(self._connection, notif3) self._checkPacketReceived(handler, notif2) def testHandleAnswer1(self): # handle with current handler request = self._makeRequest(1) self._handlers.emit(request, 0, None) answer = self._makeAnswer(1) self._handlers.handle(self._connection, answer) self._checkPacketReceived(self._handler, answer) def testHandleAnswer2(self): # handle with blocking handler request = self._makeRequest(1) self._handlers.emit(request, 0, None) handler = self._makeHandler() applied = self._handlers.setHandler(handler) self.assertFalse(applied) answer = self._makeAnswer(1) self._handlers.handle(self._connection, answer) self._checkPacketReceived(self._handler, answer) self._checkCurrentHandler(handler) def testHandleAnswer3(self): # multiple setHandler r1 = self._makeRequest(1) r2 = self._makeRequest(2) r3 = self._makeRequest(3) a1 = self._makeAnswer(1) a2 = self._makeAnswer(2) a3 = self._makeAnswer(3) h1 = self._makeHandler() h2 = self._makeHandler() h3 = self._makeHandler() # emit all requests and setHandleres self._handlers.emit(r1, 0, None) applied = self._handlers.setHandler(h1) self.assertFalse(applied) self._handlers.emit(r2, 0, None) applied = self._handlers.setHandler(h2) self.assertFalse(applied) self._handlers.emit(r3, 0, None) applied = self._handlers.setHandler(h3) self.assertFalse(applied) self._checkCurrentHandler(self._handler) self.assertTrue(self._handlers.isPending()) # process answers self._handlers.handle(self._connection, a1) self._checkCurrentHandler(h1) self._handlers.handle(self._connection, a2) self._checkCurrentHandler(h2) self._handlers.handle(self._connection, a3) self._checkCurrentHandler(h3) def testHandleAnswer4(self): # process out of order r1 = self._makeRequest(1) r2 = self._makeRequest(2) r3 = self._makeRequest(3) a1 = self._makeAnswer(1) a2 = self._makeAnswer(2) a3 = self._makeAnswer(3) h = self._makeHandler() # emit all requests self._handlers.emit(r1, 0, None) self._handlers.emit(r2, 0, None) self._handlers.emit(r3, 0, None) applied = self._handlers.setHandler(h) self.assertFalse(applied) # process answers self._handlers.handle(self._connection, a1) self._checkCurrentHandler(self._handler) self._handlers.handle(self._connection, a2) self._checkCurrentHandler(self._handler) self._handlers.handle(self._connection, a3) self._checkCurrentHandler(h) def testHandleUnexpected(self): # process out of order r1 = self._makeRequest(1) r2 = self._makeRequest(2) a2 = self._makeAnswer(2) h = self._makeHandler() # emit requests aroung state setHandler self._handlers.emit(r1, 0, None) applied = self._handlers.setHandler(h) self.assertFalse(applied) self._handlers.emit(r2, 0, None) # process answer for next state self._handlers.handle(self._connection, a2) self.checkAborted(self._connection) if __name__ == '__main__': unittest.main()