#
# Copyright (C) 2009-2010  Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.

import unittest
from mock import Mock
from neo.tests import NeoTestBase
from neo.pt import PartitionTable
from neo.storage.app import Application
from neo.storage.handlers.verification import VerificationHandler
from neo.protocol import CellStates, ErrorCodes
from neo.protocol import INVALID_OID, INVALID_TID
from neo.exception import PrimaryFailure, OperationFailure
from neo.util import p64, u64

class StorageVerificationHandlerTests(NeoTestBase):

    def setUp(self):
        self.prepareDatabase(number=1)
        # create an application object
        config = self.getStorageConfiguration(master_number=1)
        self.app = Application(config)
        self.verification = VerificationHandler(self.app)
        # define some variable to simulate client and storage node
        self.master_port = 10010
        self.storage_port = 10020
        self.client_port = 11011
        self.num_partitions = 1009
        self.num_replicas = 2
        self.app.operational = False
        self.app.load_lock_dict = {}
        self.app.pt = PartitionTable(self.num_partitions, self.num_replicas)


    def tearDown(self):
        NeoTestBase.tearDown(self)

    # Common methods
    def getLastUUID(self):
        return self.uuid

    def getClientConnection(self):
        address = ("127.0.0.1", self.client_port)
        return self.getFakeConnection(uuid=self.getNewUUID(), address=address)

    def getMasterConnection(self):
        return self.getFakeConnection(address=("127.0.0.1", self.master_port))

    # Tests
    def test_02_timeoutExpired(self):
        conn = self.getClientConnection()
        self.assertRaises(PrimaryFailure, self.verification.timeoutExpired, conn,)
        # nothing happens
        self.checkNoPacketSent(conn)

    def test_03_connectionClosed(self):
        conn = self.getClientConnection()
        self.assertRaises(PrimaryFailure, self.verification.connectionClosed, conn,)
        # nothing happens
        self.checkNoPacketSent(conn)

    def test_04_peerBroken(self):
        conn = self.getClientConnection()
        self.assertRaises(PrimaryFailure, self.verification.peerBroken, conn,)
        # nothing happens
        self.checkNoPacketSent(conn)

    def test_07_askLastIDs(self):
        conn = self.getClientConnection()
        last_ptid = self.getPTID(1)
        last_oid = self.getOID(2)
        self.app.pt = Mock({'getID': last_ptid})
        self.verification.askLastIDs(conn)
        oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True)
        self.assertEqual(oid, None)
        self.assertEqual(tid, None)
        self.assertEqual(ptid, last_ptid)

        # return value stored in db
        # insert some oid
        conn = self.getClientConnection()
        self.app.dm.begin()
        self.app.dm.query("""insert into obj (oid, serial, compression,
        checksum, value) values (3, 'A', 0, 0, '')""")
        self.app.dm.query("""insert into obj (oid, serial, compression,
        checksum, value) values (1, 'A', 0, 0, '')""")
        self.app.dm.query("""insert into obj (oid, serial, compression,
        checksum, value) values (2, 'A', 0, 0, '')""")
        self.app.dm.query("""insert into tobj (oid, serial, compression,
        checksum, value) values (5, 'A', 0, 0, '')""")
        # insert some tid
        self.app.dm.query("""insert into trans (tid, oids, user,
                description, ext) values (1, '', '', '', '')""")
        self.app.dm.query("""insert into trans (tid, oids, user,
                description, ext) values (2, '', '', '', '')""")
        self.app.dm.query("""insert into ttrans (tid, oids, user,
                description, ext) values (3, '', '', '', '')""")
        # max tid is in tobj (serial)
        self.app.dm.query("""insert into tobj (oid, serial, compression,
                checksum, value) values (0, 4, 0, 0, '')""")
        self.app.dm.commit()
        self.app.dm.setLastOID(last_oid)
        self.verification.askLastIDs(conn)
        self.checkAnswerLastIDs(conn)
        oid, tid, ptid = self.checkAnswerLastIDs(conn, decode=True)
        self.assertEqual(oid, last_oid)
        self.assertEqual(u64(tid), 4)
        self.assertEqual(ptid, self.app.pt.getID())

    def test_08_askPartitionTable(self):
        node = self.app.nm.createStorage(
            address=("127.7.9.9", 1),
            uuid=self.getNewUUID()
        )
        self.app.pt.setCell(1, node, CellStates.UP_TO_DATE)
        self.assertTrue(self.app.pt.hasOffset(1))
        conn = self.getClientConnection()
        self.verification.askPartitionTable(conn)
        ptid, row_list = self.checkAnswerPartitionTable(conn, decode=True)
        self.assertEqual(len(row_list), 1009)

    def test_10_notifyPartitionChanges(self):
        # old partition change
        conn = self.getMasterConnection()
        self.verification.notifyPartitionChanges(conn, 1, ())
        self.verification.notifyPartitionChanges(conn, 0, ())
        self.assertEqual(self.app.pt.getID(), 1)

        # new node
        conn = self.getMasterConnection()
        new_uuid = self.getNewUUID()
        cell = (0, new_uuid, CellStates.UP_TO_DATE)
        self.app.nm.createStorage(uuid=new_uuid)
        self.app.pt = PartitionTable(1, 1)
        self.app.dm = Mock({ })
        ptid, self.ptid = self.getTwoIDs()
        # pt updated
        self.verification.notifyPartitionChanges(conn, ptid, (cell, ))
        # check db update
        calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
        self.assertEquals(len(calls), 1)
        self.assertEquals(calls[0].getParam(0), ptid)
        self.assertEquals(calls[0].getParam(1), (cell, ))

    def test_11_startOperation(self):
        conn = self.getMasterConnection()
        self.assertFalse(self.app.operational)
        self.verification.startOperation(conn)
        self.assertTrue(self.app.operational)

    def test_12_stopOperation(self):
        conn = self.getMasterConnection()
        self.assertRaises(OperationFailure, self.verification.stopOperation, conn)

    def test_13_askUnfinishedTransactions(self):
        # client connection with no data
        conn = self.getMasterConnection()
        self.verification.askUnfinishedTransactions(conn)
        (tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
        self.assertEqual(len(tid_list), 0)

        # client connection with some data
        self.app.dm.begin()
        self.app.dm.query("""insert into tobj (oid, serial, compression,
                checksum, value) values (0, 4, 0, 0, '')""")
        self.app.dm.commit()
        conn = self.getMasterConnection()
        self.verification.askUnfinishedTransactions(conn)
        (tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
        self.assertEqual(len(tid_list), 1)
        self.assertEqual(u64(tid_list[0]), 4)

    def test_14_askTransactionInformation(self):
        # ask from client conn with no data
        conn = self.getMasterConnection()
        self.verification.askTransactionInformation(conn, p64(1))
        code, message = self.checkErrorPacket(conn, decode=True)
        self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)

        # input some tmp data and ask from client, must find both transaction
        self.app.dm.begin()
        self.app.dm.query("""insert into ttrans (tid, oids, user,
        description, ext) values (3, '%s', 'u1', 'd1', 'e1')""" %(p64(4),))
        self.app.dm.query("""insert into trans (tid, oids, user,
        description, ext) values (1,'%s', 'u2', 'd2', 'e2')""" %(p64(2),))
        self.app.dm.commit()
        # object from trans
        conn = self.getClientConnection()
        self.verification.askTransactionInformation(conn, p64(1))
        tid, user, desc, ext, packed, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
        self.assertEqual(u64(tid), 1)
        self.assertEqual(user, 'u2')
        self.assertEqual(desc, 'd2')
        self.assertEqual(ext, 'e2')
        self.assertEqual(len(oid_list), 1)
        self.assertEqual(u64(oid_list[0]), 2)
        # object from ttrans
        conn = self.getMasterConnection()
        self.verification.askTransactionInformation(conn, p64(3))
        tid, user, desc, ext, packed, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
        self.assertEqual(u64(tid), 3)
        self.assertEqual(user, 'u1')
        self.assertEqual(desc, 'd1')
        self.assertEqual(ext, 'e1')
        self.assertEqual(len(oid_list), 1)
        self.assertEqual(u64(oid_list[0]), 4)

        # input some tmp data and ask from server, must find one transaction
        conn = self.getMasterConnection()
        # find the one in trans
        self.verification.askTransactionInformation(conn, p64(1))
        tid, user, desc, ext, packed, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
        self.assertEqual(u64(tid), 1)
        self.assertEqual(user, 'u2')
        self.assertEqual(desc, 'd2')
        self.assertEqual(ext, 'e2')
        self.assertEqual(len(oid_list), 1)
        self.assertEqual(u64(oid_list[0]), 2)
        # do not find the one in ttrans
        conn = self.getMasterConnection()
        self.verification.askTransactionInformation(conn, p64(2))
        code, message = self.checkErrorPacket(conn, decode=True)
        self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)

    def test_15_askObjectPresent(self):
        # client connection with no data
        conn = self.getMasterConnection()
        self.verification.askObjectPresent(conn, p64(1), p64(2))
        code, message = self.checkErrorPacket(conn, decode=True)
        self.assertEqual(code, ErrorCodes.OID_NOT_FOUND)

        # client connection with some data
        self.app.dm.begin()
        self.app.dm.query("""insert into tobj (oid, serial, compression,
                checksum, value) values (1, 2, 0, 0, '')""")
        self.app.dm.commit()
        conn = self.getMasterConnection()
        self.verification.askObjectPresent(conn, p64(1), p64(2))
        oid, tid = self.checkAnswerObjectPresent(conn, decode=True)
        self.assertEqual(u64(tid), 2)
        self.assertEqual(u64(oid), 1)

    def test_16_deleteTransaction(self):
        # client connection with no data
        conn = self.getMasterConnection()
        self.verification.deleteTransaction(conn, p64(1))
        # client connection with data
        self.app.dm.begin()
        self.app.dm.query("""insert into tobj (oid, serial, compression,
                checksum, value) values (1, 2, 0, 0, '')""")
        self.app.dm.commit()
        self.verification.deleteTransaction(conn, p64(2))
        result = self.app.dm.query('select * from tobj')
        self.assertEquals(len(result), 0)

    def test_17_commitTransaction(self):
        # commit a transaction
        conn = self.getMasterConnection()
        dm = Mock()
        self.app.dm = dm
        self.verification.commitTransaction(conn, p64(1))
        self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1)
        call = dm.mockGetNamedCalls("finishTransaction")[0]
        tid = call.getParam(0)
        self.assertEqual(u64(tid), 1)

if __name__ == "__main__":
    unittest.main()