# # Copyright (C) 2009-2010 Nexedi SA # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. import os import random import socket import sys import tempfile import unittest import MySQLdb import neo from mock import Mock from neo.lib import protocol from neo.lib.protocol import Packets from neo.lib.util import getAddressType from time import time, gmtime from struct import pack, unpack DB_PREFIX = os.getenv('NEO_DB_PREFIX', 'test_neo_') DB_ADMIN = os.getenv('NEO_DB_ADMIN', 'root') DB_PASSWD = os.getenv('NEO_DB_PASSWD', None) DB_USER = os.getenv('NEO_DB_USER', 'test') IP_VERSION_FORMAT_DICT = { socket.AF_INET: '127.0.0.1', socket.AF_INET6: '::1', } ADDRESS_TYPE = socket.AF_INET def buildUrlFromString(address): try: socket.inet_pton(socket.AF_INET6, address) address = '[%s]' % address except Exception: pass return address class NeoTestBase(unittest.TestCase): def setUp(self): sys.stdout.write(' * %s ' % (self.id(), )) sys.stdout.flush() unittest.TestCase.setUp(self) def tearDown(self): unittest.TestCase.tearDown(self) sys.stdout.write('\n') sys.stdout.flush() class NeoUnitTestBase(NeoTestBase): """ Base class for neo tests, implements common checks """ local_ip = IP_VERSION_FORMAT_DICT[ADDRESS_TYPE] def prepareDatabase(self, number, admin=DB_ADMIN, password=DB_PASSWD, user=DB_USER, prefix=DB_PREFIX, address_type = ADDRESS_TYPE): """ create empties databases """ # SQL connection connect_arg_dict = {'user': admin} if password is not None: connect_arg_dict['passwd'] = password sql_connection = MySQLdb.Connect(**connect_arg_dict) cursor = sql_connection.cursor() # drop and create each database for i in xrange(number): database = "%s%d" % (prefix, i) cursor.execute('DROP DATABASE IF EXISTS %s' % (database, )) cursor.execute('CREATE DATABASE %s' % (database, )) cursor.execute('GRANT ALL ON %s.* TO "%s"@"localhost" IDENTIFIED BY ""' % (database, user)) cursor.close() sql_connection.close() def getMasterConfiguration(self, cluster='main', master_number=2, replicas=2, partitions=1009, uuid=None): assert master_number >= 1 and master_number <= 10 masters = ([(self.local_ip, 10010 + i) for i in xrange(master_number)]) return Mock({ 'getCluster': cluster, 'getBind': masters[0], 'getMasters': (masters, getAddressType(( self.local_ip, 0))), 'getReplicas': replicas, 'getPartitions': partitions, 'getUUID': uuid, }) def getStorageConfiguration(self, cluster='main', master_number=2, index=0, prefix=DB_PREFIX, uuid=None): assert master_number >= 1 and master_number <= 10 assert index >= 0 and index <= 9 masters = [(buildUrlFromString(self.local_ip), 10010 + i) for i in xrange(master_number)] database = '%s@%s%s' % (DB_USER, prefix, index) return Mock({ 'getCluster': cluster, 'getName': 'storage', 'getBind': (masters[0], 10020 + index), 'getMasters': (masters, getAddressType(( self.local_ip, 0))), 'getDatabase': database, 'getUUID': uuid, 'getReset': False, 'getAdapter': 'MySQL', }) def _makeUUID(self, prefix): """ Retuns a 16-bytes UUID according to namespace 'prefix' """ assert len(prefix) == 1 uuid = protocol.INVALID_UUID while uuid[1:] == protocol.INVALID_UUID[1:]: uuid = prefix + os.urandom(15) return uuid def getNewUUID(self): return self._makeUUID('\0') def getClientUUID(self): return self._makeUUID('C') def getMasterUUID(self): return self._makeUUID('M') def getStorageUUID(self): return self._makeUUID('S') def getAdminUUID(self): return self._makeUUID('A') def getNextTID(self, ltid=None): tm = time() gmt = gmtime(tm) upper = ((((gmt.tm_year - 1900) * 12 + gmt.tm_mon - 1) * 31 \ + gmt.tm_mday - 1) * 24 + gmt.tm_hour) * 60 + gmt.tm_min lower = int((gmt.tm_sec % 60 + (tm - int(tm))) / (60.0 / 65536.0 / 65536.0)) tid = pack('!LL', upper, lower) if ltid is not None and tid <= ltid: upper, lower = unpack('!LL', self._last_tid) if lower == 0xffffffff: # This should not happen usually. from datetime import timedelta, datetime d = datetime(gmt.tm_year, gmt.tm_mon, gmt.tm_mday, gmt.tm_hour, gmt.tm_min) \ + timedelta(0, 60) upper = ((((d.year - 1900) * 12 + d.month - 1) * 31 \ + d.day - 1) * 24 + d.hour) * 60 + d.minute lower = 0 else: lower += 1 tid = pack('!LL', upper, lower) return tid def getPTID(self, i=None): """ Return an integer PTID """ if i is None: return random.randint(1, 2**64) return i def getOID(self, i=None): """ Return a 8-bytes OID """ if i is None: return os.urandom(8) return pack('!Q', i) def getTwoIDs(self): """ Return a tuple of two sorted UUIDs """ # generate two ptid, first is lower uuids = self.getNewUUID(), self.getNewUUID() return min(uuids), max(uuids) def getFakeConnector(self, descriptor=None): return Mock({ '__repr__': 'FakeConnector', 'getDescriptor': descriptor, }) def getFakeConnection(self, uuid=None, address=('127.0.0.1', 10000), is_server=False, connector=None, peer_id=None): if connector is None: connector = self.getFakeConnector() return Mock({ 'getUUID': uuid, 'getAddress': address, 'isServer': is_server, '__repr__': 'FakeConnection', '__nonzero__': 0, 'getConnector': connector, 'getPeerId': peer_id, }) def checkProtocolErrorRaised(self, method, *args, **kwargs): """ Check if the ProtocolError exception was raised """ self.assertRaises(protocol.ProtocolError, method, *args, **kwargs) def checkUnexpectedPacketRaised(self, method, *args, **kwargs): """ Check if the UnexpectedPacketError exception wxas raised """ self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs) def checkIdenficationRequired(self, method, *args, **kwargs): """ Check is the identification_required decorator is applied """ self.checkUnexpectedPacketRaised(method, *args, **kwargs) def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs): """ Check if the BrokenNodeDisallowedError exception wxas raised """ self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs) def checkNotReadyErrorRaised(self, method, *args, **kwargs): """ Check if the NotReadyError exception wxas raised """ self.assertRaises(protocol.NotReadyError, method, *args, **kwargs) def checkAborted(self, conn): """ Ensure the connection was aborted """ self.assertEquals(len(conn.mockGetNamedCalls('abort')), 1) def checkNotAborted(self, conn): """ Ensure the connection was not aborted """ self.assertEquals(len(conn.mockGetNamedCalls('abort')), 0) def checkClosed(self, conn): """ Ensure the connection was closed """ self.assertEquals(len(conn.mockGetNamedCalls('close')), 1) def checkNotClosed(self, conn): """ Ensure the connection was not closed """ self.assertEquals(len(conn.mockGetNamedCalls('close')), 0) def _checkNoPacketSend(self, conn, method_id): call_list = conn.mockGetNamedCalls(method_id) self.assertEquals(len(call_list), 0, call_list) def checkNoPacketSent(self, conn, check_notify=True, check_answer=True, check_ask=True): """ check if no packet were sent """ if check_notify: self._checkNoPacketSend(conn, 'notify') if check_answer: self._checkNoPacketSend(conn, 'answer') if check_ask: self._checkNoPacketSend(conn, 'ask') def checkNoUUIDSet(self, conn): """ ensure no UUID was set on the connection """ self.assertEquals(len(conn.mockGetNamedCalls('setUUID')), 0) def checkUUIDSet(self, conn, uuid=None): """ ensure no UUID was set on the connection """ calls = conn.mockGetNamedCalls('setUUID') self.assertEquals(len(calls), 1) if uuid is not None: self.assertEquals(calls[0].getParam(0), uuid) # in check(Ask|Answer|Notify)Packet we return the packet so it can be used # in tests if more accurates checks are required def checkErrorPacket(self, conn, decode=False): """ Check if an error packet was answered """ calls = conn.mockGetNamedCalls("answer") self.assertEquals(len(calls), 1) packet = calls[0].getParam(0) self.assertTrue(isinstance(packet, protocol.Packet)) self.assertEquals(packet.getType(), Packets.Error) if decode: return packet.decode() return protocol.decode_table[packet.getType()](packet._body) return packet def checkAskPacket(self, conn, packet_type, decode=False): """ Check if an ask-packet with the right type is sent """ calls = conn.mockGetNamedCalls('ask') self.assertEquals(len(calls), 1) packet = calls[0].getParam(0) self.assertTrue(isinstance(packet, protocol.Packet)) self.assertEquals(packet.getType(), packet_type) if decode: return packet.decode() return packet def checkAnswerPacket(self, conn, packet_type, decode=False): """ Check if an answer-packet with the right type is sent """ calls = conn.mockGetNamedCalls('answer') self.assertEquals(len(calls), 1) packet = calls[0].getParam(0) self.assertTrue(isinstance(packet, protocol.Packet)) self.assertEquals(packet.getType(), packet_type) if decode: return packet.decode() return packet def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False): """ Check if a notify-packet with the right type is sent """ calls = conn.mockGetNamedCalls('notify') self.assertTrue(len(calls) > packet_number, (len(calls), packet_number)) packet = calls[packet_number].getParam(0) self.assertTrue(isinstance(packet, protocol.Packet)) self.assertEquals(packet.getType(), packet_type) if decode: return packet.decode() return packet def checkNotify(self, conn, **kw): return self.checkNotifyPacket(conn, Packets.Notify, **kw) def checkNotifyNodeInformation(self, conn, **kw): return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation, **kw) def checkSendPartitionTable(self, conn, **kw): return self.checkNotifyPacket(conn, Packets.SendPartitionTable, **kw) def checkStartOperation(self, conn, **kw): return self.checkNotifyPacket(conn, Packets.StartOperation, **kw) def checkInvalidateObjects(self, conn, **kw): return self.checkNotifyPacket(conn, Packets.InvalidateObjects, **kw) def checkAbortTransaction(self, conn, **kw): return self.checkNotifyPacket(conn, Packets.AbortTransaction, **kw) def checkNotifyLastOID(self, conn, **kw): return self.checkNotifyPacket(conn, Packets.NotifyLastOID, **kw) def checkAnswerTransactionFinished(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerTransactionFinished, **kw) def checkAnswerInformationLocked(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerInformationLocked, **kw) def checkAskLockInformation(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskLockInformation, **kw) def checkNotifyUnlockInformation(self, conn, **kw): return self.checkNotifyPacket(conn, Packets.NotifyUnlockInformation, **kw) def checkNotifyTransactionFinished(self, conn, **kw): return self.checkNotifyPacket(conn, Packets.NotifyTransactionFinished, **kw) def checkRequestIdentification(self, conn, **kw): return self.checkAskPacket(conn, Packets.RequestIdentification, **kw) def checkAskPrimary(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskPrimary) def checkAskUnfinishedTransactions(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskUnfinishedTransactions) def checkAskTransactionInformation(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw) def checkAskObjectPresent(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskObjectPresent, **kw) def checkAskObject(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskObject, **kw) def checkAskStoreObject(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskStoreObject, **kw) def checkAskStoreTransaction(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskStoreTransaction, **kw) def checkAskFinishTransaction(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskFinishTransaction, **kw) def checkAskNewTid(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskBeginTransaction, **kw) def checkAskLastIDs(self, conn, **kw): return self.checkAskPacket(conn, Packets.AskLastIDs, **kw) def checkAcceptIdentification(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AcceptIdentification, **kw) def checkAnswerPrimary(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerPrimary, **kw) def checkAnswerLastIDs(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerLastIDs, **kw) def checkAnswerUnfinishedTransactions(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerUnfinishedTransactions, **kw) def checkAnswerObject(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerObject, **kw) def checkAnswerTransactionInformation(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerTransactionInformation, **kw) def checkAnswerBeginTransaction(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction, **kw) def checkAnswerTids(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw) def checkAnswerTidsFrom(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerTIDsFrom, **kw) def checkAnswerObjectHistory(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw) def checkAnswerObjectHistoryFrom(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerObjectHistoryFrom, **kw) def checkAnswerStoreTransaction(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw) def checkAnswerStoreObject(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw) def checkAnswerOids(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerOIDs, **kw) def checkAnswerPartitionTable(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw) def checkAnswerObjectPresent(self, conn, **kw): return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw) connector_cpt = 0 class DoNothingConnector(Mock): def __init__(self, s=None): neo.lib.logging.info("initializing connector") global connector_cpt self.desc = connector_cpt connector_cpt += 1 self.packet_cpt = 0 Mock.__init__(self) def getAddress(self): return self.addr def makeClientConnection(self, addr): self.addr = addr def makeListeningConnection(self, addr): self.addr = addr def getDescriptor(self): return self.desc