# # Copyright (C) 2009 Nexedi SA # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. import os import unittest import tempfile import MySQLdb from neo import logging from mock import Mock from neo import protocol from neo.protocol import PacketTypes DB_PREFIX = 'test_neo_' DB_ADMIN = 'root' DB_PASSWD = None DB_USER = 'test' def getNewUUID(): """ Return a valid UUID """ uuid = protocol.INVALID_UUID while uuid == protocol.INVALID_UUID: uuid = os.urandom(16) return uuid class NeoTestBase(unittest.TestCase): """ Base class for neo tests, implements common checks """ def prepareDatabase(self, number, admin=DB_ADMIN, password=DB_PASSWD, user=DB_USER, prefix=DB_PREFIX): """ create empties databases """ # SQL connection connect_arg_dict = {'user': admin} if password is not None: connect_arg_dict['passwd'] = password sql_connection = MySQLdb.Connect(**connect_arg_dict) cursor = sql_connection.cursor() # drop and create each database for i in xrange(number): database = "%s%d" % (prefix, i) cursor.execute('DROP DATABASE IF EXISTS %s' % (database, )) cursor.execute('CREATE DATABASE %s' % (database, )) cursor.execute('GRANT ALL ON %s.* TO "%s"@"localhost" IDENTIFIED BY ""' % (database, user)) cursor.close() sql_connection.close() def getMasterConfiguration(self, cluster='main', master_number=2, replicas=2, partitions=1009, uuid=None): assert master_number >= 1 and master_number <= 10 masters = ['127.0.0.1:1001%d' % i for i in xrange(master_number)] return { 'cluster': cluster, 'bind': masters[0], 'masters': '/'.join(masters), 'replicas': replicas, 'partitions': partitions, 'uuid': uuid, } def getStorageConfiguration(self, cluster='main', master_number=2, index=0, prefix=DB_PREFIX, uuid=None): assert master_number >= 1 and master_number <= 10 assert index >= 0 and index <= 9 masters = ['127.0.0.1:1001%d' % i for i in xrange(master_number)] if DB_PASSWD is None: database = '%s:@%s%d' % (DB_USER, prefix, index) else: database = '%s:%s@%s%d' % (DB_USER, DB_PASSWD, prefix, index) return { 'cluster': cluster, 'bind': '127.0.0.1:1002%d' % (index, ), 'masters': '/'.join(masters), 'database': database, 'uuid': uuid, 'reset': False, } # XXX: according to changes with namespaced UUIDs, it would be better to # implement get<NodeType>UUID() methods def getNewUUID(self): self.uuid = getNewUUID() return self.uuid def getTwoIDs(self): """ Return a tuple of two sorted UUIDs """ # generate two ptid, first is lower uuids = self.getNewUUID(), self.getNewUUID() return min(uuids), max(uuids) def getFakeConnection(self, uuid=None, address=('127.0.0.1', 10000)): return Mock({ 'getUUID': uuid, 'getAddress': address, }) def checkProtocolErrorRaised(self, method, *args, **kwargs): """ Check if the ProtocolError exception was raised """ self.assertRaises(protocol.ProtocolError, method, *args, **kwargs) def checkUnexpectedPacketRaised(self, method, *args, **kwargs): """ Check if the UnexpectedPacketError exception wxas raised """ self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs) def checkIdenficationRequired(self, method, *args, **kwargs): """ Check is the identification_required decorator is applied """ self.checkUnexpectedPacketRaised(method, *args, **kwargs) def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs): """ Check if the BrokenNodeDisallowedError exception wxas raised """ self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs) def checkNotReadyErrorRaised(self, method, *args, **kwargs): """ Check if the NotReadyError exception wxas raised """ self.assertRaises(protocol.NotReadyError, method, *args, **kwargs) def checkAborted(self, conn): """ Ensure the connection was aborted """ self.assertEquals(len(conn.mockGetNamedCalls('abort')), 1) def checkNotAborted(self, conn): """ Ensure the connection was not aborted """ self.assertEquals(len(conn.mockGetNamedCalls('abort')), 0) def checkClosed(self, conn): """ Ensure the connection was closed """ self.assertEquals(len(conn.mockGetNamedCalls('close')), 1) def checkNotClosed(self, conn): """ Ensure the connection was not closed """ self.assertEquals(len(conn.mockGetNamedCalls('close')), 0) def checkNoPacketSent(self, conn): """ check if no packet were sent """ self.assertEquals(len(conn.mockGetNamedCalls('notify')), 0) self.assertEquals(len(conn.mockGetNamedCalls('answer')), 0) self.assertEquals(len(conn.mockGetNamedCalls('ask')), 0) def checkNoUUIDSet(self, conn): """ ensure no UUID was set on the connection """ self.assertEquals(len(conn.mockGetNamedCalls('setUUID')), 0) def checkUUIDSet(self, conn, uuid=None): """ ensure no UUID was set on the connection """ calls = conn.mockGetNamedCalls('setUUID') self.assertEquals(len(calls), 1) if uuid is not None: self.assertEquals(calls[0].getParam(0), uuid) # in check(Ask|Answer|Notify)Packet we return the packet so it can be used # in tests if more accurates checks are required def checkErrorPacket(self, conn, decode=False): """ Check if an error packet was answered """ calls = conn.mockGetNamedCalls("answer") self.assertEquals(len(calls), 1) packet = calls[0].getParam(0) self.assertTrue(isinstance(packet, protocol.Packet)) self.assertEquals(packet.getType(), PacketTypes.ERROR) if decode: return protocol.decode_table[packet.getType()](packet._body) return packet def checkAskPacket(self, conn, packet_type, decode=False): """ Check if an ask-packet with the right type is sent """ calls = conn.mockGetNamedCalls('ask') self.assertEquals(len(calls), 1) packet = calls[0].getParam(0) self.assertTrue(isinstance(packet, protocol.Packet)) self.assertEquals(packet.getType(), packet_type) if decode: return protocol.decode_table[packet.getType()](packet._body) return packet def checkAnswerPacket(self, conn, packet_type, answered_packet=None, decode=False): """ Check if an answer-packet with the right type is sent """ calls = conn.mockGetNamedCalls('answer') self.assertEquals(len(calls), 1) packet = calls[0].getParam(0) self.assertTrue(isinstance(packet, protocol.Packet)) self.assertEquals(packet.getType(), packet_type) if answered_packet is not None: msg_id = calls[0].getParam(1) self.assertEqual(msg_id, answered_packet.getId()) if decode: return protocol.decode_table[packet.getType()](packet._body) return packet def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False): """ Check if a notify-packet with the right type is sent """ calls = conn.mockGetNamedCalls('notify') self.assertTrue(len(calls) > packet_number) packet = calls[packet_number].getParam(0) self.assertTrue(isinstance(packet, protocol.Packet)) self.assertEquals(packet.getType(), packet_type) if decode: return protocol.decode_table[packet.getType()](packet._body) return packet def checkNotifyNodeInformation(self, conn, **kw): return self.checkNotifyPacket(conn, PacketTypes.NOTIFY_NODE_INFORMATION, **kw) def checkSendPartitionTable(self, conn, **kw): return self.checkNotifyPacket(conn, PacketTypes.SEND_PARTITION_TABLE, **kw) def checkStartOperation(self, conn, **kw): return self.checkNotifyPacket(conn, PacketTypes.START_OPERATION, **kw) def checkNotifyTransactionFinished(self, conn, **kw): return self.checkNotifyPacket(conn, PacketTypes.NOTIFY_TRANSACTION_FINISHED, **kw) def checkNotifyInformationLocked(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.NOTIFY_INFORMATION_LOCKED, **kw) def checkLockInformation(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.LOCK_INFORMATION, **kw) def checkUnlockInformation(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.UNLOCK_INFORMATION, **kw) def checkRequestNodeIdentification(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.REQUEST_NODE_IDENTIFICATION, **kw) def checkAskPrimaryMaster(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.ASK_PRIMARY_MASTER) def checkAskUnfinishedTransactions(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.ASK_UNFINISHED_TRANSACTIONS) def checkAskTransactionInformation(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.ASK_TRANSACTION_INFORMATION, **kw) def checkAskObjectPresent(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.ASK_OBJECT_PRESENT, **kw) def checkAskObject(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.ASK_OBJECT, **kw) def checkAskStoreObject(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.ASK_STORE_OBJECT, **kw) def checkAskStoreTransaction(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.ASK_STORE_TRANSACTION, **kw) def checkFinishTransaction(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.FINISH_TRANSACTION, **kw) def checkAskNewTid(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.ASK_BEGIN_TRANSACTION, **kw) def checkAskLastIDs(self, conn, **kw): return self.checkAskPacket(conn, PacketTypes.ASK_LAST_IDS, **kw) def checkAcceptNodeIdentification(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ACCEPT_NODE_IDENTIFICATION, **kw) def checkAnswerPrimaryMaster(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_PRIMARY_MASTER, **kw) def checkAnswerLastIDs(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_LAST_IDS, **kw) def checkAnswerUnfinishedTransactions(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_UNFINISHED_TRANSACTIONS, **kw) def checkAnswerObject(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OBJECT, **kw) def checkAnswerTransactionInformation(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_TRANSACTION_INFORMATION, **kw) def checkAnswerTids(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_TIDS, **kw) def checkAnswerObjectHistory(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OBJECT_HISTORY, **kw) def checkAnswerStoreTransaction(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_STORE_TRANSACTION, **kw) def checkAnswerStoreObject(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_STORE_OBJECT, **kw) def checkAnswerOids(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OIDS, **kw) def checkAnswerPartitionTable(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_PARTITION_TABLE, **kw) def checkAnswerObjectPresent(self, conn, **kw): return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OBJECT_PRESENT, **kw) # XXX: imported from neo.master.test.connector since it's used at many places connector_cpt = 0 class DoNothingConnector(Mock): def __init__(self, s=None): logging.info("initializing connector") self.desc = globals()['connector_cpt'] globals()['connector_cpt'] = globals()['connector_cpt']+ 1 self.packet_cpt = 0 Mock.__init__(self) def getAddress(self): return self.addr def makeClientConnection(self, addr): self.addr = addr def getDescriptor(self): return self.desc class TestElectionConnector(DoNothingConnector): def receive(self): """ simulate behavior of election """ if self.packet_cpt == 0: # first : identify logging.info("in patched analyse / IDENTIFICATION") p = protocol.Packet() self.uuid = getNewUUID() p.acceptNodeIdentification(1, NodeType.MASTER, self.uuid, self.getAddress()[0], self.getAddress()[1], 1009, 2 ) self.packet_cpt += 1 return p.encode() elif self.packet_cpt == 1: # second : answer primary master nodes logging.info("in patched analyse / ANSWER PM") p = protocol.Packet() p.answerPrimaryMaster(2, protocol.INVALID_UUID, []) self.packet_cpt += 1 return p.encode() else: # then do nothing from neo.connector import ConnectorTryAgainException raise ConnectorTryAgainException