Commit 59004b8c authored by Julien Muchembled's avatar Julien Muchembled

qa: code cleanup in non-threaded -u tests

parent bcf4afa0
...@@ -281,18 +281,6 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -281,18 +281,6 @@ class NeoUnitTestBase(NeoTestBase):
def getNextTID(self, ltid=None): def getNextTID(self, ltid=None):
return newTid(ltid) return newTid(ltid)
def getPTID(self, i=None):
""" Return an integer PTID """
if i is None:
return random.randint(1, 2**64)
return i
def getOID(self, i=None):
""" Return a 8-bytes OID """
if i is None:
return os.urandom(8)
return pack('!Q', i)
def getFakeConnector(self, descriptor=None): def getFakeConnector(self, descriptor=None):
return Mock({ return Mock({
'__repr__': 'FakeConnector', '__repr__': 'FakeConnector',
...@@ -321,18 +309,6 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -321,18 +309,6 @@ class NeoUnitTestBase(NeoTestBase):
""" Check if the ProtocolError exception was raised """ """ Check if the ProtocolError exception was raised """
self.assertRaises(protocol.ProtocolError, method, *args, **kwargs) self.assertRaises(protocol.ProtocolError, method, *args, **kwargs)
def checkUnexpectedPacketRaised(self, method, *args, **kwargs):
""" Check if the UnexpectedPacketError exception was raised """
self.assertRaises(protocol.UnexpectedPacketError, method, *args, **kwargs)
def checkIdenficationRequired(self, method, *args, **kwargs):
""" Check is the identification_required decorator is applied """
self.checkUnexpectedPacketRaised(method, *args, **kwargs)
def checkBrokenNodeDisallowedErrorRaised(self, method, *args, **kwargs):
""" Check if the BrokenNodeDisallowedError exception was raised """
self.assertRaises(protocol.BrokenNodeDisallowedError, method, *args, **kwargs)
def checkNotReadyErrorRaised(self, method, *args, **kwargs): def checkNotReadyErrorRaised(self, method, *args, **kwargs):
""" Check if the NotReadyError exception was raised """ """ Check if the NotReadyError exception was raised """
self.assertRaises(protocol.NotReadyError, method, *args, **kwargs) self.assertRaises(protocol.NotReadyError, method, *args, **kwargs)
...@@ -341,35 +317,18 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -341,35 +317,18 @@ class NeoUnitTestBase(NeoTestBase):
""" Ensure the connection was aborted """ """ Ensure the connection was aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1) self.assertEqual(len(conn.mockGetNamedCalls('abort')), 1)
def checkNotAborted(self, conn):
""" Ensure the connection was not aborted """
self.assertEqual(len(conn.mockGetNamedCalls('abort')), 0)
def checkClosed(self, conn): def checkClosed(self, conn):
""" Ensure the connection was closed """ """ Ensure the connection was closed """
self.assertEqual(len(conn.mockGetNamedCalls('close')), 1) self.assertEqual(len(conn.mockGetNamedCalls('close')), 1)
def checkNotClosed(self, conn):
""" Ensure the connection was not closed """
self.assertEqual(len(conn.mockGetNamedCalls('close')), 0)
def _checkNoPacketSend(self, conn, method_id): def _checkNoPacketSend(self, conn, method_id):
call_list = conn.mockGetNamedCalls(method_id) self.assertEqual([], conn.mockGetNamedCalls(method_id))
self.assertEqual(len(call_list), 0, call_list)
def checkNoPacketSent(self, conn, check_notify=True, check_answer=True, def checkNoPacketSent(self, conn):
check_ask=True):
""" check if no packet were sent """ """ check if no packet were sent """
if check_notify: self._checkNoPacketSend(conn, 'notify')
self._checkNoPacketSend(conn, 'notify') self._checkNoPacketSend(conn, 'answer')
if check_answer: self._checkNoPacketSend(conn, 'ask')
self._checkNoPacketSend(conn, 'answer')
if check_ask:
self._checkNoPacketSend(conn, 'ask')
def checkNoUUIDSet(self, conn):
""" ensure no UUID was set on the connection """
self.assertEqual(len(conn.mockGetNamedCalls('setUUID')), 0)
def checkUUIDSet(self, conn, uuid=None, check_intermediate=True): def checkUUIDSet(self, conn, uuid=None, check_intermediate=True):
""" ensure UUID was set on the connection """ """ ensure UUID was set on the connection """
...@@ -384,151 +343,41 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -384,151 +343,41 @@ class NeoUnitTestBase(NeoTestBase):
# in check(Ask|Answer|Notify)Packet we return the packet so it can be used # in check(Ask|Answer|Notify)Packet we return the packet so it can be used
# in tests if more accurate checks are required # in tests if more accurate checks are required
def checkErrorPacket(self, conn, decode=False): def checkErrorPacket(self, conn):
""" Check if an error packet was answered """ """ Check if an error packet was answered """
calls = conn.mockGetNamedCalls("answer") calls = conn.mockGetNamedCalls("answer")
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), Packets.Error) self.assertEqual(type(packet), Packets.Error)
if decode:
return packet.decode()
return packet return packet
def checkAskPacket(self, conn, packet_type, decode=False): def checkAskPacket(self, conn, packet_type):
""" Check if an ask-packet with the right type is sent """ """ Check if an ask-packet with the right type is sent """
calls = conn.mockGetNamedCalls('ask') calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type) self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet return packet
def checkAnswerPacket(self, conn, packet_type, decode=False): def checkAnswerPacket(self, conn, packet_type):
""" Check if an answer-packet with the right type is sent """ """ Check if an answer-packet with the right type is sent """
calls = conn.mockGetNamedCalls('answer') calls = conn.mockGetNamedCalls('answer')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type) self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet return packet
def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False): def checkNotifyPacket(self, conn, packet_type, packet_number=0):
""" Check if a notify-packet with the right type is sent """ """ Check if a notify-packet with the right type is sent """
calls = conn.mockGetNamedCalls('notify') calls = conn.mockGetNamedCalls('notify')
packet = calls.pop(packet_number).getParam(0) packet = calls.pop(packet_number).getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(type(packet), packet_type) self.assertEqual(type(packet), packet_type)
if decode:
return packet.decode()
return packet return packet
def checkNotify(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.Notify, **kw)
def checkNotifyNodeInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation, **kw)
def checkSendPartitionTable(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.SendPartitionTable, **kw)
def checkStartOperation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.StartOperation, **kw)
def checkInvalidateObjects(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.InvalidateObjects, **kw)
def checkAbortTransaction(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.AbortTransaction, **kw)
def checkNotifyLastOID(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyLastOID, **kw)
def checkAnswerTransactionFinished(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTransactionFinished, **kw)
def checkAnswerInformationLocked(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerInformationLocked, **kw)
def checkAskLockInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskLockInformation, **kw)
def checkNotifyUnlockInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyUnlockInformation, **kw)
def checkNotifyTransactionFinished(self, conn, **kw):
return self.checkNotifyPacket(conn, Packets.NotifyTransactionFinished, **kw)
def checkRequestIdentification(self, conn, **kw):
return self.checkAskPacket(conn, Packets.RequestIdentification, **kw)
def checkAskPrimary(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskPrimary)
def checkAskUnfinishedTransactions(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskUnfinishedTransactions)
def checkAskTransactionInformation(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskObject, **kw)
def checkAskStoreObject(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskStoreObject, **kw)
def checkAskStoreTransaction(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskStoreTransaction, **kw)
def checkAskFinishTransaction(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskFinishTransaction, **kw)
def checkAskNewTid(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskBeginTransaction, **kw)
def checkAskLastIDs(self, conn, **kw):
return self.checkAskPacket(conn, Packets.AskLastIDs, **kw)
def checkAcceptIdentification(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AcceptIdentification, **kw)
def checkAnswerPrimary(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPrimary, **kw)
def checkAnswerLastIDs(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerLastIDs, **kw)
def checkAnswerUnfinishedTransactions(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerUnfinishedTransactions, **kw)
def checkAnswerObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObject, **kw)
def checkAnswerTransactionInformation(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTransactionInformation, **kw)
def checkAnswerBeginTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction, **kw)
def checkAnswerTids(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw)
def checkAnswerTidsFrom(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerTIDsFrom, **kw)
def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerStoreObject(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
def checkAnswerPartitionTable(self, conn, **kw):
return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
class Patch(object): class Patch(object):
""" """
......
...@@ -68,6 +68,9 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -68,6 +68,9 @@ class ClientApplicationTests(NeoUnitTestBase):
# some helpers # some helpers
def checkAskObject(self, conn):
return self.checkAskPacket(conn, Packets.AskObject)
def _begin(self, app, txn, tid): def _begin(self, app, txn, tid):
txn_context = app._txn_container.new(txn) txn_context = app._txn_container.new(txn)
txn_context['ttid'] = tid txn_context['ttid'] = tid
......
...@@ -21,6 +21,7 @@ from .. import NeoUnitTestBase ...@@ -21,6 +21,7 @@ from .. import NeoUnitTestBase
from neo.client.app import ConnectionPool from neo.client.app import ConnectionPool
from neo.client.exception import NEOStorageError from neo.client.exception import NEOStorageError
from neo.client import pool from neo.client import pool
from neo.lib.util import p64
class ConnectionPoolTests(NeoUnitTestBase): class ConnectionPoolTests(NeoUnitTestBase):
...@@ -54,7 +55,7 @@ class ConnectionPoolTests(NeoUnitTestBase): ...@@ -54,7 +55,7 @@ class ConnectionPoolTests(NeoUnitTestBase):
def test_iterateForObject_noStorageAvailable(self): def test_iterateForObject_noStorageAvailable(self):
# no node available # no node available
oid = self.getOID(1) oid = p64(1)
app = Mock() app = Mock()
app.pt = Mock({'getCellList': []}) app.pt = Mock({'getCellList': []})
pool = ConnectionPool(app) pool = ConnectionPool(app)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from mock import Mock from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.lib.protocol import NodeTypes, NodeStates, Packets from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -62,6 +63,9 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -62,6 +63,9 @@ class MasterClientHandlerTests(NeoUnitTestBase):
) )
return uuid return uuid
def checkAnswerBeginTransaction(self, conn):
return self.checkAnswerPacket(conn, Packets.AnswerBeginTransaction)
# Tests # Tests
def test_07_askBeginTransaction(self): def test_07_askBeginTransaction(self):
tid1 = self.getNextTID() tid1 = self.getNextTID()
...@@ -87,12 +91,12 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -87,12 +91,12 @@ class MasterClientHandlerTests(NeoUnitTestBase):
calls = tm.mockGetNamedCalls('begin') calls = tm.mockGetNamedCalls('begin')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(client_node, None) calls[0].checkArgs(client_node, None)
args = self.checkAnswerBeginTransaction(conn, decode=True) packet = self.checkAnswerBeginTransaction(conn)
self.assertEqual(args, (tid1, )) self.assertEqual(packet.decode(), (tid1, ))
def test_08_askNewOIDs(self): def test_08_askNewOIDs(self):
service = self.service service = self.service
oid1, oid2 = self.getOID(1), self.getOID(2) oid1, oid2 = p64(1), p64(2)
self.app.tm.setLastOID(oid1) self.app.tm.setLastOID(oid1)
# client call it # client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
...@@ -136,7 +140,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -136,7 +140,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.setStorageReady(storage_uuid) self.app.setStorageReady(storage_uuid)
self.assertTrue(self.app.isStorageReady(storage_uuid)) self.assertTrue(self.app.isStorageReady(storage_uuid))
service.askFinishTransaction(conn, ttid, (), ()) service.askFinishTransaction(conn, ttid, (), ())
self.checkAskLockInformation(storage_conn) self.checkAskPacket(storage_conn, Packets.AskLockInformation)
self.assertEqual(len(self.app.tm.registerForNotification(storage_uuid)), 1) self.assertEqual(len(self.app.tm.registerForNotification(storage_uuid)), 1)
txn = self.app.tm[ttid] txn = self.app.tm[ttid]
pending_ttid = list(self.app.tm.registerForNotification(storage_uuid))[0] pending_ttid = list(self.app.tm.registerForNotification(storage_uuid))[0]
...@@ -170,8 +174,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -170,8 +174,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
ptid = self.checkAskPacket(storage_conn, Packets.AskPack, ptid = self.checkAskPacket(storage_conn, Packets.AskPack).decode()[0]
decode=True)[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertTrue(self.app.packing[0] is conn) self.assertTrue(self.app.packing[0] is conn)
self.assertEqual(self.app.packing[1], peer_id) self.assertEqual(self.app.packing[1], peer_id)
...@@ -183,8 +186,7 @@ class MasterClientHandlerTests(NeoUnitTestBase): ...@@ -183,8 +186,7 @@ class MasterClientHandlerTests(NeoUnitTestBase):
self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn) self.app.nm.getByUUID(storage_uuid).setConnection(storage_conn)
self.service.askPack(conn, tid) self.service.askPack(conn, tid)
self.checkNoPacketSent(storage_conn) self.checkNoPacketSent(storage_conn)
status = self.checkAnswerPacket(conn, Packets.AnswerPack, status = self.checkAnswerPacket(conn, Packets.AnswerPack).decode()[0]
decode=True)[0]
self.assertFalse(status) self.assertFalse(status)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
from mock import Mock from mock import Mock
from neo.lib import protocol from neo.lib import protocol
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import NodeTypes, NodeStates from neo.lib.protocol import NodeTypes, NodeStates, Packets
from neo.master.handlers.election import ClientElectionHandler, \ from neo.master.handlers.election import ClientElectionHandler, \
ServerElectionHandler ServerElectionHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -48,6 +48,9 @@ class MasterClientElectionTestBase(NeoUnitTestBase): ...@@ -48,6 +48,9 @@ class MasterClientElectionTestBase(NeoUnitTestBase):
node.setConnection(conn) node.setConnection(conn)
return (node, conn) return (node, conn)
def checkAcceptIdentification(self, conn):
return self.checkAnswerPacket(conn, Packets.AcceptIdentification)
class MasterClientElectionTests(MasterClientElectionTestBase): class MasterClientElectionTests(MasterClientElectionTestBase):
def setUp(self): def setUp(self):
...@@ -91,7 +94,7 @@ class MasterClientElectionTests(MasterClientElectionTestBase): ...@@ -91,7 +94,7 @@ class MasterClientElectionTests(MasterClientElectionTestBase):
self.election.connectionCompleted(conn) self.election.connectionCompleted(conn)
self._checkUnconnected(node) self._checkUnconnected(node)
self.assertTrue(node.isUnknown()) self.assertTrue(node.isUnknown())
self.checkRequestIdentification(conn) self.checkAskPacket(conn, Packets.RequestIdentification)
def _setNegociating(self, node): def _setNegociating(self, node):
self._checkUnconnected(node) self._checkUnconnected(node)
...@@ -252,9 +255,8 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -252,9 +255,8 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
self.election.requestIdentification(conn, self.election.requestIdentification(conn,
NodeTypes.MASTER, *args) NodeTypes.MASTER, *args)
self.checkUUIDSet(conn, node.getUUID()) self.checkUUIDSet(conn, node.getUUID())
args = self.checkAcceptIdentification(conn, decode=True)
(node_type, uuid, partitions, replicas, new_uuid, primary_uuid, (node_type, uuid, partitions, replicas, new_uuid, primary_uuid,
master_list) = args master_list) = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node.getUUID(), new_uuid) self.assertEqual(node.getUUID(), new_uuid)
self.assertNotEqual(node.getUUID(), uuid) self.assertNotEqual(node.getUUID(), uuid)
...@@ -290,7 +292,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase): ...@@ -290,7 +292,7 @@ class MasterServerElectionTests(MasterClientElectionTestBase):
None, None,
) )
node_type, uuid, partitions, replicas, _peer_uuid, primary, \ node_type, uuid, partitions, replicas, _peer_uuid, primary, \
master_list = self.checkAcceptIdentification(conn, decode=True) master_list = self.checkAcceptIdentification(conn).decode()
self.assertEqual(node_type, NodeTypes.MASTER) self.assertEqual(node_type, NodeTypes.MASTER)
self.assertEqual(uuid, self.app.uuid) self.assertEqual(uuid, self.app.uuid)
self.assertEqual(partitions, self.app.pt.getPartitions()) self.assertEqual(partitions, self.app.pt.getPartitions())
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import unittest import unittest
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.protocol import Packets
from neo.master.app import Application from neo.master.app import Application
class MasterAppTests(NeoUnitTestBase): class MasterAppTests(NeoUnitTestBase):
...@@ -31,6 +32,9 @@ class MasterAppTests(NeoUnitTestBase): ...@@ -31,6 +32,9 @@ class MasterAppTests(NeoUnitTestBase):
self.app.close() self.app.close()
NeoUnitTestBase._tearDown(self, success) NeoUnitTestBase._tearDown(self, success)
def checkNotifyNodeInformation(self, conn):
return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation)
def test_06_broadcastNodeInformation(self): def test_06_broadcastNodeInformation(self):
# defined some nodes to which data will be send # defined some nodes to which data will be send
master_uuid = self.getMasterUUID() master_uuid = self.getMasterUUID()
......
...@@ -71,10 +71,9 @@ class MasterStorageHandlerTests(NeoUnitTestBase): ...@@ -71,10 +71,9 @@ class MasterStorageHandlerTests(NeoUnitTestBase):
self.checkNoPacketSent(client_conn) self.checkNoPacketSent(client_conn)
self.assertEqual(self.app.packing[2], {conn2.getUUID()}) self.assertEqual(self.app.packing[2], {conn2.getUUID()})
self.service.answerPack(conn2, False) self.service.answerPack(conn2, False)
status = self.checkAnswerPacket(client_conn, Packets.AnswerPack, packet = self.checkAnswerPacket(client_conn, Packets.AnswerPack)
decode=True)[0]
# TODO: verify packet peer id # TODO: verify packet peer id
self.assertTrue(status) self.assertTrue(packet.decode()[0])
self.assertEqual(self.app.packing, None) self.assertEqual(self.app.packing, None)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,6 +20,7 @@ from collections import deque ...@@ -20,6 +20,7 @@ from collections import deque
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.lib.util import p64
from neo.lib.protocol import INVALID_TID, INVALID_OID, Packets, LockState from neo.lib.protocol import INVALID_TID, INVALID_OID, Packets, LockState
class StorageClientHandlerTests(NeoUnitTestBase): class StorageClientHandlerTests(NeoUnitTestBase):
...@@ -91,7 +92,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -91,7 +92,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
calls = self.app.dm.mockGetNamedCalls('getTIDList') calls = self.app.dm.mockGetNamedCalls('getTIDList')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
calls[0].checkArgs(1, 1, [1, ]) calls[0].checkArgs(1, 1, [1, ])
self.checkAnswerTids(conn) self.checkAnswerPacket(conn, Packets.AnswerTIDs)
def test_26_askObjectHistory1(self): def test_26_askObjectHistory1(self):
# invalid offsets => error # invalid offsets => error
...@@ -108,7 +109,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -108,7 +109,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
ltid = self.getNextTID() ltid = self.getNextTID()
undone_tid = self.getNextTID() undone_tid = self.getNextTID()
# Keep 2 entries here, so we check findUndoTID is called only once. # Keep 2 entries here, so we check findUndoTID is called only once.
oid_list = [self.getOID(1), self.getOID(2)] oid_list = map(p64, (1, 2))
obj2_data = [] # Marker obj2_data = [] # Marker
self.app.tm = Mock({ self.app.tm = Mock({
'getObjectFromTransaction': None, 'getObjectFromTransaction': None,
...@@ -134,7 +135,7 @@ class StorageClientHandlerTests(NeoUnitTestBase): ...@@ -134,7 +135,7 @@ class StorageClientHandlerTests(NeoUnitTestBase):
conn = self._getConnection() conn = self._getConnection()
self.operation.askHasLock(conn, tid_1, oid) self.operation.askHasLock(conn, tid_1, oid)
p_oid, p_status = self.checkAnswerPacket(conn, p_oid, p_status = self.checkAnswerPacket(conn,
Packets.AnswerHasLock, decode=True) Packets.AnswerHasLock).decode()
self.assertEqual(oid, p_oid) self.assertEqual(oid, p_oid)
self.assertEqual(status, p_status) self.assertEqual(status, p_status)
......
...@@ -103,20 +103,19 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -103,20 +103,19 @@ class StorageDBTests(NeoUnitTestBase):
def test_15_PTID(self): def test_15_PTID(self):
db = self.getDB() db = self.getDB()
self.checkConfigEntry(db.getPTID, db.setPTID, self.getPTID(1)) self.checkConfigEntry(db.getPTID, db.setPTID, 1)
def test_getPartitionTable(self): def test_getPartitionTable(self):
db = self.getDB() db = self.getDB()
ptid = self.getPTID(1)
uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID() uuid1, uuid2 = self.getStorageUUID(), self.getStorageUUID()
cell1 = (0, uuid1, CellStates.OUT_OF_DATE) cell1 = (0, uuid1, CellStates.OUT_OF_DATE)
cell2 = (1, uuid1, CellStates.UP_TO_DATE) cell2 = (1, uuid1, CellStates.UP_TO_DATE)
db.changePartitionTable(ptid, [cell1, cell2], 1) db.changePartitionTable(1, [cell1, cell2], 1)
result = db.getPartitionTable() result = db.getPartitionTable()
self.assertEqual(set(result), {cell1, cell2}) self.assertEqual(set(result), {cell1, cell2})
def getOIDs(self, count): def getOIDs(self, count):
return map(self.getOID, xrange(count)) return map(p64, xrange(count))
def getTIDs(self, count): def getTIDs(self, count):
tid_list = [self.getNextTID()] tid_list = [self.getNextTID()]
...@@ -198,7 +197,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -198,7 +197,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_setPartitionTable(self): def test_setPartitionTable(self):
db = self.getDB() db = self.getDB()
ptid = self.getPTID(1) ptid = 1
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE cell2 = 1, uuid, CellStates.UP_TO_DATE
...@@ -220,7 +219,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -220,7 +219,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_changePartitionTable(self): def test_changePartitionTable(self):
db = self.getDB() db = self.getDB()
ptid = self.getPTID(1) ptid = 1
uuid = self.getStorageUUID() uuid = self.getStorageUUID()
cell1 = 0, uuid, CellStates.OUT_OF_DATE cell1 = 0, uuid, CellStates.OUT_OF_DATE
cell2 = 1, uuid, CellStates.UP_TO_DATE cell2 = 1, uuid, CellStates.UP_TO_DATE
...@@ -301,7 +300,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -301,7 +300,7 @@ class StorageDBTests(NeoUnitTestBase):
def test_deleteRange(self): def test_deleteRange(self):
np = 4 np = 4
self.setNumPartitions(np) self.setNumPartitions(np)
t1, t2, t3 = map(self.getOID, (1, 2, 3)) t1, t2, t3 = map(p64, (1, 2, 3))
oid_list = self.getOIDs(np * 2) oid_list = self.getOIDs(np * 2)
for tid in t1, t2, t3: for tid in t1, t2, t3:
txn, objs = self.getTransaction(oid_list) txn, objs = self.getTransaction(oid_list)
...@@ -339,7 +338,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -339,7 +338,7 @@ class StorageDBTests(NeoUnitTestBase):
self.assertEqual(self.db.getTransaction(tid2, False), None) self.assertEqual(self.db.getTransaction(tid2, False), None)
def test_getObjectHistory(self): def test_getObjectHistory(self):
oid = self.getOID(1) oid = p64(1)
tid1, tid2, tid3 = self.getTIDs(3) tid1, tid2, tid3 = self.getTIDs(3)
txn1, objs1 = self.getTransaction([oid]) txn1, objs1 = self.getTransaction([oid])
txn2, objs2 = self.getTransaction([oid]) txn2, objs2 = self.getTransaction([oid])
...@@ -362,7 +361,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -362,7 +361,7 @@ class StorageDBTests(NeoUnitTestBase):
def _storeTransactions(self, count): def _storeTransactions(self, count):
# use OID generator to know result of tid % N # use OID generator to know result of tid % N
tid_list = self.getOIDs(count) tid_list = self.getOIDs(count)
oid = self.getOID(1) oid = p64(1)
for tid in tid_list: for tid in tid_list:
txn, objs = self.getTransaction([oid]) txn, objs = self.getTransaction([oid])
self.db.storeTransaction(tid, objs, txn, False) self.db.storeTransaction(tid, objs, txn, False)
...@@ -446,7 +445,7 @@ class StorageDBTests(NeoUnitTestBase): ...@@ -446,7 +445,7 @@ class StorageDBTests(NeoUnitTestBase):
tid3 = self.getNextTID() tid3 = self.getNextTID()
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid5 = self.getNextTID() tid5 = self.getNextTID()
oid1 = self.getOID(1) oid1 = p64(1)
foo = db.holdData("3" * 20, 'foo', 0) foo = db.holdData("3" * 20, 'foo', 0)
bar = db.holdData("4" * 20, 'bar', 0) bar = db.holdData("4" * 20, 'bar', 0)
db.releaseData((foo, bar)) db.releaseData((foo, bar))
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from mock import Mock from mock import Mock
from .. import NeoUnitTestBase from .. import NeoUnitTestBase
from neo.lib.util import p64
from neo.storage.transactions import TransactionManager from neo.storage.transactions import TransactionManager
...@@ -36,7 +37,7 @@ class TransactionManagerTests(NeoUnitTestBase): ...@@ -36,7 +37,7 @@ class TransactionManagerTests(NeoUnitTestBase):
def test_updateObjectDataForPack(self): def test_updateObjectDataForPack(self):
ram_serial = self.getNextTID() ram_serial = self.getNextTID()
oid = self.getOID(1) oid = p64(1)
orig_serial = self.getNextTID() orig_serial = self.getNextTID()
uuid = self.getClientUUID() uuid = self.getClientUUID()
locking_serial = self.getNextTID() locking_serial = self.getNextTID()
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
from . import NeoUnitTestBase from . import NeoUnitTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.lib.bootstrap import BootstrapManager from neo.lib.bootstrap import BootstrapManager
from neo.lib.protocol import NodeTypes from neo.lib.protocol import NodeTypes, Packets
class BootstrapManagerTests(NeoUnitTestBase): class BootstrapManagerTests(NeoUnitTestBase):
...@@ -46,7 +46,7 @@ class BootstrapManagerTests(NeoUnitTestBase): ...@@ -46,7 +46,7 @@ class BootstrapManagerTests(NeoUnitTestBase):
conn = self.getFakeConnection(address=address) conn = self.getFakeConnection(address=address)
self.bootstrap.current = self.app.nm.createMaster(address=address) self.bootstrap.current = self.app.nm.createMaster(address=address)
self.bootstrap.connectionCompleted(conn) self.bootstrap.connectionCompleted(conn)
self.checkRequestIdentification(conn) self.checkAskPacket(conn, Packets.RequestIdentification)
def testHandleNotReady(self): def testHandleNotReady(self):
# the primary is not ready # the primary is not ready
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment