__init__.py 14.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#
# Copyright (C) 2009  Nexedi SA
# 
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import os
import unittest
import tempfile
import MySQLdb
22
from neo import logging
23 24
from mock import Mock
from neo import protocol
25
from neo.protocol import PacketTypes
26 27 28 29 30 31

DB_PREFIX = 'test_neo_'
DB_ADMIN = 'root'
DB_PASSWD = None
DB_USER = 'test'

32 33 34 35 36 37 38
def getNewUUID():
    """ Return a valid UUID """
    uuid = protocol.INVALID_UUID
    while uuid == protocol.INVALID_UUID:
        uuid = os.urandom(16)
    return uuid

39 40 41 42 43 44 45 46 47 48 49 50 51 52
class NeoTestBase(unittest.TestCase):
    """ Base class for neo tests, implements common checks """

    def prepareDatabase(self, number, admin=DB_ADMIN, password=DB_PASSWD,
            user=DB_USER, prefix=DB_PREFIX):
        """ create empties databases """
        # SQL connection
        connect_arg_dict = {'user': admin}
        if password is not None:
            connect_arg_dict['passwd'] = password
        sql_connection = MySQLdb.Connect(**connect_arg_dict)
        cursor = sql_connection.cursor()
        # drop and create each database
        for i in xrange(number):
53
            database = "%s%d" % (prefix, i)
54 55 56 57 58 59 60
            cursor.execute('DROP DATABASE IF EXISTS %s' % (database, ))
            cursor.execute('CREATE DATABASE %s' % (database, ))
            cursor.execute('GRANT ALL ON %s.* TO "%s"@"localhost" IDENTIFIED BY ""' % 
                (database, user))
        cursor.close()
        sql_connection.close()

61 62 63 64 65 66 67
    def getMasterConfiguration(self, cluster='main', master_number=2, 
            replicas=2, partitions=1009, uuid=None):
        assert master_number >= 1 and master_number <= 10
        masters = ['127.0.0.1:1001%d' % i for i in xrange(master_number)]
        return {
                'cluster': cluster,
                'bind': masters[0],
68
                'masters': '/'.join(masters),
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
                'replicas': replicas,
                'partitions': partitions,
                'uuid': uuid,
        }

    def getStorageConfiguration(self, cluster='main', master_number=2, 
            index=0, prefix=DB_PREFIX, uuid=None):
        assert master_number >= 1 and master_number <= 10
        assert index >= 0 and index <= 9
        masters = ['127.0.0.1:1001%d' % i for i in xrange(master_number)]
        if DB_PASSWD is None:
            database = '%s:@%s%d' % (DB_USER, prefix, index)
        else:
            database = '%s:%s@%s%d' % (DB_USER, DB_PASSWD, prefix, index)
        return {
                'cluster': cluster,
                'bind': '127.0.0.1:1002%d' % (index, ),
86
                'masters': '/'.join(masters),
87 88 89 90
                'database': database,
                'uuid': uuid,
                'reset': False,
        }
91 92 93 94
        
    # XXX: according to changes with namespaced UUIDs, it would be better to 
    # implement get<NodeType>UUID() methods 
    def getNewUUID(self):
95 96
        self.uuid = getNewUUID()
        return self.uuid
97 98 99 100

    def getTwoIDs(self):
        """ Return a tuple of two sorted UUIDs """
        # generate two ptid, first is lower
101 102
        uuids = self.getNewUUID(), self.getNewUUID()
        return min(uuids), max(uuids)
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171

    def getFakeConnection(self, uuid=None, address=('127.0.0.1', 10000)):
        return Mock({
            'getUUID': uuid,
            'getAddress': address,
        })

    def checkProtocolErrorRaised(self, method, *args, **kwargs):
        """ Check if the ProtocolError exception was raised """
        self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)

    def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
        """ Check if the UnexpectedPacketError exception wxas raised """
        self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)

    def checkIdenficationRequired(self, method, *args, **kwargs):
        """ Check is the identification_required decorator is applied """
        self.checkUnexpectedPacketRaised(method, *args, **kwargs)

    def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
        """ Check if the BrokenNodeDisallowedError exception wxas raised """
        self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)

    def checkNotReadyErrorRaised(self, method, *args, **kwargs):
        """ Check if the NotReadyError exception wxas raised """
        self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)

    def checkAborted(self, conn):
        """ Ensure the connection was aborted """
        self.assertEquals(len(conn.mockGetNamedCalls('abort')), 1)

    def checkNotAborted(self, conn):
        """ Ensure the connection was not aborted """
        self.assertEquals(len(conn.mockGetNamedCalls('abort')), 0)

    def checkClosed(self, conn):
        """ Ensure the connection was closed """
        self.assertEquals(len(conn.mockGetNamedCalls('close')), 1)

    def checkNotClosed(self, conn):
        """ Ensure the connection was not closed """
        self.assertEquals(len(conn.mockGetNamedCalls('close')), 0)

    def checkNoPacketSent(self, conn):
        """ check if no packet were sent """
        self.assertEquals(len(conn.mockGetNamedCalls('notify')), 0)
        self.assertEquals(len(conn.mockGetNamedCalls('answer')), 0)
        self.assertEquals(len(conn.mockGetNamedCalls('ask')), 0)

    def checkNoUUIDSet(self, conn):
        """ ensure no UUID was set on the connection """
        self.assertEquals(len(conn.mockGetNamedCalls('setUUID')), 0)

    def checkUUIDSet(self, conn, uuid=None):
        """ ensure no UUID was set on the connection """
        calls = conn.mockGetNamedCalls('setUUID')
        self.assertEquals(len(calls), 1)
        if uuid is not None:
            self.assertEquals(calls[0].getParam(0), uuid)

    # in check(Ask|Answer|Notify)Packet we return the packet so it can be used
    # in tests if more accurates checks are required

    def checkErrorPacket(self, conn, decode=False):
        """ Check if an error packet was answered """
        calls = conn.mockGetNamedCalls("answer")
        self.assertEquals(len(calls), 1)
        packet = calls[0].getParam(0)
        self.assertTrue(isinstance(packet, protocol.Packet))
172
        self.assertEquals(packet.getType(), PacketTypes.ERROR)
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
        if decode:
            return protocol.decode_table[packet.getType()](packet._body)
        return packet

    def checkAskPacket(self, conn, packet_type, decode=False):
        """ Check if an ask-packet with the right type is sent """
        calls = conn.mockGetNamedCalls('ask')
        self.assertEquals(len(calls), 1)
        packet = calls[0].getParam(0)
        self.assertTrue(isinstance(packet, protocol.Packet))
        self.assertEquals(packet.getType(), packet_type)
        if decode:
            return protocol.decode_table[packet.getType()](packet._body)
        return packet

    def checkAnswerPacket(self, conn, packet_type, answered_packet=None, decode=False):
        """ Check if an answer-packet with the right type is sent """
        calls = conn.mockGetNamedCalls('answer')
        self.assertEquals(len(calls), 1)
        packet = calls[0].getParam(0)
        self.assertTrue(isinstance(packet, protocol.Packet))
        self.assertEquals(packet.getType(), packet_type)
        if answered_packet is not None:
196 197
            msg_id = calls[0].getParam(1)
            self.assertEqual(msg_id, answered_packet.getId())
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
        if decode:
            return protocol.decode_table[packet.getType()](packet._body)
        return packet

    def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False):
        """ Check if a notify-packet with the right type is sent """
        calls = conn.mockGetNamedCalls('notify')
        self.assertTrue(len(calls) > packet_number)
        packet = calls[packet_number].getParam(0)
        self.assertTrue(isinstance(packet, protocol.Packet))
        self.assertEquals(packet.getType(), packet_type)
        if decode:
            return protocol.decode_table[packet.getType()](packet._body)
        return packet

    def checkNotifyNodeInformation(self, conn, **kw):
214
        return self.checkNotifyPacket(conn, PacketTypes.NOTIFY_NODE_INFORMATION, **kw)
215 216

    def checkSendPartitionTable(self, conn, **kw):
217
        return self.checkNotifyPacket(conn, PacketTypes.SEND_PARTITION_TABLE, **kw)
218 219

    def checkStartOperation(self, conn, **kw):
220
        return self.checkNotifyPacket(conn, PacketTypes.START_OPERATION, **kw)
221 222

    def checkNotifyTransactionFinished(self, conn, **kw):
223
        return self.checkNotifyPacket(conn, PacketTypes.NOTIFY_TRANSACTION_FINISHED, **kw)
224 225

    def checkNotifyInformationLocked(self, conn, **kw):
226
        return self.checkAnswerPacket(conn, PacketTypes.NOTIFY_INFORMATION_LOCKED, **kw)
227 228

    def checkLockInformation(self, conn, **kw):
229
        return self.checkAskPacket(conn, PacketTypes.LOCK_INFORMATION, **kw)
230 231

    def checkUnlockInformation(self, conn, **kw):
232
        return self.checkAskPacket(conn, PacketTypes.UNLOCK_INFORMATION, **kw)
233 234

    def checkRequestNodeIdentification(self, conn, **kw):
235
        return self.checkAskPacket(conn, PacketTypes.REQUEST_NODE_IDENTIFICATION, **kw)
236 237

    def checkAskPrimaryMaster(self, conn, **kw):
238
        return self.checkAskPacket(conn, PacketTypes.ASK_PRIMARY_MASTER)
239 240

    def checkAskUnfinishedTransactions(self, conn, **kw):
241
        return self.checkAskPacket(conn, PacketTypes.ASK_UNFINISHED_TRANSACTIONS)
242 243

    def checkAskTransactionInformation(self, conn, **kw):
244
        return self.checkAskPacket(conn, PacketTypes.ASK_TRANSACTION_INFORMATION, **kw)
245 246

    def checkAskObjectPresent(self, conn, **kw):
247
        return self.checkAskPacket(conn, PacketTypes.ASK_OBJECT_PRESENT, **kw)
248 249

    def checkAskObject(self, conn, **kw):
250
        return self.checkAskPacket(conn, PacketTypes.ASK_OBJECT, **kw)
251 252

    def checkAskStoreObject(self, conn, **kw):
253
        return self.checkAskPacket(conn, PacketTypes.ASK_STORE_OBJECT, **kw)
254 255

    def checkAskStoreTransaction(self, conn, **kw):
256
        return self.checkAskPacket(conn, PacketTypes.ASK_STORE_TRANSACTION, **kw)
257 258

    def checkFinishTransaction(self, conn, **kw):
259
        return self.checkAskPacket(conn, PacketTypes.FINISH_TRANSACTION, **kw)
260 261

    def checkAskNewTid(self, conn, **kw):
262
        return self.checkAskPacket(conn, PacketTypes.ASK_BEGIN_TRANSACTION, **kw)
263 264

    def checkAskLastIDs(self, conn, **kw):
265
        return self.checkAskPacket(conn, PacketTypes.ASK_LAST_IDS, **kw)
266 267

    def checkAcceptNodeIdentification(self, conn, **kw):
268
        return self.checkAnswerPacket(conn, PacketTypes.ACCEPT_NODE_IDENTIFICATION, **kw)
269 270

    def checkAnswerPrimaryMaster(self, conn, **kw):
271
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_PRIMARY_MASTER, **kw)
272 273

    def checkAnswerLastIDs(self, conn, **kw):
274
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_LAST_IDS, **kw)
275 276

    def checkAnswerUnfinishedTransactions(self, conn, **kw):
277
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_UNFINISHED_TRANSACTIONS, **kw)
278 279

    def checkAnswerObject(self, conn, **kw):
280
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OBJECT, **kw)
281 282

    def checkAnswerTransactionInformation(self, conn, **kw):
283
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_TRANSACTION_INFORMATION, **kw)
284 285

    def checkAnswerTids(self, conn, **kw):
286
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_TIDS, **kw)
287 288

    def checkAnswerObjectHistory(self, conn, **kw):
289
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OBJECT_HISTORY, **kw)
290 291

    def checkAnswerStoreTransaction(self, conn, **kw):
292
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_STORE_TRANSACTION, **kw)
293 294

    def checkAnswerStoreObject(self, conn, **kw):
295
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_STORE_OBJECT, **kw)
296 297

    def checkAnswerOids(self, conn, **kw):
298
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OIDS, **kw)
299 300

    def checkAnswerPartitionTable(self, conn, **kw):
301
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_PARTITION_TABLE, **kw)
302 303

    def checkAnswerObjectPresent(self, conn, **kw):
304
        return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OBJECT_PRESENT, **kw)
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337


# XXX: imported from neo.master.test.connector since it's used at many places

connector_cpt = 0

class DoNothingConnector(Mock):
  def __init__(self, s=None):
    logging.info("initializing connector")
    self.desc = globals()['connector_cpt']
    globals()['connector_cpt'] = globals()['connector_cpt']+ 1
    self.packet_cpt = 0
    Mock.__init__(self)

  def getAddress(self):
    return self.addr
  
  def makeClientConnection(self, addr):
    self.addr = addr

  def getDescriptor(self):
    return self.desc


class TestElectionConnector(DoNothingConnector):
  def receive(self):
    """ simulate behavior of election """
    if self.packet_cpt == 0:
      # first : identify
      logging.info("in patched analyse / IDENTIFICATION")
      p = protocol.Packet()
      self.uuid = getNewUUID()
      p.acceptNodeIdentification(1,
338
                                 NodeType.MASTER,
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
                                 self.uuid,
                                 self.getAddress()[0],
                                 self.getAddress()[1],
                                 1009,
                                 2
                                 )
      self.packet_cpt += 1
      return p.encode()
    elif self.packet_cpt == 1:
      # second : answer primary master nodes
      logging.info("in patched analyse / ANSWER PM")
      p = protocol.Packet()
      p.answerPrimaryMaster(2, protocol.INVALID_UUID, [])        
      self.packet_cpt += 1
      return p.encode()
    else:
      # then do nothing
      from neo.connector import ConnectorTryAgainException
      raise ConnectorTryAgainException