#
# Copyright (C) 2009  Nexedi SA
# 
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import os
import unittest
import logging
import MySQLdb
from tempfile import mkstemp
from mock import Mock
from neo import protocol
from neo.node import MasterNode
from neo.pt import PartitionTable
from neo.storage.app import Application, StorageNode
from neo.storage.verification import VerificationEventHandler
from neo.protocol import STORAGE_NODE_TYPE, MASTER_NODE_TYPE, CLIENT_NODE_TYPE
from neo.protocol import BROKEN_STATE, RUNNING_STATE, Packet, INVALID_UUID, \
     UP_TO_DATE_STATE, INVALID_OID, INVALID_TID, PROTOCOL_ERROR_CODE
from neo.protocol import ACCEPT_NODE_IDENTIFICATION, REQUEST_NODE_IDENTIFICATION, \
     NOTIFY_PARTITION_CHANGES, STOP_OPERATION, ASK_LAST_IDS, ASK_PARTITION_TABLE, \
     ANSWER_LAST_IDS, ASK_UNFINISHED_TRANSACTIONS, ANSWER_UNFINISHED_TRANSACTIONS, \
     ANSWER_OBJECT_PRESENT, ASK_OBJECT_PRESENT, OID_NOT_FOUND_CODE, LOCK_INFORMATION, \
     UNLOCK_INFORMATION, TID_NOT_FOUND_CODE, ASK_TRANSACTION_INFORMATION, ANSWER_TRANSACTION_INFORMATION, \
     ANSWER_PARTITION_TABLE,SEND_PARTITION_TABLE, COMMIT_TRANSACTION
from neo.protocol import ERROR, BROKEN_NODE_DISALLOWED_CODE, ASK_PRIMARY_MASTER
from neo.protocol import ANSWER_PRIMARY_MASTER
from neo.exception import PrimaryFailure, OperationFailure
from neo.storage.mysqldb import MySQLDatabaseManager, p64, u64

SQL_ADMIN_USER = 'root'
SQL_ADMIN_PASSWORD = None
NEO_SQL_USER = 'test'
NEO_SQL_DATABASE = 'test_neo1'

class StorageVerificationTests(unittest.TestCase):

    def setUp(self):
        logging.basicConfig(level = logging.ERROR)
        # create an application object
        config_file_text = """# Default parameters.
[DEFAULT]
# The list of master nodes.
master_nodes: 127.0.0.1:10010 
# The number of replicas.
replicas: 2
# The number of partitions.
partitions: 1009
# The name of this cluster.
name: main
# The user name for the database.
user: %(user)s
connector : SocketConnector
# The first master.
[mastertest]
server: 127.0.0.1:10010

[storagetest]
database: %(database)s
server: 127.0.0.1:10020
""" % {
    'database': NEO_SQL_DATABASE,
    'user': NEO_SQL_USER,
}
        # SQL connection
        connect_arg_dict = {'user': SQL_ADMIN_USER}
        if SQL_ADMIN_PASSWORD is not None:
            connect_arg_dict['raise NotImplementedErrorwd'] = SQL_ADMIN_PASSWORD
        sql_connection = MySQLdb.Connect(**connect_arg_dict)
        cursor = sql_connection.cursor()
        # new database
        cursor.execute('DROP DATABASE IF EXISTS %s' % (NEO_SQL_DATABASE, ))
        cursor.execute('CREATE DATABASE %s' % (NEO_SQL_DATABASE, ))
        cursor.execute('GRANT ALL ON %s.* TO "%s"@"localhost" IDENTIFIED BY ""' % 
                (NEO_SQL_DATABASE, NEO_SQL_USER))
        cursor.close()
        # config file
        tmp_id, self.tmp_path = mkstemp()
        tmp_file = os.fdopen(tmp_id, "w+b")
        tmp_file.write(config_file_text)
        tmp_file.close()
        self.app = Application(self.tmp_path, "storagetest")        
        self.verification = VerificationEventHandler(self.app)
        # define some variable to simulate client and storage node
        self.master_port = 10010
        self.storage_port = 10020
        self.client_port = 11011
        self.num_partitions = 1009
        self.num_replicas = 2
        self.app.num_partitions = 1009
        self.app.num_replicas = 2
        self.app.operational = False
        self.app.load_lock_dict = {}
        self.app.pt = PartitionTable(self.app.num_partitions, self.app.num_replicas)

        
    def tearDown(self):
        # Delete tmp file
        os.remove(self.tmp_path)

    # Common methods
    def getNewUUID(self):
        uuid = INVALID_UUID
        while uuid == INVALID_UUID:
            uuid = os.urandom(16)
        self.uuid = uuid
        return uuid

    def getLastUUID(self):
        return self.uuid

    def getTwoIDs(self):
        # generate two ptid, first is lower
        ptids = self.getNewUUID(), self.getNewUUID()
        return min(ptids), max(ptids)
        ptid = min(ptids)

    def checkCalledAbort(self, conn, packet_number=0):
        """Check the abort method has been called and an error packet has been sent"""
        # sometimes we answer an error, sometimes we just send it
        send_calls_len = len(conn.mockGetNamedCalls("send"))
        answer_calls_len = len(conn.mockGetNamedCalls('answer'))
        self.assertEquals(send_calls_len + answer_calls_len, 1)
        self.assertEquals(len(conn.mockGetNamedCalls("abort")), 1)
        self.assertEquals(len(conn.mockGetNamedCalls("expectMessage")), 0)
        if send_calls_len == 1:
            call = conn.mockGetNamedCalls("send")[packet_number]
        else:
            call = conn.mockGetNamedCalls("answer")[packet_number]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ERROR)

    # Tests
    def test_01_connectionAccepted(self):
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port)})
        self.verification.connectionAccepted(conn, None, ("127.0.0.1", self.client_port))
        # nothing happens
        self.assertEquals(len(conn.mockGetNamedCalls("addPacket")), 0)
        
    def test_02_timeoutExpired(self):
        # listening connection
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        self.verification.timeoutExpired(conn)
        # nothing happens
        self.assertEquals(len(conn.mockGetNamedCalls("addPacket")), 0)
        
        # client connection
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})
        self.assertRaises(PrimaryFailure, self.verification.timeoutExpired, conn,)
        # nothing happens
        self.assertEquals(len(conn.mockGetNamedCalls("addPacket")), 0)

    def test_03_connectionClosed(self):
        # listening connection
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        self.verification.connectionClosed(conn)
        # nothing happens
        self.assertEquals(len(conn.mockGetNamedCalls("addPacket")), 0)
        
        # client connection
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})
        self.assertRaises(PrimaryFailure, self.verification.connectionClosed, conn,)
        # nothing happens
        self.assertEquals(len(conn.mockGetNamedCalls("addPacket")), 0)


    def test_04_peerBroken(self):
        # listening connection
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        self.verification.peerBroken(conn)
        # nothing happens
        self.assertEquals(len(conn.mockGetNamedCalls("addPacket")), 0)
        
        # client connection
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})
        self.assertRaises(PrimaryFailure, self.verification.peerBroken, conn,)
        # nothing happens
        self.assertEquals(len(conn.mockGetNamedCalls("addPacket")), 0)


    def test_05_handleRequestNodeIdentification(self):
        # listening connection
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        p = Packet(msg_type=REQUEST_NODE_IDENTIFICATION)
        self.verification.handleRequestNodeIdentification(conn, p, CLIENT_NODE_TYPE,
                                      uuid, "127.0.0.1", self.client_port, "zatt")
        self.checkCalledAbort(conn)

        # not a master node
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        p = Packet(msg_type=REQUEST_NODE_IDENTIFICATION)
        self.verification.handleRequestNodeIdentification(conn, p, CLIENT_NODE_TYPE,
                                      uuid, "127.0.0.1", self.client_port, "zatt")
        self.checkCalledAbort(conn)

        # bad name
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.master_port),
                     "isServerConnection" : True})
        p = Packet(msg_type=REQUEST_NODE_IDENTIFICATION)
        self.verification.handleRequestNodeIdentification(conn, p, MASTER_NODE_TYPE,
                                      uuid, "127.0.0.1", self.client_port, "zatt")
        self.checkCalledAbort(conn)

        # new node
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.master_port),
                     "isServerConnection" : True})
        p = Packet(msg_type=REQUEST_NODE_IDENTIFICATION)
        self.assertEqual(self.app.nm.getNodeByServer(conn.getAddress()), None)
        self.verification.handleRequestNodeIdentification(conn, p, MASTER_NODE_TYPE,
                                                          uuid, "127.0.0.1", self.master_port, "main")
        self.assertNotEqual(self.app.nm.getNodeByServer(conn.getAddress()), None)
        node = self.app.nm.getNodeByServer(conn.getAddress())
        self.assertEqual(node.getUUID(), uuid)
        self.assertEqual(node.getState(), RUNNING_STATE)
        self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
        self.assertEquals(len(conn.mockGetNamedCalls("abort")), 1)
        
        # notify a node declared as broken
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.master_port),
                     "isServerConnection" : True})
        node = self.app.nm.getNodeByServer(conn.getAddress())
        node.setState(BROKEN_STATE)
        self.assertEqual(node.getUUID(), uuid)
        self.verification.handleRequestNodeIdentification(conn, p, MASTER_NODE_TYPE,
                                                          uuid, "127.0.0.1", self.master_port, "main")
        self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ERROR)
        self.assertEquals(len(conn.mockGetNamedCalls("abort")), 1)

        # change uuid of a known node
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.master_port),
                     "isServerConnection" : True})
        node = self.app.nm.getNodeByServer(conn.getAddress())
        node.setState(RUNNING_STATE)
        self.assertNotEqual(node.getUUID(), uuid)
        self.verification.handleRequestNodeIdentification(conn, p, MASTER_NODE_TYPE,
                                                          uuid, "127.0.0.1", self.master_port, "main")
        self.assertNotEqual(self.app.nm.getNodeByServer(conn.getAddress()), None)
        node = self.app.nm.getNodeByServer(conn.getAddress())
        self.assertEqual(node.getUUID(), uuid)
        self.assertEqual(node.getState(), RUNNING_STATE)
        self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ACCEPT_NODE_IDENTIFICATION)
        self.assertEquals(len(conn.mockGetNamedCalls("abort")), 1)


    def test_06_handleAcceptNodeIdentification(self):
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        p = Packet(msg_type=ACCEPT_NODE_IDENTIFICATION)
        self.verification.handleAcceptNodeIdentification(conn, p, CLIENT_NODE_TYPE,
                 self.getNewUUID(),"127.0.0.1", self.client_port, 1009, 2, uuid)
        self.checkCalledAbort(conn)    

    def test_07_handleAnswerPrimaryMaster(self):
        # reject server connection 
        packet = Packet(msg_type=ANSWER_PRIMARY_MASTER)
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        self.verification.handleAnswerPrimaryMaster(conn, packet,self.getNewUUID(), ())
        self.checkCalledAbort(conn)

        # raise id uuid is different
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})
        self.app.primary_master_node = MasterNode(uuid=self.getNewUUID())
        self.assertNotEqual(uuid, self.app.primary_master_node.getUUID())
        self.assertRaises(PrimaryFailure, self.verification.handleAnswerPrimaryMaster, conn, packet,uuid, ())
        
        # same uuid, do nothing
        uuid = self.app.primary_master_node.getUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})
        self.assertEqual(uuid, self.app.primary_master_node.getUUID())
        self.verification.handleAnswerPrimaryMaster(conn, packet,uuid, ())
        
    def test_07_handleAskLastIDs(self):
        # reject server connection 
        packet = Packet(msg_type=ASK_LAST_IDS)
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        self.verification.handleAskLastIDs(conn, packet)
        self.checkCalledAbort(conn)

        # return invalid if db store nothing
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})
        self.verification.handleAskLastIDs(conn, packet)
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_LAST_IDS)
        oid, tid, ptid = packet.decode()
        self.assertEqual(oid, INVALID_OID)
        self.assertEqual(tid, INVALID_TID)
        self.assertEqual(ptid, self.app.ptid)

        # return value stored in db
        # insert some oid
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})
        self.app.dm.begin()
        self.app.dm.query("""insert into obj (oid, serial, compression,
        checksum, value) values (3, 'A', 0, 0, '')""")
        self.app.dm.query("""insert into obj (oid, serial, compression,
        checksum, value) values (1, 'A', 0, 0, '')""")
        self.app.dm.query("""insert into obj (oid, serial, compression,
        checksum, value) values (2, 'A', 0, 0, '')""")
        self.app.dm.query("""insert into tobj (oid, serial, compression,
        checksum, value) values (5, 'A', 0, 0, '')""")
        # insert some tid
        self.app.dm.query("""insert into trans (tid, oids, user,
                description, ext) values (1, '', '', '', '')""")
        self.app.dm.query("""insert into trans (tid, oids, user,
                description, ext) values (2, '', '', '', '')""")
        self.app.dm.query("""insert into ttrans (tid, oids, user,
                description, ext) values (3, '', '', '', '')""")
        # max tid is in tobj (serial)
        self.app.dm.query("""insert into tobj (oid, serial, compression,
                checksum, value) values (0, 4, 0, 0, '')""")
        self.app.dm.commit()
        self.verification.handleAskLastIDs(conn, packet)
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_LAST_IDS)
        oid, tid, ptid = packet.decode()
        self.assertEqual(u64(oid), 5)
        self.assertEqual(u64(tid), 4)
        self.assertEqual(ptid, self.app.ptid)
        
    def test_08_handleAskPartitionTable(self):
        # reject server connection 
        packet = Packet(msg_type=ASK_PARTITION_TABLE)
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        self.verification.handleAskPartitionTable(conn, packet, [1,])
        self.checkCalledAbort(conn)

        # try to get unknown offset
        self.assertEqual(len(self.app.pt.getNodeList()), 0)
        self.assertFalse(self.app.pt.hasOffset(1))
        self.assertEqual(len(self.app.pt.getCellList(1)), 0)
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})
        self.verification.handleAskPartitionTable(conn, packet, [1,])
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_PARTITION_TABLE)
        ptid, row_list = packet.decode()
        self.assertEqual(len(row_list), 1)
        offset, rows = row_list[0]
        self.assertEqual(offset, 1)
        self.assertEqual(len(rows), 0)

        # try to get known offset
        node = StorageNode(("127.7.9.9", 1), self.getNewUUID())
        self.app.pt.setCell(1, node, UP_TO_DATE_STATE)
        self.assertTrue(self.app.pt.hasOffset(1))
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})
        self.verification.handleAskPartitionTable(conn, packet, [1,])
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_PARTITION_TABLE)
        ptid, row_list = packet.decode()
        self.assertEqual(len(row_list), 1)
        offset, rows = row_list[0]
        self.assertEqual(offset, 1)
        self.assertEqual(len(rows), 1)

    def test_09_handleSendPartitionTable(self):
        # reject server connection 
        packet = Packet(msg_type=SEND_PARTITION_TABLE)
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        self.app.ptid = 1
        self.verification.handleSendPartitionTable(conn, packet, 0, ())
        self.assertEquals(self.app.ptid, 1)
        self.checkCalledAbort(conn)

        # send a table
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : False})

        self.app.pt = PartitionTable(3, 2)
        node_1 = self.getNewUUID()
        node_2 = self.getNewUUID()
        node_3 = self.getNewUUID()
        # SN already known one of the node
        self.app.nm.add(StorageNode(uuid=node_1))
        self.app.ptid = 1
        self.app.num_partitions = 3
        self.app.num_replicas =2 
        self.assertEqual(self.app.dm.getPartitionTable(), ())
        row_list = [(0, ((node_1, UP_TO_DATE_STATE), (node_2, UP_TO_DATE_STATE))),
                    (1, ((node_3, UP_TO_DATE_STATE), (node_1, UP_TO_DATE_STATE))),
                    (2, ((node_2, UP_TO_DATE_STATE), (node_3, UP_TO_DATE_STATE)))]
        self.assertFalse(self.app.pt.filled())
        # send part of the table, won't be filled
        self.verification.handleSendPartitionTable(conn, packet, "1", row_list[:1])
        self.assertFalse(self.app.pt.filled())
        self.assertEqual(self.app.ptid, "1")
        self.assertEqual(self.app.dm.getPartitionTable(), ())
        # send remaining of the table
        self.verification.handleSendPartitionTable(conn, packet, "1", row_list[1:])
        self.assertTrue(self.app.pt.filled())
        self.assertEqual(self.app.ptid, "1")
        self.assertNotEqual(self.app.dm.getPartitionTable(), ())
        # send a complete new table
        self.verification.handleSendPartitionTable(conn, packet, "2", row_list)
        self.assertTrue(self.app.pt.filled())
        self.assertEqual(self.app.ptid, "2")
        self.assertNotEqual(self.app.dm.getPartitionTable(), ())

    def test_10_handleNotifyPartitionChanges(self):
        # reject server connection 
        packet = Packet(msg_type=NOTIFY_PARTITION_CHANGES)
        uuid = self.getNewUUID()
        conn = Mock({"getUUID" : uuid,
                     "getAddress" : ("127.0.0.1", self.client_port),
                     "isServerConnection" : True})
        self.app.ptid = 1
        self.verification.handleNotifyPartitionChanges(conn, packet, 0, ())
        self.assertEquals(self.app.ptid, 1)
        self.checkCalledAbort(conn)
        
        # old partition change
        conn = Mock({
            "isServerConnection": False,
            "getAddress" : ("127.0.0.1", self.master_port), 
        })
        packet = Packet(msg_type=NOTIFY_PARTITION_CHANGES)
        self.app.ptid = 1
        self.verification.handleNotifyPartitionChanges(conn, packet, 0, ())
        self.assertEquals(self.app.ptid, 1)

        # new node
        conn = Mock({
            "isServerConnection": False,
            "getAddress" : ("127.0.0.1", self.master_port), 
        })
        packet = Packet(msg_type=NOTIFY_PARTITION_CHANGES)
        cell = (0, self.getNewUUID(), UP_TO_DATE_STATE)
        count = len(self.app.nm.getNodeList())
        self.app.pt = PartitionTable(1, 1)
        self.app.dm = Mock({ })
        ptid, self.ptid = self.getTwoIDs()
        # pt updated
        self.verification.handleNotifyPartitionChanges(conn, packet, ptid, (cell, ))
        self.assertEquals(len(self.app.nm.getNodeList()), count + 1)
        # check db update
        calls = self.app.dm.mockGetNamedCalls('changePartitionTable')
        self.assertEquals(len(calls), 1)
        self.assertEquals(calls[0].getParam(0), ptid)
        self.assertEquals(calls[0].getParam(1), (cell, ))

    def test_11_handleStartOperation(self):
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': True })
        packet = Packet(msg_type=STOP_OPERATION)
        self.verification.handleStartOperation(conn, packet)
        self.checkCalledAbort(conn)        
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False })
        self.assertFalse(self.app.operational)
        packet = Packet(msg_type=STOP_OPERATION)
        self.verification.handleStartOperation(conn, packet)
        self.assertTrue(self.app.operational)

    def test_12_handleStopOperation(self):
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': True })
        packet = Packet(msg_type=STOP_OPERATION)
        self.verification.handleStopOperation(conn, packet)
        self.checkCalledAbort(conn)        
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False })
        packet = Packet(msg_type=STOP_OPERATION)
        self.assertRaises(OperationFailure, self.verification.handleStopOperation, conn, packet)

    def test_13_handleAskUnfinishedTransactions(self):
        # server connection
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': True })
        packet = Packet(msg_type=ASK_UNFINISHED_TRANSACTIONS)
        self.verification.handleAskUnfinishedTransactions(conn, packet)
        self.checkCalledAbort(conn)
        # client connection with no data
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False})
        packet = Packet(msg_type=ASK_UNFINISHED_TRANSACTIONS)
        self.verification.handleAskUnfinishedTransactions(conn, packet)
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_UNFINISHED_TRANSACTIONS)
        tid_list = packet.decode()[0]        
        self.assertEqual(len(tid_list), 0)

        # client connection with some data
        self.app.dm.begin()
        self.app.dm.query("""insert into tobj (oid, serial, compression,
                checksum, value) values (0, 4, 0, 0, '')""")
        self.app.dm.commit()
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False})
        packet = Packet(msg_type=ASK_UNFINISHED_TRANSACTIONS)
        self.verification.handleAskUnfinishedTransactions(conn, packet)
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_UNFINISHED_TRANSACTIONS)
        tid_list = packet.decode()[0]        
        self.assertEqual(len(tid_list), 1)
        self.assertEqual(u64(tid_list[0]), 4)

    def test_14_handleAskTransactionInformation(self):
        # ask from server with no data
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': True })
        packet = Packet(msg_type=ASK_TRANSACTION_INFORMATION)
        self.verification.handleAskTransactionInformation(conn, packet, p64(1))
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ERROR)
        code, message = packet.decode()     
        self.assertEqual(code, TID_NOT_FOUND_CODE)
        # ask from client conn with no data
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False })
        packet = Packet(msg_type=ASK_TRANSACTION_INFORMATION)
        self.verification.handleAskTransactionInformation(conn, packet, p64(1))
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ERROR)
        code, message = packet.decode()     
        self.assertEqual(code, TID_NOT_FOUND_CODE)

        # input some tmp data and ask from client, must find both transaction
        self.app.dm.begin()
        self.app.dm.query("""insert into ttrans (tid, oids, user,
        description, ext) values (3, '%s', 'u1', 'd1', 'e1')""" %(p64(4),))
        self.app.dm.query("""insert into trans (tid, oids, user,
        description, ext) values (1,'%s', 'u2', 'd2', 'e2')""" %(p64(2),)) 
        self.app.dm.commit()
        # object from trans
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False })
        packet = Packet(msg_type=ASK_TRANSACTION_INFORMATION)
        self.verification.handleAskTransactionInformation(conn, packet, p64(1))
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_TRANSACTION_INFORMATION)
        tid, user, desc, ext, oid_list = packet.decode()
        self.assertEqual(u64(tid), 1)
        self.assertEqual(user, 'u2')
        self.assertEqual(desc, 'd2')
        self.assertEqual(ext, 'e2')
        self.assertEqual(len(oid_list), 1)
        self.assertEqual(u64(oid_list[0]), 2)
        # object from ttrans
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False })
        packet = Packet(msg_type=ASK_TRANSACTION_INFORMATION)
        self.verification.handleAskTransactionInformation(conn, packet, p64(3))
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_TRANSACTION_INFORMATION)
        tid, user, desc, ext, oid_list = packet.decode()
        self.assertEqual(u64(tid), 3)
        self.assertEqual(user, 'u1')
        self.assertEqual(desc, 'd1')
        self.assertEqual(ext, 'e1')
        self.assertEqual(len(oid_list), 1)
        self.assertEqual(u64(oid_list[0]), 4)

        # input some tmp data and ask from server, must find one transaction
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': True })
        # find the one in trans
        packet = Packet(msg_type=ASK_TRANSACTION_INFORMATION)
        self.verification.handleAskTransactionInformation(conn, packet, p64(1))
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_TRANSACTION_INFORMATION)
        tid, user, desc, ext, oid_list = packet.decode()
        self.assertEqual(u64(tid), 1)
        self.assertEqual(user, 'u2')
        self.assertEqual(desc, 'd2')
        self.assertEqual(ext, 'e2')
        self.assertEqual(len(oid_list), 1)
        self.assertEqual(u64(oid_list[0]), 2)
        # do not find the one in ttrans
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': True })
        packet = Packet(msg_type=ASK_TRANSACTION_INFORMATION)
        self.verification.handleAskTransactionInformation(conn, packet, p64(3))
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ERROR)
        code, message = packet.decode()     
        self.assertEqual(code, TID_NOT_FOUND_CODE)

    def test_15_handleAskObjectPresent(self):
        # server connection
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': True })
        packet = Packet(msg_type=ASK_OBJECT_PRESENT)
        self.verification.handleAskObjectPresent(conn, packet, p64(1), p64(2))
        self.checkCalledAbort(conn)
        # client connection with no data
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False})
        packet = Packet(msg_type=ASK_OBJECT_PRESENT)
        self.verification.handleAskObjectPresent(conn, packet, p64(1), p64(2))
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ERROR)
        code, message = packet.decode()     
        self.assertEqual(code, OID_NOT_FOUND_CODE)

        # client connection with some data
        self.app.dm.begin()
        self.app.dm.query("""insert into tobj (oid, serial, compression,
                checksum, value) values (1, 2, 0, 0, '')""")
        self.app.dm.commit()
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False})
        packet = Packet(msg_type=ASK_OBJECT_PRESENT)
        self.verification.handleAskObjectPresent(conn, packet, p64(1), p64(2))
        call = conn.mockGetNamedCalls("answer")[0]
        packet = call.getParam(0)
        self.assertTrue(isinstance(packet, Packet))
        self.assertEquals(packet.getType(), ANSWER_OBJECT_PRESENT)
        oid, tid = packet.decode()
        self.assertEqual(u64(tid), 2)
        self.assertEqual(u64(oid), 1)

    def test_16_handleDeleteTransaction(self):
        # server connection
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': True })
        packet = Packet(msg_type=ASK_OBJECT_PRESENT)
        self.verification.handleDeleteTransaction(conn, packet, p64(1))
        self.checkCalledAbort(conn)
        # client connection with no data
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False})
        packet = Packet(msg_type=ASK_OBJECT_PRESENT)
        self.verification.handleDeleteTransaction(conn, packet, p64(1))
        # client connection with data
        self.app.dm.begin()
        self.app.dm.query("""insert into tobj (oid, serial, compression,
                checksum, value) values (1, 2, 0, 0, '')""")
        self.app.dm.commit()
        self.verification.handleDeleteTransaction(conn, packet, p64(2))
        result = self.app.dm.query('select * from tobj')
        self.assertEquals(len(result), 0)

    def test_17_handleCommitTransaction(self):
        # server connection
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': True })
        dm = Mock()
        self.app.dm = dm
        packet = Packet(msg_type=COMMIT_TRANSACTION)
        self.verification.handleCommitTransaction(conn, packet, p64(1))
        self.checkCalledAbort(conn)
        self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 0)
        # commit a transaction
        conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
                      'isServerConnection': False })
        dm = Mock()
        self.app.dm = dm
        packet = Packet(msg_type=COMMIT_TRANSACTION)
        self.verification.handleCommitTransaction(conn, packet, p64(1))
        self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1)
        call = dm.mockGetNamedCalls("finishTransaction")[0]
        tid = call.getParam(0)
        self.assertEqual(u64(tid), 1)

    def test_18_handleLockInformation(self):
        conn = Mock({"getAddress" : ("127.0.0.1", self.master_port),
                     'isServerConnection': False})
        packet = Packet(msg_type=LOCK_INFORMATION)
        self.assertEquals(len(self.app.load_lock_dict), 0)
        self.verification.handleLockInformation(conn, packet, p64(1))
        self.assertEquals(len(self.app.load_lock_dict), 0)

    def test_19_handleUnlockInformation(self):
        conn = Mock({"getAddress" : ("127.0.0.1", self.master_port),
                     'isServerConnection': False})
        self.app.load_lock_dict[p64(1)] = Mock()
        packet = Packet(msg_type=UNLOCK_INFORMATION)
        self.verification.handleUnlockInformation(conn, packet, p64(1))
        self.assertEquals(len(self.app.load_lock_dict), 1)
    
if __name__ == "__main__":
    unittest.main()