# -*- coding: utf-8 -*-
#
# Copyright (C) 2009-2015  Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
import unittest
from time import time
from mock import Mock
from neo.lib import connection, logging
from neo.lib.connection import BaseConnection, ListeningConnection, \
     Connection, ClientConnection, ServerConnection, MTClientConnection, \
     HandlerSwitcher, CRITICAL_TIMEOUT
from neo.lib.connector import registerConnectorHandler
from neo.lib.connector import ConnectorException, ConnectorTryAgainException, \
     ConnectorInProgressException, ConnectorConnectionRefusedException
from neo.lib.handler import EventHandler
from neo.lib.protocol import Packets, PACKET_HEADER_FORMAT
from . import NeoUnitTestBase, Patch


connector_cpt = 0

class DummyConnector(Mock):
    def __init__(self, addr, s=None):
        logging.info("initializing connector")
        global connector_cpt
        self.desc = connector_cpt
        connector_cpt += 1
        self.packet_cpt = 0
        self.addr = addr
        Mock.__init__(self)

    def getAddress(self):
        return self.addr

    def getDescriptor(self):
        return self.desc

    accept = getError = makeClientConnection = makeListeningConnection = \
    receive = send = lambda *args, **kw: None


dummy_connector = Patch(BaseConnection,
    ConnectorClass=lambda orig, self, *args, **kw: DummyConnector(*args, **kw))


class ConnectionTests(NeoUnitTestBase):

    def setUp(self):
        NeoUnitTestBase.setUp(self)
        self.app = Mock({'__repr__': 'Fake App'})
        self.em = Mock({'__repr__': 'Fake Em'})
        self.handler = Mock({'__repr__': 'Fake Handler'})
        self.address = ("127.0.0.7", 93413)
        self.node = Mock({'getAddress': self.address})
        connection.connect_limit = 0

    def _makeListeningConnection(self, addr):
        with dummy_connector:
            conn = ListeningConnection(self.em, self.handler, addr)
        self.connector = conn.connector
        return conn

    def _makeServerConnection(self):
        addr = self.address
        self.connector = DummyConnector(addr)
        return Connection(self.em, self.handler, self.connector, addr)

    def _makeClientConnection(self):
        with dummy_connector:
            conn = ClientConnection(self.em, self.handler, self.node)
        self.connector = conn.connector
        return conn

    _makeConnection = _makeClientConnection

    def _checkRegistered(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("register")), n)

    def _checkUnregistered(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("unregister")), n)

    def _checkReaderRemoved(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("removeReader")), n)

    def _checkWriterAdded(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("addWriter")), n)

    def _checkWriterRemoved(self, n=1):
        self.assertEqual(len(self.em.mockGetNamedCalls("removeWriter")), n)

    def _checkClose(self, n=1):
        self.assertEqual(len(self.connector.mockGetNamedCalls("close")), n)

    def _checkAccept(self, n=1):
        calls = self.connector.mockGetNamedCalls('accept')
        self.assertEqual(len(calls), n)

    def _checkSend(self, n=1, data=None):
        calls = self.connector.mockGetNamedCalls('send')
        self.assertEqual(len(calls), n)
        if n > 1 and data is not None:
            data = calls[n-1].getParam(0)
            self.assertEqual(data, "testdata")

    def _checkConnectionAccepted(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionAccepted')
        self.assertEqual(len(calls), n)

    def _checkConnectionFailed(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionFailed')
        self.assertEqual(len(calls), n)

    def _checkConnectionClosed(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionClosed')
        self.assertEqual(len(calls), n)

    def _checkConnectionStarted(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionStarted')
        self.assertEqual(len(calls), n)

    def _checkConnectionCompleted(self, n=1):
        calls = self.handler.mockGetNamedCalls('connectionCompleted')
        self.assertEqual(len(calls), n)

    def _checkMakeListeningConnection(self, n=1):
        calls = self.connector.mockGetNamedCalls('makeListeningConnection')
        self.assertEqual(len(calls), n)

    def _checkMakeClientConnection(self, n=1):
        calls = self.connector.mockGetNamedCalls("makeClientConnection")
        self.assertEqual(len(calls), n)

    def _checkPacketReceived(self, n=1):
        calls = self.handler.mockGetNamedCalls('packetReceived')
        self.assertEqual(len(calls), n)

    def _checkReadBuf(self, bc, data):
        content = bc.read_buf.read(len(bc.read_buf))
        self.assertEqual(''.join(content), data)

    def _appendToReadBuf(self, bc, data):
        bc.read_buf.append(data)

    def _appendPacketToReadBuf(self, bc, packet):
        data = ''.join(packet.encode())
        bc.read_buf.append(data)

    def _checkWriteBuf(self, bc, data):
        self.assertEqual(''.join(bc.write_buf), data)

    def test_01_BaseConnection(self):
        # init with address
        bc = self._makeConnection()
        self.assertEqual(bc.getAddress(), self.address)
        self.assertIsNot(bc.connector, None)
        self._checkRegistered(1)

    def test_02_ListeningConnection1(self):
        # test init part
        addr = ("127.0.0.7", 93413)
        with Patch(DummyConnector, accept=lambda orig, self: (self, ('', 0))):
            bc = self._makeListeningConnection(addr=addr)
            self.assertEqual(bc.getAddress(), addr)
            self._checkRegistered()
            self._checkMakeListeningConnection()
            # test readable
            bc.readable()
            self._checkAccept()
            self._checkConnectionAccepted()

    def test_02_ListeningConnection2(self):
        # test with exception raise when getting new connection
        def accept(orig, self):
            raise ConnectorTryAgainException
        addr = ("127.0.0.7", 93413)
        with Patch(DummyConnector, accept=accept):
            bc = self._makeListeningConnection(addr=addr)
            self.assertEqual(bc.getAddress(), addr)
            self._checkRegistered()
            self._checkMakeListeningConnection()
            # test readable
            bc.readable()
            self._checkAccept(1)
            self._checkConnectionAccepted(0)

    def test_03_Connection(self):
        bc = self._makeConnection()
        self.assertEqual(bc.getAddress(), self.address)
        self._checkReadBuf(bc, '')
        self._checkWriteBuf(bc, '')
        self.assertEqual(bc.cur_id, 0)
        self.assertFalse(bc.aborted)
        # test uuid
        self.assertEqual(bc.uuid, None)
        self.assertEqual(bc.getUUID(), None)
        uuid = self.getNewUUID(None)
        bc.setUUID(uuid)
        self.assertEqual(bc.getUUID(), uuid)
        # test next id
        cur_id = bc.cur_id
        next_id = bc._getNextId()
        self.assertEqual(next_id, cur_id)
        next_id = bc._getNextId()
        self.assertTrue(next_id > cur_id)
        # test overflow of next id
        bc.cur_id =  0xffffffff
        next_id = bc._getNextId()
        self.assertEqual(next_id, 0xffffffff)
        next_id = bc._getNextId()
        self.assertEqual(next_id, 0)

    def test_Connection_pending(self):
        bc = self._makeConnection()
        self.assertEqual(''.join(bc.write_buf), '')
        self.assertFalse(bc.pending())
        bc.write_buf += '1'
        self.assertTrue(bc.pending())

    def test_Connection_recv1(self):
        # patch receive method to return data
        with Patch(DummyConnector, receive=lambda orig, self: "testdata"):
            bc = self._makeConnection()
            self._checkReadBuf(bc, '')
            bc._recv()
            self._checkReadBuf(bc, 'testdata')

    def test_Connection_recv2(self):
        # patch receive method to raise try again
        def receive(orig, self):
            raise ConnectorTryAgainException
        with Patch(DummyConnector, receive=receive):
            bc = self._makeConnection()
            self._checkReadBuf(bc, '')
            bc._recv()
            self._checkReadBuf(bc, '')
            self._checkConnectionClosed(0)
            self._checkUnregistered(0)

    def test_Connection_recv3(self):
        # patch receive method to raise ConnectorConnectionRefusedException
        def receive(orig, self):
            raise ConnectorConnectionRefusedException
        with Patch(DummyConnector, receive=receive):
            bc = self._makeConnection()
            self._checkReadBuf(bc, '')
            # fake client connection instance with connecting attribute
            bc.connecting = True
            bc._recv()
            self._checkReadBuf(bc, '')
            self._checkConnectionFailed(1)
            self._checkUnregistered(1)

    def test_Connection_recv4(self):
        # patch receive method to raise any other connector error
        def receive(orig, self):
            raise ConnectorException
        with Patch(DummyConnector, receive=receive):
            bc = self._makeConnection()
            self._checkReadBuf(bc, '')
            self.assertRaises(ConnectorException, bc._recv)
            self._checkReadBuf(bc, '')
            self._checkConnectionClosed(1)
            self._checkUnregistered(1)

    def test_Connection_send1(self):
        # no data, nothing done
        # patch receive method to return data
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc._send()
        self._checkSend(0)
        self._checkConnectionClosed(0)
        self._checkUnregistered(0)

    def test_Connection_send2(self):
        # send all data
        with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
            bc = self._makeConnection()
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata"]
            bc._send()
            self._checkSend(1, "testdata")
            self._checkWriteBuf(bc, '')
            self._checkConnectionClosed(0)
            self._checkUnregistered(0)

    def test_Connection_send3(self):
        # send part of the data
        with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
            bc = self._makeConnection()
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata"]
            bc._send()
            self._checkSend(1, "testdata")
            self._checkWriteBuf(bc, 'data')
            self._checkConnectionClosed(0)
            self._checkUnregistered(0)

    def test_Connection_send4(self):
        # send multiple packet
        with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
            bc = self._makeConnection()
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata", "second", "third"]
            bc._send()
            self._checkSend(1, "testdatasecondthird")
            self._checkWriteBuf(bc, '')
            self._checkConnectionClosed(0)
            self._checkUnregistered(0)

    def test_Connection_send5(self):
        # send part of multiple packet
        with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
            bc = self._makeConnection()
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata", "second", "third"]
            bc._send()
            self._checkSend(1, "testdatasecondthird")
            self._checkWriteBuf(bc, 'econdthird')
            self._checkConnectionClosed(0)
            self._checkUnregistered(0)

    def test_Connection_send6(self):
        # raise try again
        def send(orig, self, data):
            raise ConnectorTryAgainException
        with Patch(DummyConnector, send=send):
            bc = self._makeConnection()
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata", "second", "third"]
            bc._send()
            self._checkSend(1, "testdatasecondthird")
            self._checkWriteBuf(bc, 'testdatasecondthird')
            self._checkConnectionClosed(0)
            self._checkUnregistered(0)

    def test_Connection_send7(self):
        # raise other error
        def send(orig, self, data):
            raise ConnectorException
        with Patch(DummyConnector, send=send):
            bc = self._makeConnection()
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata", "second", "third"]
            self.assertRaises(ConnectorException, bc._send)
            self._checkSend(1, "testdatasecondthird")
            # connection closed -> buffers flushed
            self._checkWriteBuf(bc, '')
            self._checkConnectionClosed(1)
            self._checkUnregistered(1)

    def test_07_Connection_addPacket(self):
        # new packet
        p = Packets.Ping()
        p._id = 0
        bc = self._makeConnection()
        self._checkWriteBuf(bc, '')
        bc._addPacket(p)
        self._checkWriteBuf(bc, PACKET_HEADER_FORMAT.pack(0, p._code, 10))
        self._checkWriterAdded(1)

    def test_Connection_analyse1(self):
        # nothing to read, nothing is done
        bc = self._makeConnection()
        bc._queue = Mock()
        self._checkReadBuf(bc, '')
        bc._analyse()
        self._checkPacketReceived(0)
        self._checkReadBuf(bc, '')

        p = Packets.AnswerPrimary(self.getNewUUID(None))
        p.setId(1)
        p_data = ''.join(p.encode())
        data_edge = len(p_data) - 1
        p_data_1, p_data_2 = p_data[:data_edge], p_data[data_edge:]
        # append an incomplete packet, nothing is done
        bc.read_buf.append(p_data_1)
        bc._analyse()
        self._checkPacketReceived(0)
        self.assertNotEqual(len(bc.read_buf), 0)
        self.assertNotEqual(len(bc.read_buf), len(p_data))
        # append the rest of the packet
        bc.read_buf.append(p_data_2)
        bc._analyse()
        # check packet decoded
        self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
        call = bc._queue.mockGetNamedCalls("append")[0]
        data = call.getParam(0)
        self.assertEqual(type(data), type(p))
        self.assertEqual(data.getId(), p.getId())
        self.assertEqual(data.decode(), p.decode())
        self._checkReadBuf(bc, '')

    def test_Connection_analyse2(self):
        # give multiple packet
        bc = self._makeConnection()
        bc._queue = Mock()
        p1 = Packets.AnswerPrimary(self.getNewUUID(None))
        p1.setId(1)
        self._appendPacketToReadBuf(bc, p1)
        p2 = Packets.AnswerPrimary( self.getNewUUID(None))
        p2.setId(2)
        self._appendPacketToReadBuf(bc, p2)
        self.assertEqual(len(bc.read_buf), len(p1) + len(p2))
        bc._analyse()
        # check two packets decoded
        self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 2)
        # packet 1
        call = bc._queue.mockGetNamedCalls("append")[0]
        data = call.getParam(0)
        self.assertEqual(type(data), type(p1))
        self.assertEqual(data.getId(), p1.getId())
        self.assertEqual(data.decode(), p1.decode())
        # packet 2
        call = bc._queue.mockGetNamedCalls("append")[1]
        data = call.getParam(0)
        self.assertEqual(type(data), type(p2))
        self.assertEqual(data.getId(), p2.getId())
        self.assertEqual(data.decode(), p2.decode())
        self._checkReadBuf(bc, '')

    def test_Connection_analyse3(self):
        # give a bad packet, won't be decoded
        bc = self._makeConnection()
        p = Packets.Ping()
        p.setId(1)
        self._appendToReadBuf(bc, '%s%sdatadatadatadata' % p.encode())
        bc._analyse()
        self._checkPacketReceived(1) # ping packet
        self._checkClose(1) # malformed packet

    def test_Connection_analyse4(self):
        # give an expected packet
        bc = self._makeConnection()
        bc._queue = Mock()
        p = Packets.AnswerPrimary(self.getNewUUID(None))
        p.setId(1)
        self._appendPacketToReadBuf(bc, p)
        bc._analyse()
        # check packet decoded
        self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
        call = bc._queue.mockGetNamedCalls("append")[0]
        data = call.getParam(0)
        self.assertEqual(type(data), type(p))
        self.assertEqual(data.getId(), p.getId())
        self.assertEqual(data.decode(), p.decode())
        self._checkReadBuf(bc, '')

    def test_Connection_writable1(self):
        # with pending operation after send
        with Patch(DummyConnector, send=lambda orig, self, data: len(data)//2):
            bc = self._makeConnection()
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata"]
            self.assertTrue(bc.pending())
            self.assertFalse(bc.aborted)
            bc.writable()
            # test send was called
            self._checkSend(1, "testdata")
            self._checkWriteBuf(bc, "data")
            self._checkConnectionClosed(0)
            self._checkClose(0)
            self._checkUnregistered(0)
            # pending, so nothing called
            self.assertTrue(bc.pending())
            self._checkWriterRemoved(0)
            self._checkReaderRemoved(0)
            self._checkClose(0)

    def test_Connection_writable2(self):
        # without pending operation after send
        with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
            bc = self._makeConnection()
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata"]
            self.assertTrue(bc.pending())
            self.assertFalse(bc.aborted)
            bc.writable()
            # test send was called
            self._checkSend(1, "testdata")
            self._checkWriteBuf(bc, '')
            self._checkConnectionClosed(0)
            self._checkClose(0)
            self._checkUnregistered(0)
            # nothing else pending, so writer has been removed
            self.assertFalse(bc.pending())
            self._checkWriterRemoved(1)
            self._checkReaderRemoved(0)
            self._checkClose(0)

    def test_Connection_writable3(self):
        # without pending operation after send and aborted set to true
        with Patch(DummyConnector, send=lambda orig, self, data: len(data)):
            bc = self._makeConnection()
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata"]
            self.assertTrue(bc.pending())
            bc.abort()
            self.assertTrue(bc.aborted)
            bc.writable()
            # test send was called
            self._checkSend(1, "testdata")
            self._checkWriteBuf(bc, '')
            self._checkConnectionClosed(1)
            self._checkClose(1)
            self._checkUnregistered(1)
            # nothing else pending, so writer has been removed
            self.assertFalse(bc.pending())
            self._checkClose(1)

    def test_Connection_readable(self):
        # With aborted set to false
        # patch receive method to return data
        def receive(orig, self):
            p = Packets.AnswerPrimary(self.getNewUUID(None))
            p.setId(1)
            return ''.join(p.encode())
        with Patch(DummyConnector, receive=receive):
            bc = self._makeConnection()
            bc._queue = Mock({'__len__': 0})
            self._checkReadBuf(bc, '')
            self.assertFalse(bc.aborted)
            bc.readable()
            # check packet decoded
            self._checkReadBuf(bc, '')
            self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
            call = bc._queue.mockGetNamedCalls("append")[0]
            data = call.getParam(0)
            self.assertEqual(type(data), Packets.AnswerPrimary)
            self.assertEqual(data.getId(), 1)
            self._checkReadBuf(bc, '')
            # check not aborted
            self.assertFalse(bc.aborted)
            self._checkUnregistered(0)
            self._checkWriterRemoved(0)
            self._checkReaderRemoved(0)
            self._checkClose(0)

    def test_ClientConnection_init1(self):
        # create a good client connection
        bc = self._makeClientConnection()
        # check connector created and connection initialize
        self.assertFalse(bc.connecting)
        self.assertFalse(bc.isServer())
        self._checkMakeClientConnection(1)
        # check call to handler
        self.assertFalse(bc.getHandler() is None)
        self._checkConnectionStarted(1)
        self._checkConnectionCompleted(1)
        self._checkConnectionFailed(0)
        # check call to event manager
        self.assertIsNot(bc.em, None)
        self._checkWriterAdded(0)

    def test_ClientConnection_init2(self):
        # raise connection in progress
        def makeClientConnection(orig, self):
            raise ConnectorInProgressException
        with Patch(DummyConnector, makeClientConnection=makeClientConnection):
            bc = self._makeClientConnection()
        # check connector created and connection initialize
        self.assertTrue(bc.connecting)
        self.assertFalse(bc.isServer())
        self._checkMakeClientConnection(1)
        # check call to handler
        self.assertFalse(bc.getHandler() is None)
        self._checkConnectionStarted(1)
        self._checkConnectionCompleted(0)
        self._checkConnectionFailed(0)
        # check call to event manager
        self.assertIsNot(bc.em, None)
        self._checkWriterAdded(1)

    def test_ClientConnection_init3(self):
        # raise another error, connection must fail
        def makeClientConnection(orig, self):
            raise ConnectorException
        with Patch(DummyConnector, makeClientConnection=makeClientConnection):
            self.assertRaises(ConnectorException, self._makeClientConnection)
        # since the exception was raised, the connection is not created
        # check call to handler
        self._checkConnectionStarted(1)
        self._checkConnectionCompleted(0)
        self._checkConnectionFailed(1)
        # check call to event manager
        self._checkWriterAdded(0)

    def test_ClientConnection_writable1(self):
        # with a non connecting connection, will call parent's method
        with Patch(DummyConnector, send=lambda orig, self, data: len(data)), \
             Patch(DummyConnector,
                   makeClientConnection=lambda orig, self: "OK") as p:
            bc = self._makeClientConnection()
            p.revert()
            # check connector created and connection initialize
            self.assertFalse(bc.connecting)
            self._checkWriteBuf(bc, '')
            bc.write_buf = ["testdata"]
            self.assertTrue(bc.pending())
            self.assertFalse(bc.aborted)
            # call
            self._checkConnectionCompleted(1)
            bc.writable()
            self.assertFalse(bc.pending())
            self.assertFalse(bc.aborted)
            self.assertFalse(bc.connecting)
            self._checkSend(1, "testdata")
            self._checkConnectionClosed(0)
            self._checkConnectionCompleted(1)
            self._checkConnectionFailed(0)
            self._checkUnregistered(0)
            self._checkWriterRemoved(1)
            self._checkReaderRemoved(0)
            self._checkClose(0)

    def test_ClientConnection_writable2(self):
        # with a connecting connection, must not call parent's method
        # with errors, close connection
        with Patch(DummyConnector, getError=lambda orig, self: True):
            bc = self._makeClientConnection()
        # check connector created and connection initialize
        self._checkWriteBuf(bc, '')
        bc.write_buf = ["testdata"]
        self.assertTrue(bc.pending())
        self.assertFalse(bc.aborted)
        # call
        self._checkConnectionCompleted(1)
        bc.writable()
        self.assertFalse(bc.connecting)
        self.assertFalse(bc.pending())
        self.assertFalse(bc.aborted)
        self._checkWriteBuf(bc, '')
        self._checkConnectionClosed(1)
        self._checkConnectionCompleted(1)
        self._checkConnectionFailed(0)
        self._checkUnregistered(1)

    def test_14_ServerConnection(self):
        bc = self._makeServerConnection()
        self.assertEqual(bc.getAddress(), ("127.0.0.7", 93413))
        self._checkReadBuf(bc, '')
        self._checkWriteBuf(bc, '')
        self.assertEqual(bc.cur_id, 0)
        self.assertFalse(bc.aborted)
        # test uuid
        self.assertEqual(bc.uuid, None)
        self.assertEqual(bc.getUUID(), None)
        uuid = self.getNewUUID(None)
        bc.setUUID(uuid)
        self.assertEqual(bc.getUUID(), uuid)
        # test next id
        cur_id = bc.cur_id
        next_id = bc._getNextId()
        self.assertEqual(next_id, cur_id)
        next_id = bc._getNextId()
        self.assertTrue(next_id > cur_id)
        # test overflow of next id
        bc.cur_id =  0xffffffff
        next_id = bc._getNextId()
        self.assertEqual(next_id, 0xffffffff)
        next_id = bc._getNextId()
        self.assertEqual(next_id, 0)

    def test_15_Timeout(self):
        # NOTE: This method uses ping/pong packets only because MT connection
        #       don't accept any other packet without specifying a queue.
        self.handler = EventHandler(self.app)
        conn = self._makeClientConnection()

        use_case_list = (
            # (a) For a single packet sent at T,
            #     the limit time for the answer is T + (1 * CRITICAL_TIMEOUT)
            ((), (1., 0)),
            # (b) Same as (a), even if send another packet at (T + CT/2).
            #     But receiving a packet (at T + CT - ε) resets the timeout
            #     (which means the limit for the 2nd one is T + 2*CT)
            ((.5, None), (1., 0, 2., 1)),
            # (c) Same as (b) with a first answer at well before the limit
            #     (T' = T + CT/2). The limit for the second one is T' + CT.
            ((.1, None, .5, 1), (1.5, 0)),
        )

        from neo.lib import connection
        def set_time(t):
            connection.time = lambda: int(CRITICAL_TIMEOUT * (1000 + t))
        closed = []
        conn.close = lambda: closed.append(connection.time())
        def answer(packet_id):
            p = Packets.Pong()
            p.setId(packet_id)
            conn.connector.receive = [''.join(p.encode())].pop
            conn.readable()
            checkTimeout()
            conn.process()
        def checkTimeout():
            timeout = conn.getTimeout()
            if timeout and timeout <= connection.time():
                conn.onTimeout()
        try:
            for use_case, expected in use_case_list:
                i = iter(use_case)
                conn.cur_id = 0
                set_time(0)
                # No timeout when no pending request
                self.assertEqual(conn._handlers.getNextTimeout(), None)
                conn.ask(Packets.Ping())
                for t in i:
                    set_time(t)
                    checkTimeout()
                    packet_id = i.next()
                    if packet_id is None:
                        conn.ask(Packets.Ping())
                    else:
                        answer(packet_id)
                i = iter(expected)
                for t in i:
                    set_time(t - .1)
                    checkTimeout()
                    set_time(t)
                    # this test method relies on the fact that only
                    # conn.close is called in case of a timeout
                    checkTimeout()
                    self.assertEqual(closed.pop(), connection.time())
                    answer(i.next())
                self.assertFalse(conn.isPending())
                self.assertFalse(closed)
        finally:
            connection.time = time

class MTConnectionTests(ConnectionTests):
    # XXX: here we test non-client-connection-related things too, which
    # duplicates test suite work... Should be fragmented into finer-grained
    # test classes.

    def setUp(self):
        super(MTConnectionTests, self).setUp()
        self.dispatcher = Mock({'__repr__': 'Fake Dispatcher'})

    def _makeClientConnection(self):
        with dummy_connector:
            conn = MTClientConnection(self.em, self.handler, self.node,
                                      dispatcher=self.dispatcher)
        self.connector = conn.connector
        return conn

    def test_MTClientConnectionQueueParameter(self):
        ask = self._makeClientConnection().ask
        packet = Packets.AskPrimary() # Any non-Ping simple "ask" packet
        # One cannot "ask" anything without a queue
        self.assertRaises(TypeError, ask, packet)
        ask(packet, queue=object())
        # ... except Ping
        ask(Packets.Ping())

class HandlerSwitcherTests(NeoUnitTestBase):

    def setUp(self):
        NeoUnitTestBase.setUp(self)
        self._handler = handler = Mock({
            '__repr__': 'initial handler',
        })
        self._connection = Mock({
            '__repr__': 'connection',
            'getAddress': ('127.0.0.1', 10000),
        })
        self._handlers = HandlerSwitcher(handler)

    def _makeNotification(self, msg_id):
        packet = Packets.StartOperation()
        packet.setId(msg_id)
        return packet

    def _makeRequest(self, msg_id):
        packet = Packets.AskBeginTransaction()
        packet.setId(msg_id)
        return packet

    def _makeAnswer(self, msg_id):
        packet = Packets.AnswerBeginTransaction(self.getNextTID())
        packet.setId(msg_id)
        return packet

    def _makeHandler(self):
        return Mock({'__repr__': 'handler'})

    def _checkPacketReceived(self, handler, packet, index=0):
        calls = handler.mockGetNamedCalls('packetReceived')
        self.assertEqual(len(calls), index + 1)

    def _checkCurrentHandler(self, handler):
        self.assertTrue(self._handlers.getHandler() is handler)

    def testInit(self):
        self._checkCurrentHandler(self._handler)
        self.assertFalse(self._handlers.isPending())

    def testEmit(self):
        # First case, emit is called outside of a handler
        self.assertFalse(self._handlers.isPending())
        request = self._makeRequest(1)
        self._handlers.emit(request, 0, None)
        self.assertTrue(self._handlers.isPending())
        # Second case, emit is called from inside a handler with a pending
        # handler change.
        new_handler = self._makeHandler()
        applied = self._handlers.setHandler(new_handler)
        self.assertFalse(applied)
        self._checkCurrentHandler(self._handler)
        call_tracker = []
        def packetReceived(conn, packet, kw):
            self._handlers.emit(self._makeRequest(2), 0, None)
            call_tracker.append(True)
        self._handler.packetReceived = packetReceived
        self._handlers.handle(self._connection, self._makeAnswer(1))
        self.assertEqual(call_tracker, [True])
        # Effective handler must not have changed (new request is blocking
        # it)
        self._checkCurrentHandler(self._handler)
        # Handling the next response will cause the handler to change
        delattr(self._handler, 'packetReceived')
        self._handlers.handle(self._connection, self._makeAnswer(2))
        self._checkCurrentHandler(new_handler)

    def testHandleNotification(self):
        # handle with current handler
        notif1 = self._makeNotification(1)
        self._handlers.handle(self._connection, notif1)
        self._checkPacketReceived(self._handler, notif1)
        # emit a request and delay an handler
        request = self._makeRequest(2)
        self._handlers.emit(request, 0, None)
        handler = self._makeHandler()
        applied = self._handlers.setHandler(handler)
        self.assertFalse(applied)
        # next notification fall into the current handler
        notif2 = self._makeNotification(3)
        self._handlers.handle(self._connection, notif2)
        self._checkPacketReceived(self._handler, notif2, index=1)
        # handle with new handler
        answer = self._makeAnswer(2)
        self._handlers.handle(self._connection, answer)
        notif3 = self._makeNotification(4)
        self._handlers.handle(self._connection, notif3)
        self._checkPacketReceived(handler, notif2)

    def testHandleAnswer1(self):
        # handle with current handler
        request = self._makeRequest(1)
        self._handlers.emit(request, 0, None)
        answer = self._makeAnswer(1)
        self._handlers.handle(self._connection, answer)
        self._checkPacketReceived(self._handler, answer)

    def testHandleAnswer2(self):
        # handle with blocking handler
        request = self._makeRequest(1)
        self._handlers.emit(request, 0, None)
        handler = self._makeHandler()
        applied = self._handlers.setHandler(handler)
        self.assertFalse(applied)
        answer = self._makeAnswer(1)
        self._handlers.handle(self._connection, answer)
        self._checkPacketReceived(self._handler, answer)
        self._checkCurrentHandler(handler)

    def testHandleAnswer3(self):
        # multiple setHandler
        r1 = self._makeRequest(1)
        r2 = self._makeRequest(2)
        r3 = self._makeRequest(3)
        a1 = self._makeAnswer(1)
        a2 = self._makeAnswer(2)
        a3 = self._makeAnswer(3)
        h1 = self._makeHandler()
        h2 = self._makeHandler()
        h3 = self._makeHandler()
        # emit all requests and setHandleres
        self._handlers.emit(r1, 0, None)
        applied = self._handlers.setHandler(h1)
        self.assertFalse(applied)
        self._handlers.emit(r2, 0, None)
        applied = self._handlers.setHandler(h2)
        self.assertFalse(applied)
        self._handlers.emit(r3, 0, None)
        applied = self._handlers.setHandler(h3)
        self.assertFalse(applied)
        self._checkCurrentHandler(self._handler)
        self.assertTrue(self._handlers.isPending())
        # process answers
        self._handlers.handle(self._connection, a1)
        self._checkCurrentHandler(h1)
        self._handlers.handle(self._connection, a2)
        self._checkCurrentHandler(h2)
        self._handlers.handle(self._connection, a3)
        self._checkCurrentHandler(h3)

    def testHandleAnswer4(self):
        # process out of order
        r1 = self._makeRequest(1)
        r2 = self._makeRequest(2)
        r3 = self._makeRequest(3)
        a1 = self._makeAnswer(1)
        a2 = self._makeAnswer(2)
        a3 = self._makeAnswer(3)
        h = self._makeHandler()
        # emit all requests
        self._handlers.emit(r1, 0, None)
        self._handlers.emit(r2, 0, None)
        self._handlers.emit(r3, 0, None)
        applied = self._handlers.setHandler(h)
        self.assertFalse(applied)
        # process answers
        self._handlers.handle(self._connection, a1)
        self._checkCurrentHandler(self._handler)
        self._handlers.handle(self._connection, a2)
        self._checkCurrentHandler(self._handler)
        self._handlers.handle(self._connection, a3)
        self._checkCurrentHandler(h)

    def testHandleUnexpected(self):
        # process out of order
        r1 = self._makeRequest(1)
        r2 = self._makeRequest(2)
        a2 = self._makeAnswer(2)
        h = self._makeHandler()
        # emit requests aroung state setHandler
        self._handlers.emit(r1, 0, None)
        applied = self._handlers.setHandler(h)
        self.assertFalse(applied)
        self._handlers.emit(r2, 0, None)
        # process answer for next state
        self._handlers.handle(self._connection, a2)
        self.checkAborted(self._connection)


if __name__ == '__main__':
    unittest.main()