Commit e3368849 authored by Grégory Wisniewski's avatar Grégory Wisniewski

Update tests to follow changes introduced in previous commit.


git-svn-id: https://svn.erp5.org/repos/neo/trunk@1357 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent a6aae268
...@@ -22,7 +22,7 @@ import MySQLdb ...@@ -22,7 +22,7 @@ import MySQLdb
from neo import logging from neo import logging
from mock import Mock from mock import Mock
from neo import protocol from neo import protocol
from neo.protocol import PacketTypes from neo.protocol import Packets
DB_PREFIX = 'test_neo_' DB_PREFIX = 'test_neo_'
DB_ADMIN = 'root' DB_ADMIN = 'root'
...@@ -169,8 +169,9 @@ class NeoTestBase(unittest.TestCase): ...@@ -169,8 +169,9 @@ class NeoTestBase(unittest.TestCase):
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
packet = calls[0].getParam(0) packet = calls[0].getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEquals(packet.getType(), PacketTypes.ERROR) self.assertEquals(packet.getType(), Packets.Error)
if decode: if decode:
return packet.decode()
return protocol.decode_table[packet.getType()](packet._body) return protocol.decode_table[packet.getType()](packet._body)
return packet return packet
...@@ -182,7 +183,7 @@ class NeoTestBase(unittest.TestCase): ...@@ -182,7 +183,7 @@ class NeoTestBase(unittest.TestCase):
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEquals(packet.getType(), packet_type) self.assertEquals(packet.getType(), packet_type)
if decode: if decode:
return protocol.decode_table[packet.getType()](packet._body) return packet.decode()
return packet return packet
def checkAnswerPacket(self, conn, packet_type, answered_packet=None, decode=False): def checkAnswerPacket(self, conn, packet_type, answered_packet=None, decode=False):
...@@ -196,7 +197,7 @@ class NeoTestBase(unittest.TestCase): ...@@ -196,7 +197,7 @@ class NeoTestBase(unittest.TestCase):
msg_id = calls[0].getParam(1) msg_id = calls[0].getParam(1)
self.assertEqual(msg_id, answered_packet.getId()) self.assertEqual(msg_id, answered_packet.getId())
if decode: if decode:
return protocol.decode_table[packet.getType()](packet._body) return packet.decode()
return packet return packet
def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False): def checkNotifyPacket(self, conn, packet_type, packet_number=0, decode=False):
...@@ -207,101 +208,101 @@ class NeoTestBase(unittest.TestCase): ...@@ -207,101 +208,101 @@ class NeoTestBase(unittest.TestCase):
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEquals(packet.getType(), packet_type) self.assertEquals(packet.getType(), packet_type)
if decode: if decode:
return protocol.decode_table[packet.getType()](packet._body) return packet.decode()
return packet return packet
def checkNotifyNodeInformation(self, conn, **kw): def checkNotifyNodeInformation(self, conn, **kw):
return self.checkNotifyPacket(conn, PacketTypes.NOTIFY_NODE_INFORMATION, **kw) return self.checkNotifyPacket(conn, Packets.NotifyNodeInformation, **kw)
def checkSendPartitionTable(self, conn, **kw): def checkSendPartitionTable(self, conn, **kw):
return self.checkNotifyPacket(conn, PacketTypes.SEND_PARTITION_TABLE, **kw) return self.checkNotifyPacket(conn, Packets.SendPartitionTable, **kw)
def checkStartOperation(self, conn, **kw): def checkStartOperation(self, conn, **kw):
return self.checkNotifyPacket(conn, PacketTypes.START_OPERATION, **kw) return self.checkNotifyPacket(conn, Packets.StartOperation, **kw)
def checkNotifyTransactionFinished(self, conn, **kw): def checkNotifyTransactionFinished(self, conn, **kw):
return self.checkNotifyPacket(conn, PacketTypes.NOTIFY_TRANSACTION_FINISHED, **kw) return self.checkNotifyPacket(conn, Packets.NotifyTransactionFinished, **kw)
def checkNotifyInformationLocked(self, conn, **kw): def checkNotifyInformationLocked(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.NOTIFY_INFORMATION_LOCKED, **kw) return self.checkAnswerPacket(conn, Packets.NotifyInformationLocked, **kw)
def checkLockInformation(self, conn, **kw): def checkLockInformation(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.LOCK_INFORMATION, **kw) return self.checkAskPacket(conn, Packets.LockInformation, **kw)
def checkUnlockInformation(self, conn, **kw): def checkUnlockInformation(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.UNLOCK_INFORMATION, **kw) return self.checkAskPacket(conn, Packets.UnlockInformation, **kw)
def checkRequestNodeIdentification(self, conn, **kw): def checkRequestIdentification(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.REQUEST_NODE_IDENTIFICATION, **kw) return self.checkAskPacket(conn, Packets.RequestIdentification, **kw)
def checkAskPrimaryMaster(self, conn, **kw): def checkAskPrimary(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.ASK_PRIMARY_MASTER) return self.checkAskPacket(conn, Packets.AskPrimary)
def checkAskUnfinishedTransactions(self, conn, **kw): def checkAskUnfinishedTransactions(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.ASK_UNFINISHED_TRANSACTIONS) return self.checkAskPacket(conn, Packets.AskUnfinishedTransactions)
def checkAskTransactionInformation(self, conn, **kw): def checkAskTransactionInformation(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.ASK_TRANSACTION_INFORMATION, **kw) return self.checkAskPacket(conn, Packets.AskTransactionInformation, **kw)
def checkAskObjectPresent(self, conn, **kw): def checkAskObjectPresent(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.ASK_OBJECT_PRESENT, **kw) return self.checkAskPacket(conn, Packets.AskObjectPresent, **kw)
def checkAskObject(self, conn, **kw): def checkAskObject(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.ASK_OBJECT, **kw) return self.checkAskPacket(conn, Packets.AskObject, **kw)
def checkAskStoreObject(self, conn, **kw): def checkAskStoreObject(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.ASK_STORE_OBJECT, **kw) return self.checkAskPacket(conn, Packets.AskStoreObject, **kw)
def checkAskStoreTransaction(self, conn, **kw): def checkAskStoreTransaction(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.ASK_STORE_TRANSACTION, **kw) return self.checkAskPacket(conn, Packets.AskStoreTransaction, **kw)
def checkFinishTransaction(self, conn, **kw): def checkFinishTransaction(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.FINISH_TRANSACTION, **kw) return self.checkAskPacket(conn, Packets.FinishTransaction, **kw)
def checkAskNewTid(self, conn, **kw): def checkAskNewTid(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.ASK_BEGIN_TRANSACTION, **kw) return self.checkAskPacket(conn, Packets.AskBeginTransaction, **kw)
def checkAskLastIDs(self, conn, **kw): def checkAskLastIDs(self, conn, **kw):
return self.checkAskPacket(conn, PacketTypes.ASK_LAST_IDS, **kw) return self.checkAskPacket(conn, Packets.AskLastIDs, **kw)
def checkAcceptNodeIdentification(self, conn, **kw): def checkAcceptIdentification(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ACCEPT_NODE_IDENTIFICATION, **kw) return self.checkAnswerPacket(conn, Packets.AcceptIdentification, **kw)
def checkAnswerPrimaryMaster(self, conn, **kw): def checkAnswerPrimary(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_PRIMARY_MASTER, **kw) return self.checkAnswerPacket(conn, Packets.AnswerPrimary, **kw)
def checkAnswerLastIDs(self, conn, **kw): def checkAnswerLastIDs(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_LAST_IDS, **kw) return self.checkAnswerPacket(conn, Packets.AnswerLastIDs, **kw)
def checkAnswerUnfinishedTransactions(self, conn, **kw): def checkAnswerUnfinishedTransactions(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_UNFINISHED_TRANSACTIONS, **kw) return self.checkAnswerPacket(conn, Packets.AnswerUnfinishedTransactions, **kw)
def checkAnswerObject(self, conn, **kw): def checkAnswerObject(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OBJECT, **kw) return self.checkAnswerPacket(conn, Packets.AnswerObject, **kw)
def checkAnswerTransactionInformation(self, conn, **kw): def checkAnswerTransactionInformation(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_TRANSACTION_INFORMATION, **kw) return self.checkAnswerPacket(conn, Packets.AnswerTransactionInformation, **kw)
def checkAnswerTids(self, conn, **kw): def checkAnswerTids(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_TIDS, **kw) return self.checkAnswerPacket(conn, Packets.AnswerTIDs, **kw)
def checkAnswerObjectHistory(self, conn, **kw): def checkAnswerObjectHistory(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OBJECT_HISTORY, **kw) return self.checkAnswerPacket(conn, Packets.AnswerObjectHistory, **kw)
def checkAnswerStoreTransaction(self, conn, **kw): def checkAnswerStoreTransaction(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_STORE_TRANSACTION, **kw) return self.checkAnswerPacket(conn, Packets.AnswerStoreTransaction, **kw)
def checkAnswerStoreObject(self, conn, **kw): def checkAnswerStoreObject(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_STORE_OBJECT, **kw) return self.checkAnswerPacket(conn, Packets.AnswerStoreObject, **kw)
def checkAnswerOids(self, conn, **kw): def checkAnswerOids(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OIDS, **kw) return self.checkAnswerPacket(conn, Packets.AnswerOIDs, **kw)
def checkAnswerPartitionTable(self, conn, **kw): def checkAnswerPartitionTable(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_PARTITION_TABLE, **kw) return self.checkAnswerPacket(conn, Packets.AnswerPartitionTable, **kw)
def checkAnswerObjectPresent(self, conn, **kw): def checkAnswerObjectPresent(self, conn, **kw):
return self.checkAnswerPacket(conn, PacketTypes.ANSWER_OBJECT_PRESENT, **kw) return self.checkAnswerPacket(conn, Packets.AnswerObjectPresent, **kw)
# XXX: imported from neo.master.test.connector since it's used at many places # XXX: imported from neo.master.test.connector since it's used at many places
...@@ -335,7 +336,7 @@ class TestElectionConnector(DoNothingConnector): ...@@ -335,7 +336,7 @@ class TestElectionConnector(DoNothingConnector):
logging.info("in patched analyse / IDENTIFICATION") logging.info("in patched analyse / IDENTIFICATION")
p = protocol.Packet() p = protocol.Packet()
self.uuid = getNewUUID() self.uuid = getNewUUID()
p.acceptNodeIdentification(1, NodeType.MASTER, self.uuid, p.AcceptIdentification(1, NodeType.MASTER, self.uuid,
self.getAddress()[0], self.getAddress()[1], 1009, 2) self.getAddress()[0], self.getAddress()[1], 1009, 2)
self.packet_cpt += 1 self.packet_cpt += 1
return p.encode() return p.encode()
...@@ -343,7 +344,7 @@ class TestElectionConnector(DoNothingConnector): ...@@ -343,7 +344,7 @@ class TestElectionConnector(DoNothingConnector):
# second : answer primary master nodes # second : answer primary master nodes
logging.info("in patched analyse / ANSWER PM") logging.info("in patched analyse / ANSWER PM")
p = protocol.Packet() p = protocol.Packet()
p.answerPrimaryMaster(2, protocol.INVALID_UUID, []) p.answerPrimary(2, protocol.INVALID_UUID, [])
self.packet_cpt += 1 self.packet_cpt += 1
return p.encode() return p.encode()
else: else:
......
...@@ -23,7 +23,7 @@ from neo.client.app import Application ...@@ -23,7 +23,7 @@ from neo.client.app import Application
from neo.client.exception import NEOStorageError, NEOStorageNotFoundError, \ from neo.client.exception import NEOStorageError, NEOStorageNotFoundError, \
NEOStorageConflictError NEOStorageConflictError
from neo import protocol from neo import protocol
from neo.protocol import INVALID_TID from neo.protocol import Packets, INVALID_TID
from neo.util import makeChecksum from neo.util import makeChecksum
import neo.connection import neo.connection
...@@ -77,7 +77,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -77,7 +77,7 @@ class ClientApplicationTests(NeoTestBase):
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEquals(packet.getType(), packet_type) self.assertEquals(packet.getType(), packet_type)
if decode: if decode:
return protocol.decode_table[packet.getType()](packet._body) return packet.decode()
return packet return packet
def getApp(self, master_nodes='127.0.0.1:10010', name='test', def getApp(self, master_nodes='127.0.0.1:10010', name='test',
...@@ -111,7 +111,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -111,7 +111,7 @@ class ClientApplicationTests(NeoTestBase):
if oid is None: if oid is None:
oid = self.makeOID() oid = self.makeOID()
obj = (oid, tid, 'DATA', '', app.local_var.txn) obj = (oid, tid, 'DATA', '', app.local_var.txn)
packet = protocol.answerStoreObject(conflicting=0, oid=oid, serial=tid) packet = Packets.AnswerStoreObject(conflicting=0, oid=oid, serial=tid)
conn = Mock({ 'getNextId': 1, 'fakeReceived': packet, }) conn = Mock({ 'getNextId': 1, 'fakeReceived': packet, })
cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', }) cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', })
app.cp = Mock({ 'getConnForCell': conn}) app.cp = Mock({ 'getConnForCell': conn})
...@@ -121,7 +121,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -121,7 +121,7 @@ class ClientApplicationTests(NeoTestBase):
def voteTransaction(self, app): def voteTransaction(self, app):
tid = app.local_var.tid tid = app.local_var.tid
txn = app.local_var.txn txn = app.local_var.txn
packet = protocol.answerStoreTransaction(tid=tid) packet = Packets.AnswerStoreTransaction(tid=tid)
conn = Mock({ 'getNextId': 1, 'fakeReceived': packet, }) conn = Mock({ 'getNextId': 1, 'fakeReceived': packet, })
cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', }) cell = Mock({ 'getAddress': 'FakeServer', 'getState': 'FakeState', })
app.pt = Mock({ 'getCellListForID': (cell, cell, ) }) app.pt = Mock({ 'getCellListForID': (cell, cell, ) })
...@@ -131,7 +131,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -131,7 +131,7 @@ class ClientApplicationTests(NeoTestBase):
def finishTransaction(self, app): def finishTransaction(self, app):
txn = app.local_var.txn txn = app.local_var.txn
tid = app.local_var.tid tid = app.local_var.tid
packet = protocol.notifyTransactionFinished(tid) packet = Packets.NotifyTransactionFinished(tid)
app.master_conn = Mock({ app.master_conn = Mock({
'getNextId': 1, 'getNextId': 1,
'getAddress': ('127.0.0.1', 10010), 'getAddress': ('127.0.0.1', 10010),
...@@ -164,7 +164,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -164,7 +164,7 @@ class ClientApplicationTests(NeoTestBase):
app = self.getApp() app = self.getApp()
test_msg_id = 50 test_msg_id = 50
test_oid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02'] test_oid_list = ['\x00\x00\x00\x00\x00\x00\x00\x01', '\x00\x00\x00\x00\x00\x00\x00\x02']
response_packet = protocol.answerNewOIDs(test_oid_list[:]) response_packet = Packets.AnswerNewOIDs(test_oid_list[:])
app.master_conn = Mock({'getNextId': test_msg_id, '_addPacket': None, app.master_conn = Mock({'getNextId': test_msg_id, '_addPacket': None,
'expectMessage': None, 'lock': None, 'expectMessage': None, 'lock': None,
'unlock': None, 'unlock': None,
...@@ -232,7 +232,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -232,7 +232,7 @@ class ClientApplicationTests(NeoTestBase):
self.assertRaises(NEOStorageNotFoundError, app.load, oid) self.assertRaises(NEOStorageNotFoundError, app.load, oid)
self.checkAskObject(conn) self.checkAskObject(conn)
# object found on storage nodes and put in cache # object found on storage nodes and put in cache
packet = protocol.answerObject(*an_object[1:]) packet = Packets.AnswerObject(*an_object[1:])
conn = Mock({ conn = Mock({
'getAddress': ('127.0.0.1', 0), 'getAddress': ('127.0.0.1', 0),
'fakeReceived': packet, 'fakeReceived': packet,
...@@ -277,7 +277,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -277,7 +277,7 @@ class ClientApplicationTests(NeoTestBase):
mq.store(oid, (tid1, 'WRONG')) mq.store(oid, (tid1, 'WRONG'))
self.assertTrue(oid in mq) self.assertTrue(oid in mq)
another_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum('RIGHT'), 'RIGHT') another_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum('RIGHT'), 'RIGHT')
packet = protocol.answerObject(*another_object[1:]) packet = Packets.AnswerObject(*another_object[1:])
conn = Mock({ conn = Mock({
'getAddress': ('127.0.0.1', 0), 'getAddress': ('127.0.0.1', 0),
'fakeReceived': packet, 'fakeReceived': packet,
...@@ -310,7 +310,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -310,7 +310,7 @@ class ClientApplicationTests(NeoTestBase):
self.checkAskObject(conn) self.checkAskObject(conn)
# no previous versions -> return None # no previous versions -> return None
an_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum(''), '') an_object = (1, oid, tid2, INVALID_SERIAL, 0, makeChecksum(''), '')
packet = protocol.answerObject(*an_object[1:]) packet = Packets.AnswerObject(*an_object[1:])
conn = Mock({ conn = Mock({
'getAddress': ('127.0.0.1', 0), 'getAddress': ('127.0.0.1', 0),
'fakeReceived': packet, 'fakeReceived': packet,
...@@ -325,7 +325,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -325,7 +325,7 @@ class ClientApplicationTests(NeoTestBase):
mq.store(oid, (tid1, 'WRONG')) mq.store(oid, (tid1, 'WRONG'))
self.assertTrue(oid in mq) self.assertTrue(oid in mq)
another_object = (1, oid, tid1, tid2, 0, makeChecksum('RIGHT'), 'RIGHT') another_object = (1, oid, tid1, tid2, 0, makeChecksum('RIGHT'), 'RIGHT')
packet = protocol.answerObject(*another_object[1:]) packet = Packets.AnswerObject(*another_object[1:])
conn = Mock({ conn = Mock({
'getAddress': ('127.0.0.1', 0), 'getAddress': ('127.0.0.1', 0),
'fakeReceived': packet, 'fakeReceived': packet,
...@@ -357,7 +357,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -357,7 +357,7 @@ class ClientApplicationTests(NeoTestBase):
# no connection -> NEOStorageError (wait until connected to primary) # no connection -> NEOStorageError (wait until connected to primary)
#self.assertRaises(NEOStorageError, app.tpc_begin, transaction=txn, tid=None) #self.assertRaises(NEOStorageError, app.tpc_begin, transaction=txn, tid=None)
# ask a tid to pmn # ask a tid to pmn
packet = protocol.answerBeginTransaction(tid=tid) packet = Packets.AnswerBeginTransaction(tid=tid)
app.master_conn = Mock({ app.master_conn = Mock({
'getNextId': 1, 'getNextId': 1,
'expectMessage': None, 'expectMessage': None,
...@@ -401,7 +401,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -401,7 +401,7 @@ class ClientApplicationTests(NeoTestBase):
# build conflicting state # build conflicting state
app.local_var.txn = txn app.local_var.txn = txn
app.local_var.tid = tid app.local_var.tid = tid
packet = protocol.answerStoreObject(conflicting=1, oid=oid, serial=tid) packet = Packets.AnswerStoreObject(conflicting=1, oid=oid, serial=tid)
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'fakeReceived': packet, 'fakeReceived': packet,
...@@ -430,7 +430,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -430,7 +430,7 @@ class ClientApplicationTests(NeoTestBase):
# case with no conflict # case with no conflict
app.local_var.txn = txn app.local_var.txn = txn
app.local_var.tid = tid app.local_var.tid = tid
packet = protocol.answerStoreObject(conflicting=0, oid=oid, serial=tid) packet = Packets.AnswerStoreObject(conflicting=0, oid=oid, serial=tid)
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'fakeReceived': packet, 'fakeReceived': packet,
...@@ -469,7 +469,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -469,7 +469,7 @@ class ClientApplicationTests(NeoTestBase):
app.local_var.txn = txn app.local_var.txn = txn
app.local_var.tid = tid app.local_var.tid = tid
# wrong answer -> failure # wrong answer -> failure
packet = protocol.answerNewOIDs(()) packet = Packets.AnswerNewOIDs(())
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'fakeReceived': packet, 'fakeReceived': packet,
...@@ -489,7 +489,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -489,7 +489,7 @@ class ClientApplicationTests(NeoTestBase):
self.assertEquals(len(calls), 1) self.assertEquals(len(calls), 1)
packet = calls[0].getParam(0) packet = calls[0].getParam(0)
self.assertTrue(isinstance(packet, Packet)) self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet._type, ASK_STORE_TRANSACTION) self.assertEquals(packet._type, AskStoreTransaction)
def test_tpc_vote3(self): def test_tpc_vote3(self):
app = self.getApp() app = self.getApp()
...@@ -498,7 +498,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -498,7 +498,7 @@ class ClientApplicationTests(NeoTestBase):
app.local_var.txn = txn app.local_var.txn = txn
app.local_var.tid = tid app.local_var.tid = tid
# response -> OK # response -> OK
packet = protocol.answerStoreTransaction(tid=tid) packet = Packets.AnswerStoreTransaction(tid=tid)
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'fakeReceived': packet, 'fakeReceived': packet,
...@@ -555,9 +555,9 @@ class ClientApplicationTests(NeoTestBase): ...@@ -555,9 +555,9 @@ class ClientApplicationTests(NeoTestBase):
app.local_var.data_dict = {oid1: '', oid2: ''} app.local_var.data_dict = {oid1: '', oid2: ''}
app.tpc_abort(txn) app.tpc_abort(txn)
# will check if there was just one call/packet : # will check if there was just one call/packet :
self.checkNotifyPacket(conn1, ABORT_TRANSACTION) self.checkNotifyPacket(conn1, AbortTransaction)
self.checkNotifyPacket(conn2, ABORT_TRANSACTION) self.checkNotifyPacket(conn2, AbortTransaction)
self.checkNotifyPacket(app.master_conn, ABORT_TRANSACTION) self.checkNotifyPacket(app.master_conn, AbortTransaction)
self.assertEquals(app.local_var.tid, None) self.assertEquals(app.local_var.tid, None)
self.assertEquals(app.local_var.txn, None) self.assertEquals(app.local_var.txn, None)
self.assertEquals(app.local_var.data_dict, {}) self.assertEquals(app.local_var.data_dict, {})
...@@ -594,7 +594,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -594,7 +594,7 @@ class ClientApplicationTests(NeoTestBase):
def hook(tid): def hook(tid):
self.f_called = True self.f_called = True
self.f_called_with_tid = tid self.f_called_with_tid = tid
packet = protocol.answerBeginTransaction(INVALID_TID) packet = Packets.AnswerBeginTransaction(INVALID_TID)
app.master_conn = Mock({ app.master_conn = Mock({
'getNextId': 1, 'getNextId': 1,
'getAddress': ('127.0.0.1', 10000), 'getAddress': ('127.0.0.1', 10000),
...@@ -619,7 +619,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -619,7 +619,7 @@ class ClientApplicationTests(NeoTestBase):
def hook(tid): def hook(tid):
self.f_called = True self.f_called = True
self.f_called_with_tid = tid self.f_called_with_tid = tid
packet = protocol.notifyTransactionFinished(tid) packet = Packets.NotifyTransactionFinished(tid)
app.master_conn = Mock({ app.master_conn = Mock({
'getNextId': 1, 'getNextId': 1,
'getAddress': ('127.0.0.1', 10010), 'getAddress': ('127.0.0.1', 10010),
...@@ -685,19 +685,19 @@ class ClientApplicationTests(NeoTestBase): ...@@ -685,19 +685,19 @@ class ClientApplicationTests(NeoTestBase):
self.voteTransaction(app) self.voteTransaction(app)
self.finishTransaction(app) self.finishTransaction(app)
# undo 1 -> no previous revision # undo 1 -> no previous revision
u1p1 = protocol.answerTransactionInformation(tid1, '', '', '', (oid1, )) u1p1 = Packets.AnswerTransactionInformation(tid1, '', '', '', (oid1, ))
u1p2 = protocol.oidNotFound('oid not found') u1p2 = protocol.oidNotFound('oid not found')
# undo 2 -> not end tid # undo 2 -> not end tid
u2p1 = protocol.answerTransactionInformation(tid2, '', '', '', (oid2, )) u2p1 = Packets.AnswerTransactionInformation(tid2, '', '', '', (oid2, ))
u2p2 = protocol.answerObject(oid2, tid2, tid3, 0, makeChecksum('O2V1'), 'O2V1') u2p2 = Packets.AnswerObject(oid2, tid2, tid3, 0, makeChecksum('O2V1'), 'O2V1')
# undo 3 -> conflict # undo 3 -> conflict
u3p1 = protocol.answerTransactionInformation(tid3, '', '', '', (oid2, )) u3p1 = Packets.AnswerTransactionInformation(tid3, '', '', '', (oid2, ))
u3p2 = protocol.answerObject(oid2, tid3, tid3, 0, makeChecksum('O2V2'), 'O2V2') u3p2 = Packets.AnswerObject(oid2, tid3, tid3, 0, makeChecksum('O2V2'), 'O2V2')
u3p3 = protocol.answerStoreObject(conflicting=1, oid=oid2, serial=tid2) u3p3 = Packets.AnswerStoreObject(conflicting=1, oid=oid2, serial=tid2)
# undo 4 -> ok # undo 4 -> ok
u4p1 = protocol.answerTransactionInformation(tid3, '', '', '', (oid2, )) u4p1 = Packets.AnswerTransactionInformation(tid3, '', '', '', (oid2, ))
u4p2 = protocol.answerObject(oid2, tid3, tid3, 0, makeChecksum('O2V2'), 'O2V2') u4p2 = Packets.AnswerObject(oid2, tid3, tid3, 0, makeChecksum('O2V2'), 'O2V2')
u4p3 = protocol.answerStoreObject(conflicting=0, oid=oid2, serial=tid2) u4p3 = Packets.AnswerStoreObject(conflicting=0, oid=oid2, serial=tid2)
# test logic # test logic
packets = (u1p1, u1p2, u2p1, u2p2, u3p1, u3p2, u3p3, u3p1, u4p2, u4p3) packets = (u1p1, u1p2, u2p1, u2p2, u3p1, u3p2, u3p3, u3p1, u4p2, u4p3)
conn = Mock({ conn = Mock({
...@@ -729,8 +729,8 @@ class ClientApplicationTests(NeoTestBase): ...@@ -729,8 +729,8 @@ class ClientApplicationTests(NeoTestBase):
oid1, oid2 = self.makeOID(1), self.makeOID(2) oid1, oid2 = self.makeOID(1), self.makeOID(2)
# TIDs packets supplied by _waitMessage hook # TIDs packets supplied by _waitMessage hook
# TXN info packets # TXN info packets
p3 = protocol.answerTransactionInformation(tid1, '', '', '', (oid1, )) p3 = Packets.AnswerTransactionInformation(tid1, '', '', '', (oid1, ))
p4 = protocol.answerTransactionInformation(tid2, '', '', '', (oid2, )) p4 = Packets.AnswerTransactionInformation(tid2, '', '', '', (oid2, ))
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
'getUUID': ReturnValues(uuid1, uuid2), 'getUUID': ReturnValues(uuid1, uuid2),
...@@ -760,10 +760,10 @@ class ClientApplicationTests(NeoTestBase): ...@@ -760,10 +760,10 @@ class ClientApplicationTests(NeoTestBase):
tid1, tid2 = self.makeTID(1), self.makeTID(2) tid1, tid2 = self.makeTID(1), self.makeTID(2)
object_history = ( (tid1, 42), (tid2, 42),) object_history = ( (tid1, 42), (tid2, 42),)
# object history, first is a wrong oid, second is valid # object history, first is a wrong oid, second is valid
p2 = protocol.answerObjectHistory(oid, object_history) p2 = Packets.AnswerObjectHistory(oid, object_history)
# transaction history # transaction history
p3 = protocol.answerTransactionInformation(tid1, 'u', 'd', 'e', (oid, )) p3 = Packets.AnswerTransactionInformation(tid1, 'u', 'd', 'e', (oid, ))
p4 = protocol.answerTransactionInformation(tid2, 'u', 'd', 'e', (oid, )) p4 = Packets.AnswerTransactionInformation(tid2, 'u', 'd', 'e', (oid, ))
# faked environnement # faked environnement
conn = Mock({ conn = Mock({
'getNextId': 1, 'getNextId': 1,
...@@ -786,7 +786,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -786,7 +786,7 @@ class ClientApplicationTests(NeoTestBase):
self.assertEquals(result[0]['size'], 42) self.assertEquals(result[0]['size'], 42)
self.assertEquals(result[1]['size'], 42) self.assertEquals(result[1]['size'], 42)
def test_connectToPrimaryMasterNode(self): def test_connectToPrimaryNode(self):
# here we have three master nodes : # here we have three master nodes :
# the connection to the first will fail # the connection to the first will fail
# the second will have changed # the second will have changed
...@@ -843,7 +843,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -843,7 +843,7 @@ class ClientApplicationTests(NeoTestBase):
app.em = Mock({}) app.em = Mock({})
app.pt = Mock({ 'operational': False}) app.pt = Mock({ 'operational': False})
try: try:
app.master_conn = app._connectToPrimaryMasterNode() app.master_conn = app._connectToPrimaryNode()
self.assertEqual(len(all_passed), 1) self.assertEqual(len(all_passed), 1)
self.assertTrue(app.master_conn is not None) self.assertTrue(app.master_conn is not None)
self.assertTrue(app.pt.operational()) self.assertTrue(app.pt.operational())
...@@ -859,7 +859,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -859,7 +859,7 @@ class ClientApplicationTests(NeoTestBase):
def _waitMessage_hook(app, conn=None, msg_id=None, handler=None): def _waitMessage_hook(app, conn=None, msg_id=None, handler=None):
self.test_ok = True self.test_ok = True
_waitMessage_old = Application._waitMessage _waitMessage_old = Application._waitMessage
packet = protocol.askBeginTransaction(None) packet = Packets.AskBeginTransaction(None)
Application._waitMessage = _waitMessage_hook Application._waitMessage = _waitMessage_hook
try: try:
app._askStorage(conn, packet) app._askStorage(conn, packet)
...@@ -868,7 +868,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -868,7 +868,7 @@ class ClientApplicationTests(NeoTestBase):
# check packet sent, connection unlocked and dispatcher updated # check packet sent, connection unlocked and dispatcher updated
self.checkAskNewTid(conn) self.checkAskNewTid(conn)
self.assertEquals(len(conn.mockGetNamedCalls('unlock')), 1) self.assertEquals(len(conn.mockGetNamedCalls('unlock')), 1)
self.checkDispatcherRegisterCalled() self.checkDispatcherRegisterCalled(app, conn)
# and _waitMessage called # and _waitMessage called
self.assertTrue(self.test_ok) self.assertTrue(self.test_ok)
...@@ -885,7 +885,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -885,7 +885,7 @@ class ClientApplicationTests(NeoTestBase):
self.test_ok = True self.test_ok = True
_waitMessage_old = Application._waitMessage _waitMessage_old = Application._waitMessage
Application._waitMessage = _waitMessage_hook Application._waitMessage = _waitMessage_hook
packet = protocol.askBeginTransaction(None) packet = Packets.AskBeginTransaction(None)
try: try:
app._askPrimary(packet) app._askPrimary(packet)
finally: finally:
...@@ -894,7 +894,7 @@ class ClientApplicationTests(NeoTestBase): ...@@ -894,7 +894,7 @@ class ClientApplicationTests(NeoTestBase):
self.checkAskNewTid(conn) self.checkAskNewTid(conn)
self.assertEquals(len(conn.mockGetNamedCalls('lock')), 1) self.assertEquals(len(conn.mockGetNamedCalls('lock')), 1)
self.assertEquals(len(conn.mockGetNamedCalls('unlock')), 1) self.assertEquals(len(conn.mockGetNamedCalls('unlock')), 1)
self.checkDispatcherRegisterCalled() self.checkDispatcherRegisterCalled(app, conn)
# and _waitMessage called # and _waitMessage called
self.assertTrue(self.test_ok) self.assertTrue(self.test_ok)
# check NEOStorageError is raised when the primary connection is lost # check NEOStorageError is raised when the primary connection is lost
......
...@@ -47,7 +47,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -47,7 +47,7 @@ class ClientHandlerTests(NeoTestBase):
'unlock': None}) 'unlock': None})
def getDispatcher(self, queue=None): def getDispatcher(self, queue=None):
return Mock({'getQueue': queue, 'connectToPrimaryMasterNode': None}) return Mock({'getQueue': queue, 'connectToPrimaryNode': None})
def buildHandler(self, handler_class, app, dispatcher): def buildHandler(self, handler_class, app, dispatcher):
# some handlers do not accept the second argument # some handlers do not accept the second argument
...@@ -64,7 +64,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -64,7 +64,7 @@ class ClientHandlerTests(NeoTestBase):
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
client_handler = BaseHandler(None, dispatcher) client_handler = BaseHandler(None, dispatcher)
conn = self.getConnection() conn = self.getConnection()
client_handler.packetReceived(conn, protocol.ping()) client_handler.packetReceived(conn, Packets.Ping())
self.checkAnswerPacket(conn, protocol.PONG) self.checkAnswerPacket(conn, protocol.PONG)
def _testInitialMasterWithMethod(self, method): def _testInitialMasterWithMethod(self, method):
...@@ -77,7 +77,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -77,7 +77,7 @@ class ClientHandlerTests(NeoTestBase):
def _testMasterWithMethod(self, method, handler_class): def _testMasterWithMethod(self, method, handler_class):
uuid = self.getNewUUID() uuid = self.getNewUUID()
app = Mock({'connectToPrimaryMasterNode': None}) app = Mock({'connectToPrimaryNode': None})
app.primary_master_node = Mock({'getUUID': uuid}) app.primary_master_node = Mock({'getUUID': uuid})
app.master_conn = Mock({'close': None, 'getUUID': uuid, 'getAddress': ('127.0.0.1', 10000)}) app.master_conn = Mock({'close': None, 'getUUID': uuid, 'getAddress': ('127.0.0.1', 10000)})
dispatcher = self.getDispatcher() dispatcher = self.getDispatcher()
...@@ -209,7 +209,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -209,7 +209,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler.notReady(conn, None, None) client_handler.notReady(conn, None, None)
self.assertEquals(len(app.mockGetNamedCalls('setNodeNotReady')), 1) self.assertEquals(len(app.mockGetNamedCalls('setNodeNotReady')), 1)
def test_clientAcceptNodeIdentification(self): def test_clientAcceptIdentification(self):
class App: class App:
nm = Mock({'getByAddress': None}) nm = Mock({'getByAddress': None})
storage_node = None storage_node = None
...@@ -220,7 +220,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -220,7 +220,7 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection() conn = self.getConnection()
uuid = self.getNewUUID() uuid = self.getNewUUID()
app.uuid = 'C' * 16 app.uuid = 'C' * 16
client_handler.acceptNodeIdentification( client_handler.AcceptIdentification(
conn, None, conn, None,
NodeTypes.CLIENT, NodeTypes.CLIENT,
uuid, ('127.0.0.1', 10010), uuid, ('127.0.0.1', 10010),
...@@ -231,7 +231,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -231,7 +231,7 @@ class ClientHandlerTests(NeoTestBase):
self.assertEquals(app.pt, None) self.assertEquals(app.pt, None)
self.assertEquals(app.uuid, 'C' * 16) self.assertEquals(app.uuid, 'C' * 16)
def test_masterAcceptNodeIdentification(self): def test_masterAcceptIdentification(self):
node = Mock({'setUUID': None}) node = Mock({'setUUID': None})
class FakeLocal: class FakeLocal:
from Queue import Queue from Queue import Queue
...@@ -248,7 +248,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -248,7 +248,7 @@ class ClientHandlerTests(NeoTestBase):
uuid = self.getNewUUID() uuid = self.getNewUUID()
your_uuid = 'C' * 16 your_uuid = 'C' * 16
app.uuid = INVALID_UUID app.uuid = INVALID_UUID
client_handler.acceptNodeIdentification(conn, None, client_handler.AcceptIdentification(conn, None,
NodeTypes.MASTER, uuid, ('127.0.0.1', 10010), 10, 2, your_uuid) NodeTypes.MASTER, uuid, ('127.0.0.1', 10010), 10, 2, your_uuid)
self.checkNotClosed(conn) self.checkNotClosed(conn)
self.checkUUIDSet(conn, uuid) self.checkUUIDSet(conn, uuid)
...@@ -256,7 +256,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -256,7 +256,7 @@ class ClientHandlerTests(NeoTestBase):
self.assertTrue(app.pt is not None) self.assertTrue(app.pt is not None)
self.assertEquals(app.uuid, your_uuid) self.assertEquals(app.uuid, your_uuid)
def test_storageAcceptNodeIdentification(self): def test_storageAcceptIdentification(self):
node = Mock({'setUUID': None}) node = Mock({'setUUID': None})
class App: class App:
nm = Mock({'getByAddress': node}) nm = Mock({'getByAddress': node})
...@@ -268,7 +268,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -268,7 +268,7 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection() conn = self.getConnection()
uuid = self.getNewUUID() uuid = self.getNewUUID()
app.uuid = 'C' * 16 app.uuid = 'C' * 16
client_handler.acceptNodeIdentification(conn, None, client_handler.AcceptIdentification(conn, None,
NodeTypes.STORAGE, uuid, ('127.0.0.1', 10010), 0, 0, INVALID_UUID) NodeTypes.STORAGE, uuid, ('127.0.0.1', 10010), 0, 0, INVALID_UUID)
self.checkNotClosed(conn) self.checkNotClosed(conn)
self.checkUUIDSet(conn, uuid) self.checkUUIDSet(conn, uuid)
...@@ -279,7 +279,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -279,7 +279,7 @@ class ClientHandlerTests(NeoTestBase):
self.assertRaises(UnexpectedPacketError, method, *args, **dict(kw)) self.assertRaises(UnexpectedPacketError, method, *args, **dict(kw))
# Master node handler # Master node handler
def test_nonMasterAnswerPrimaryMaster(self): def test_nonMasterAnswerPrimary(self):
for node_type in (NodeTypes.CLIENT, NodeTypes.STORAGE): for node_type in (NodeTypes.CLIENT, NodeTypes.STORAGE):
node = Mock({'getType': node_type}) node = Mock({'getType': node_type})
class App: class App:
...@@ -288,12 +288,12 @@ class ClientHandlerTests(NeoTestBase): ...@@ -288,12 +288,12 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.answerPrimaryMaster(conn, None, 0, []) client_handler.answerPrimary(conn, None, 0, [])
# Check that nothing happened # Check that nothing happened
self.assertEqual(len(app.nm.mockGetNamedCalls('getByAddress')), 0) self.assertEqual(len(app.nm.mockGetNamedCalls('getByAddress')), 0)
self.assertEqual(len(app.nm.mockGetNamedCalls('add')), 0) self.assertEqual(len(app.nm.mockGetNamedCalls('add')), 0)
def test_unknownNodeAnswerPrimaryMaster(self): def test_unknownNodeAnswerPrimary(self):
node = Mock({'getType': NodeTypes.MASTER}) node = Mock({'getType': NodeTypes.MASTER})
class App: class App:
nm = Mock({'getByAddress': None, 'add': None}) nm = Mock({'getByAddress': None, 'add': None})
...@@ -302,7 +302,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -302,7 +302,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
test_master_list = [(('127.0.0.1', 10010), self.getNewUUID())] test_master_list = [(('127.0.0.1', 10010), self.getNewUUID())]
client_handler.answerPrimaryMaster(conn, None, INVALID_UUID, test_master_list) client_handler.answerPrimary(conn, None, INVALID_UUID, test_master_list)
# Check that yet-unknown master node got added # Check that yet-unknown master node got added
getByAddress_call_list = app.nm.mockGetNamedCalls('getNodeByServer') getByAddress_call_list = app.nm.mockGetNamedCalls('getNodeByServer')
add_call_list = app.nm.mockGetNamedCalls('add') add_call_list = app.nm.mockGetNamedCalls('add')
...@@ -318,7 +318,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -318,7 +318,7 @@ class ClientHandlerTests(NeoTestBase):
# hence INVALID_UUID in call). # hence INVALID_UUID in call).
self.assertEquals(app.primary_master_node, None) self.assertEquals(app.primary_master_node, None)
def test_knownNodeUnknownUUIDNodeAnswerPrimaryMaster(self): def test_knownNodeUnknownUUIDNodeAnswerPrimary(self):
node = Mock({'getType': NodeTypes.MASTER, 'getUUID': None, 'setUUID': None}) node = Mock({'getType': NodeTypes.MASTER, 'getUUID': None, 'setUUID': None})
class App: class App:
nm = Mock({'getByAddress': node, 'add': None}) nm = Mock({'getByAddress': node, 'add': None})
...@@ -328,7 +328,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -328,7 +328,7 @@ class ClientHandlerTests(NeoTestBase):
conn = self.getConnection() conn = self.getConnection()
test_node_uuid = self.getNewUUID() test_node_uuid = self.getNewUUID()
test_master_list = [(('127.0.0.1', 10010), test_node_uuid)] test_master_list = [(('127.0.0.1', 10010), test_node_uuid)]
client_handler.answerPrimaryMaster(conn, None, INVALID_UUID, test_master_list) client_handler.answerPrimary(conn, None, INVALID_UUID, test_master_list)
# Test sanity checks # Test sanity checks
getByAddress_call_list = app.nm.mockGetNamedCalls('getNodeByServer') getByAddress_call_list = app.nm.mockGetNamedCalls('getNodeByServer')
self.assertEqual(len(getByAddress_call_list), 1) self.assertEqual(len(getByAddress_call_list), 1)
...@@ -344,7 +344,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -344,7 +344,7 @@ class ClientHandlerTests(NeoTestBase):
# hence INVALID_UUID in call). # hence INVALID_UUID in call).
self.assertEquals(app.primary_master_node, None) self.assertEquals(app.primary_master_node, None)
def test_knownNodeKnownUUIDNodeAnswerPrimaryMaster(self): def test_knownNodeKnownUUIDNodeAnswerPrimary(self):
test_node_uuid = self.getNewUUID() test_node_uuid = self.getNewUUID()
node = Mock({'getType': NodeTypes.MASTER, 'getUUID': test_node_uuid, 'setUUID': None}) node = Mock({'getType': NodeTypes.MASTER, 'getUUID': test_node_uuid, 'setUUID': None})
class App: class App:
...@@ -354,7 +354,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -354,7 +354,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
test_master_list = [(('127.0.0.1', 10010), test_node_uuid)] test_master_list = [(('127.0.0.1', 10010), test_node_uuid)]
client_handler.answerPrimaryMaster(conn, None, INVALID_UUID, test_master_list) client_handler.answerPrimary(conn, None, INVALID_UUID, test_master_list)
# Test sanity checks # Test sanity checks
getByAddress_call_list = app.nm.mockGetNamedCalls('getNodeByServer') getByAddress_call_list = app.nm.mockGetNamedCalls('getNodeByServer')
self.assertEqual(len(getByAddress_call_list), 1) self.assertEqual(len(getByAddress_call_list), 1)
...@@ -374,7 +374,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -374,7 +374,7 @@ class ClientHandlerTests(NeoTestBase):
# TODO: test known node, known but different uuid (not detected in code, # TODO: test known node, known but different uuid (not detected in code,
# desired behaviour unknown) # desired behaviour unknown)
def test_alreadyDifferentPrimaryAnswerPrimaryMaster(self): def test_alreadyDifferentPrimaryAnswerPrimary(self):
test_node_uuid = self.getNewUUID() test_node_uuid = self.getNewUUID()
test_primary_node_uuid = test_node_uuid test_primary_node_uuid = test_node_uuid
while test_primary_node_uuid == test_node_uuid: while test_primary_node_uuid == test_node_uuid:
...@@ -391,7 +391,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -391,7 +391,7 @@ class ClientHandlerTests(NeoTestBase):
# If primary master is already set *and* is not given primary master # If primary master is already set *and* is not given primary master
# handle call raises. # handle call raises.
# Check that the call doesn't raise # Check that the call doesn't raise
client_handler.answerPrimaryMaster(conn, None, test_node_uuid, []) client_handler.answerPrimary(conn, None, test_node_uuid, [])
# Check that the primary master changed # Check that the primary master changed
self.assertTrue(app.primary_master_node is node) self.assertTrue(app.primary_master_node is node)
# Test sanity checks # Test sanity checks
...@@ -401,7 +401,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -401,7 +401,7 @@ class ClientHandlerTests(NeoTestBase):
getByAddress_call_list = app.nm.mockGetNamedCalls('getNodeByServer') getByAddress_call_list = app.nm.mockGetNamedCalls('getNodeByServer')
self.assertEqual(len(getByAddress_call_list), 0) self.assertEqual(len(getByAddress_call_list), 0)
def test_alreadySamePrimaryAnswerPrimaryMaster(self): def test_alreadySamePrimaryAnswerPrimary(self):
test_node_uuid = self.getNewUUID() test_node_uuid = self.getNewUUID()
node = Mock({'getType': NodeTypes.MASTER, 'getUUID': test_node_uuid, 'setUUID': None}) node = Mock({'getType': NodeTypes.MASTER, 'getUUID': test_node_uuid, 'setUUID': None})
class App: class App:
...@@ -411,11 +411,11 @@ class ClientHandlerTests(NeoTestBase): ...@@ -411,11 +411,11 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.answerPrimaryMaster(conn, None, test_node_uuid, []) client_handler.answerPrimary(conn, None, test_node_uuid, [])
# Check that primary node is (still) node. # Check that primary node is (still) node.
self.assertTrue(app.primary_master_node is node) self.assertTrue(app.primary_master_node is node)
def test_unknownNewPrimaryAnswerPrimaryMaster(self): def test_unknownNewPrimaryAnswerPrimary(self):
test_node_uuid = self.getNewUUID() test_node_uuid = self.getNewUUID()
test_primary_node_uuid = test_node_uuid test_primary_node_uuid = test_node_uuid
while test_primary_node_uuid == test_node_uuid: while test_primary_node_uuid == test_node_uuid:
...@@ -428,7 +428,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -428,7 +428,7 @@ class ClientHandlerTests(NeoTestBase):
app = App() app = App()
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
client_handler.answerPrimaryMaster(conn, None, test_primary_node_uuid, []) client_handler.answerPrimary(conn, None, test_primary_node_uuid, [])
# Test sanity checks # Test sanity checks
getByUUID_call_list = app.nm.mockGetNamedCalls('getNodeByUUID') getByUUID_call_list = app.nm.mockGetNamedCalls('getNodeByUUID')
self.assertEqual(len(getByUUID_call_list), 1) self.assertEqual(len(getByUUID_call_list), 1)
...@@ -436,7 +436,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -436,7 +436,7 @@ class ClientHandlerTests(NeoTestBase):
# Check that primary node was not updated. # Check that primary node was not updated.
self.assertTrue(app.primary_master_node is None) self.assertTrue(app.primary_master_node is None)
def test_AnswerPrimaryMaster(self): def test_AnswerPrimary(self):
test_node_uuid = self.getNewUUID() test_node_uuid = self.getNewUUID()
node = Mock({'getType': NodeTypes.MASTER, 'getUUID': test_node_uuid, 'setUUID': None}) node = Mock({'getType': NodeTypes.MASTER, 'getUUID': test_node_uuid, 'setUUID': None})
class App: class App:
...@@ -447,7 +447,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -447,7 +447,7 @@ class ClientHandlerTests(NeoTestBase):
client_handler = PrimaryBootstrapHandler(app) client_handler = PrimaryBootstrapHandler(app)
conn = self.getConnection() conn = self.getConnection()
test_master_list = [(('127.0.0.1', 10010), test_node_uuid)] test_master_list = [(('127.0.0.1', 10010), test_node_uuid)]
client_handler.answerPrimaryMaster(conn, None, test_node_uuid, test_master_list) client_handler.answerPrimary(conn, None, test_node_uuid, test_master_list)
# Test sanity checks # Test sanity checks
getByUUID_call_list = app.nm.mockGetNamedCalls('getNodeByUUID') getByUUID_call_list = app.nm.mockGetNamedCalls('getNodeByUUID')
self.assertEqual(len(getByUUID_call_list), 1) self.assertEqual(len(getByUUID_call_list), 1)
...@@ -652,7 +652,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -652,7 +652,7 @@ class ClientHandlerTests(NeoTestBase):
self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
def test_noPrimaryMasterNotifyPartitionChanges(self): def test_noPrimaryNotifyPartitionChanges(self):
node = Mock({'getType': NodeTypes.MASTER}) node = Mock({'getType': NodeTypes.MASTER})
class App: class App:
nm = Mock({'getByUUID': node}) nm = Mock({'getByUUID': node})
...@@ -667,7 +667,7 @@ class ClientHandlerTests(NeoTestBase): ...@@ -667,7 +667,7 @@ class ClientHandlerTests(NeoTestBase):
self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('setCell')), 0)
self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('removeCell')), 0)
def test_nonPrimaryMasterNotifyPartitionChanges(self): def test_nonPrimaryNotifyPartitionChanges(self):
test_master_uuid = self.getNewUUID() test_master_uuid = self.getNewUUID()
test_sender_uuid = test_master_uuid test_sender_uuid = test_master_uuid
while test_sender_uuid == test_master_uuid: while test_sender_uuid == test_master_uuid:
......
...@@ -292,7 +292,7 @@ class NEOCluster(object): ...@@ -292,7 +292,7 @@ class NEOCluster(object):
def _killMaster(self, primary=False, all=False): def _killMaster(self, primary=False, all=False):
killed_uuid_list = [] killed_uuid_list = []
primary_uuid = self.neoctl.getPrimaryMaster() primary_uuid = self.neoctl.getPrimary()
for master in self.getMasterProcessList(): for master in self.getMasterProcessList():
master_uuid = master.getUUID() master_uuid = master.getUUID()
is_primary = master_uuid == primary_uuid is_primary = master_uuid == primary_uuid
...@@ -304,7 +304,7 @@ class NEOCluster(object): ...@@ -304,7 +304,7 @@ class NEOCluster(object):
break break
return killed_uuid_list return killed_uuid_list
def killPrimaryMaster(self): def killPrimary(self):
return self._killMaster(primary=True) return self._killMaster(primary=True)
def killSecondaryMaster(self, all=False): def killSecondaryMaster(self, all=False):
...@@ -312,7 +312,7 @@ class NEOCluster(object): ...@@ -312,7 +312,7 @@ class NEOCluster(object):
def killMasters(self): def killMasters(self):
secondary_list = self.killSecondaryMaster(all=True) secondary_list = self.killSecondaryMaster(all=True)
primary_list = self.killPrimaryMaster() primary_list = self.killPrimary()
return secondary_list + primary_list return secondary_list + primary_list
def killStorage(self, all=False): def killStorage(self, all=False):
...@@ -347,9 +347,9 @@ class NEOCluster(object): ...@@ -347,9 +347,9 @@ class NEOCluster(object):
def getMasterNodeState(self, uuid): def getMasterNodeState(self, uuid):
return self.__getNodeState(NodeTypes.MASTER, uuid) return self.__getNodeState(NodeTypes.MASTER, uuid)
def getPrimaryMaster(self): def getPrimary(self):
try: try:
current_try = self.neoctl.getPrimaryMaster() current_try = self.neoctl.getPrimary()
except NotReadyException: except NotReadyException:
current_try = None current_try = None
return current_try return current_try
...@@ -394,9 +394,9 @@ class NEOCluster(object): ...@@ -394,9 +394,9 @@ class NEOCluster(object):
self.__expectNodeState(NodeTypes.STORAGE, uuid, state, self.__expectNodeState(NodeTypes.STORAGE, uuid, state,
timeout,delay) timeout,delay)
def expectPrimaryMaster(self, uuid=None, timeout=0, delay=1): def expectPrimary(self, uuid=None, timeout=0, delay=1):
def callback(last_try): def callback(last_try):
current_try = self.getPrimaryMaster() current_try = self.getPrimary()
if None not in (uuid, current_try) and uuid != current_try: if None not in (uuid, current_try) and uuid != current_try:
raise AssertionError, 'An unexpected primary arised: %r, ' \ raise AssertionError, 'An unexpected primary arised: %r, ' \
'expected %r' % (dump(current_try), dump(uuid)) 'expected %r' % (dump(current_try), dump(uuid))
......
...@@ -49,25 +49,25 @@ class MasterTests(NEOFunctionalTest): ...@@ -49,25 +49,25 @@ class MasterTests(NEOFunctionalTest):
# Check node state has changed. # Check node state has changed.
self.neo.expectMasterState(uuid, None) self.neo.expectMasterState(uuid, None)
def testStoppingPrimaryMasterWithTwoSecondaries(self): def testStoppingPrimaryWithTwoSecondaries(self):
# Wait for masters to stabilize # Wait for masters to stabilize
self.neo.expectAllMasters(MASTER_NODE_COUNT) self.neo.expectAllMasters(MASTER_NODE_COUNT)
# Kill # Kill
killed_uuid_list = self.neo.killPrimaryMaster() killed_uuid_list = self.neo.killPrimary()
# Test sanity check. # Test sanity check.
self.assertEqual(len(killed_uuid_list), 1) self.assertEqual(len(killed_uuid_list), 1)
uuid = killed_uuid_list[0] uuid = killed_uuid_list[0]
# Check the state of the primary we just killed # Check the state of the primary we just killed
self.neo.expectMasterState(uuid, (None, NodeStates.UNKNOWN)) self.neo.expectMasterState(uuid, (None, NodeStates.UNKNOWN))
self.assertEqual(self.neo.getPrimaryMaster(), None) self.assertEqual(self.neo.getPrimary(), None)
# Check that a primary master arised. # Check that a primary master arised.
self.neo.expectPrimaryMaster(timeout=10) self.neo.expectPrimary(timeout=10)
# Check that the uuid really changed. # Check that the uuid really changed.
new_uuid = self.neo.getPrimaryMaster() new_uuid = self.neo.getPrimary()
self.assertNotEqual(new_uuid, uuid) self.assertNotEqual(new_uuid, uuid)
def testStoppingPrimaryMasterWithOneSecondary(self): def testStoppingPrimaryWithOneSecondary(self):
self.neo.expectAllMasters(MASTER_NODE_COUNT, self.neo.expectAllMasters(MASTER_NODE_COUNT,
state=NodeStates.RUNNING) state=NodeStates.RUNNING)
...@@ -78,17 +78,17 @@ class MasterTests(NEOFunctionalTest): ...@@ -78,17 +78,17 @@ class MasterTests(NEOFunctionalTest):
self.neo.expectMasterState(killed_uuid_list[0], None) self.neo.expectMasterState(killed_uuid_list[0], None)
self.assertEqual(len(self.neo.getMasterList()), 2) self.assertEqual(len(self.neo.getMasterList()), 2)
killed_uuid_list = self.neo.killPrimaryMaster() killed_uuid_list = self.neo.killPrimary()
# Test sanity check. # Test sanity check.
self.assertEqual(len(killed_uuid_list), 1) self.assertEqual(len(killed_uuid_list), 1)
uuid = killed_uuid_list[0] uuid = killed_uuid_list[0]
# Check the state of the primary we just killed # Check the state of the primary we just killed
self.neo.expectMasterState(uuid, (None, NodeStates.UNKNOWN)) self.neo.expectMasterState(uuid, (None, NodeStates.UNKNOWN))
self.assertEqual(self.neo.getPrimaryMaster(), None) self.assertEqual(self.neo.getPrimary(), None)
# Check that a primary master arised. # Check that a primary master arised.
self.neo.expectPrimaryMaster(timeout=10) self.neo.expectPrimary(timeout=10)
# Check that the uuid really changed. # Check that the uuid really changed.
new_uuid = self.neo.getPrimaryMaster() new_uuid = self.neo.getPrimary()
self.assertNotEqual(new_uuid, uuid) self.assertNotEqual(new_uuid, uuid)
def testMasterSequentialStart(self): def testMasterSequentialStart(self):
...@@ -104,7 +104,7 @@ class MasterTests(NEOFunctionalTest): ...@@ -104,7 +104,7 @@ class MasterTests(NEOFunctionalTest):
first_master.start() first_master.start()
first_master_uuid = first_master.getUUID() first_master_uuid = first_master.getUUID()
# Check that the master node we started elected itself. # Check that the master node we started elected itself.
self.neo.expectPrimaryMaster(first_master_uuid, timeout=30) self.neo.expectPrimary(first_master_uuid, timeout=30)
# Check that no other node is known as running. # Check that no other node is known as running.
self.assertEqual(len(self.neo.getMasterList( self.assertEqual(len(self.neo.getMasterList(
state=NodeStates.RUNNING)), 1) state=NodeStates.RUNNING)), 1)
...@@ -119,7 +119,7 @@ class MasterTests(NEOFunctionalTest): ...@@ -119,7 +119,7 @@ class MasterTests(NEOFunctionalTest):
self.neo.expectMasterState(second_master.getUUID(), self.neo.expectMasterState(second_master.getUUID(),
NodeStates.RUNNING) NodeStates.RUNNING)
# Check that the primary master didn't change. # Check that the primary master didn't change.
self.assertEqual(self.neo.getPrimaryMaster(), first_master_uuid) self.assertEqual(self.neo.getPrimary(), first_master_uuid)
# Start a third master. # Start a third master.
third_master = master_list[2] third_master = master_list[2]
...@@ -131,7 +131,7 @@ class MasterTests(NEOFunctionalTest): ...@@ -131,7 +131,7 @@ class MasterTests(NEOFunctionalTest):
self.neo.expectMasterState(third_master.getUUID(), self.neo.expectMasterState(third_master.getUUID(),
NodeStates.RUNNING) NodeStates.RUNNING)
# Check that the primary master didn't change. # Check that the primary master didn't change.
self.assertEqual(self.neo.getPrimaryMaster(), first_master_uuid) self.assertEqual(self.neo.getPrimary(), first_master_uuid)
def test_suite(): def test_suite():
return unittest.makeSuite(MasterTests) return unittest.makeSuite(MasterTests)
......
...@@ -20,7 +20,7 @@ from mock import Mock ...@@ -20,7 +20,7 @@ from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo import protocol from neo import protocol
from neo.protocol import Packet, PacketTypes, NodeTypes, NodeStates from neo.protocol import Packet, Packets, NodeTypes, NodeStates
from neo.master.handlers.client import ClientServiceHandler from neo.master.handlers.client import ClientServiceHandler
from neo.master.app import Application from neo.master.app import Application
from neo.exception import OperationFailure from neo.exception import OperationFailure
...@@ -74,7 +74,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -74,7 +74,7 @@ class MasterClientHandlerTests(NeoTestBase):
def test_05_notifyNodeInformation(self): def test_05_notifyNodeInformation(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.NOTIFY_NODE_INFORMATION) packet = Packets.NotifyNodeInformation()
# tell the master node that is not running any longer, it must raises # tell the master node that is not running any longer, it must raises
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
node_list = [(NodeTypes.MASTER, ('127.0.0.1', self.master_port), node_list = [(NodeTypes.MASTER, ('127.0.0.1', self.master_port),
...@@ -145,7 +145,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -145,7 +145,7 @@ class MasterClientHandlerTests(NeoTestBase):
def test_06_answerLastIDs(self): def test_06_answerLastIDs(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ANSWER_LAST_IDS) packet = Packets.AnswerLastIDs()
loid = self.app.loid loid = self.app.loid
ltid = self.app.ltid ltid = self.app.ltid
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
...@@ -172,7 +172,8 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -172,7 +172,8 @@ class MasterClientHandlerTests(NeoTestBase):
def test_07_askBeginTransaction(self): def test_07_askBeginTransaction(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ASK_BEGIN_TRANSACTION) packet = Packets.AskBeginTransaction()
packet.setId(0)
ltid = self.app.ltid ltid = self.app.ltid
# client call it # client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
...@@ -183,11 +184,11 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -183,11 +184,11 @@ class MasterClientHandlerTests(NeoTestBase):
tid = self.app.finishing_transaction_dict.keys()[0] tid = self.app.finishing_transaction_dict.keys()[0]
self.assertEquals(tid, self.app.ltid) self.assertEquals(tid, self.app.ltid)
def test_08_askNewOIDs(self): def test_08_askNewOIDs(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ASK_NEW_OIDS) packet = Packets.AskNewOIDs()
packet.setId(0)
loid = self.app.loid loid = self.app.loid
# client call it # client call it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
...@@ -198,7 +199,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -198,7 +199,7 @@ class MasterClientHandlerTests(NeoTestBase):
def test_09_finishTransaction(self): def test_09_finishTransaction(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.FINISH_TRANSACTION) packet = Packets.FinishTransaction()
packet.setId(9) packet.setId(9)
# give an older tid than the PMN known, must abort # give an older tid than the PMN known, must abort
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
...@@ -236,7 +237,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -236,7 +237,7 @@ class MasterClientHandlerTests(NeoTestBase):
def test_11_abortTransaction(self): def test_11_abortTransaction(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ABORT_TRANSACTION) packet = Packets.AbortTransaction()
# give a bad tid, must not failed, just ignored it # give a bad tid, must not failed, just ignored it
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port) client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, port=self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
...@@ -255,7 +256,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -255,7 +256,7 @@ class MasterClientHandlerTests(NeoTestBase):
def test_12_askLastIDs(self): def test_12_askLastIDs(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ASK_LAST_IDS) packet = Packets.AskLastIDs()
# give a uuid # give a uuid
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
ptid = self.app.pt.getID() ptid = self.app.pt.getID()
...@@ -272,7 +273,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -272,7 +273,7 @@ class MasterClientHandlerTests(NeoTestBase):
def test_13_askUnfinishedTransactions(self): def test_13_askUnfinishedTransactions(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ASK_UNFINISHED_TRANSACTIONS) packet = Packets.AskUnfinishedTransactions()
# give a uuid # give a uuid
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
service.askUnfinishedTransactions(conn, packet) service.askUnfinishedTransactions(conn, packet)
...@@ -323,7 +324,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -323,7 +324,7 @@ class MasterClientHandlerTests(NeoTestBase):
port = self.client_port) port = self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
packet = Packet(msg_type=ASK_BEGIN_TRANSACTION) packet = AskBeginTransaction()
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
...@@ -367,7 +368,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -367,7 +368,7 @@ class MasterClientHandlerTests(NeoTestBase):
port = self.client_port) port = self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
packet = Packet(msg_type=ASK_BEGIN_TRANSACTION) packet = AskBeginTransaction()
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
...@@ -411,7 +412,7 @@ class MasterClientHandlerTests(NeoTestBase): ...@@ -411,7 +412,7 @@ class MasterClientHandlerTests(NeoTestBase):
port = self.client_port) port = self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
packet = Packet(msg_type=ASK_BEGIN_TRANSACTION) packet = AskBeginTransaction()
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
import unittest import unittest
from mock import Mock from mock import Mock
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo import protocol from neo.protocol import Packet, Packets, NodeTypes, NodeStates, INVALID_UUID
from neo.protocol import Packet, PacketTypes, NodeTypes, NodeStates, INVALID_UUID
from neo.master.handlers.election import ClientElectionHandler, ServerElectionHandler from neo.master.handlers.election import ClientElectionHandler, ServerElectionHandler
from neo.master.app import Application from neo.master.app import Application
from neo.exception import ElectionFailure from neo.exception import ElectionFailure
...@@ -92,7 +91,7 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -92,7 +91,7 @@ class MasterClientElectionTests(NeoTestBase):
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.master_port)}) "getAddress" : ("127.0.0.1", self.master_port)})
self.election.connectionCompleted(conn) self.election.connectionCompleted(conn)
self.checkAskPrimaryMaster(conn) self.checkAskPrimary(conn)
def test_03_connectionFailed(self): def test_03_connectionFailed(self):
...@@ -112,22 +111,23 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -112,22 +111,23 @@ class MasterClientElectionTests(NeoTestBase):
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(), self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(),
NodeStates.TEMPORARILY_DOWN) NodeStates.TEMPORARILY_DOWN)
def test_11_askPrimaryMaster(self): def test_11_askPrimary(self):
election = self.election election = self.election
uuid = self.identifyToMasterNode(port=self.master_port) uuid = self.identifyToMasterNode(port=self.master_port)
packet = protocol.askPrimaryMaster() packet = Packets.AskPrimary()
packet.setId(0)
conn = Mock({"_addPacket" : None, conn = Mock({"_addPacket" : None,
"getUUID" : uuid, "getUUID" : uuid,
"isServer" : True, "isServer" : True,
"getConnector": Mock(), "getConnector": Mock(),
"getAddress" : ("127.0.0.1", self.master_port)}) "getAddress" : ("127.0.0.1", self.master_port)})
self.assertEqual(len(self.app.nm.getMasterList()), 2) self.assertEqual(len(self.app.nm.getMasterList()), 2)
election.askPrimaryMaster(conn, packet) election.askPrimary(conn, packet)
self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1) self.assertEquals(len(conn.mockGetNamedCalls("answer")), 1)
self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0) self.assertEquals(len(conn.mockGetNamedCalls("abort")), 0)
self.checkAnswerPrimaryMaster(conn) self.checkAnswerPrimary(conn)
def test_09_answerPrimaryMaster1(self): def test_09_answerPrimary1(self):
# test with master node and greater uuid # test with master node and greater uuid
uuid = self.getNewUUID() uuid = self.getNewUUID()
if uuid < self.app.uuid: if uuid < self.app.uuid:
...@@ -135,12 +135,12 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -135,12 +135,12 @@ class MasterClientElectionTests(NeoTestBase):
conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port), conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port),
connector_handler = DoNothingConnector) connector_handler = DoNothingConnector)
conn.setUUID(uuid) conn.setUUID(uuid)
p = protocol.askPrimaryMaster() p = Packets.AskPrimary()
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
self.assertEqual(len(self.app.negotiating_master_node_set), 1) self.assertEqual(len(self.app.negotiating_master_node_set), 1)
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.assertEqual(len(self.app.nm.getMasterList()), 1) self.assertEqual(len(self.app.nm.getMasterList()), 1)
self.election.answerPrimaryMaster(conn, p, INVALID_UUID, []) self.election.answerPrimary(conn, p, INVALID_UUID, [])
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.assertEqual(self.app.primary, False) self.assertEqual(self.app.primary, False)
self.assertEqual(len(self.app.nm.getMasterList()), 1) self.assertEqual(len(self.app.nm.getMasterList()), 1)
...@@ -148,7 +148,7 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -148,7 +148,7 @@ class MasterClientElectionTests(NeoTestBase):
self.assertEqual(len(self.app.negotiating_master_node_set), 0) self.assertEqual(len(self.app.negotiating_master_node_set), 0)
def test_09_answerPrimaryMaster2(self): def test_09_answerPrimary2(self):
# test with master node and lesser uuid # test with master node and lesser uuid
uuid = self.getNewUUID() uuid = self.getNewUUID()
if uuid > self.app.uuid: if uuid > self.app.uuid:
...@@ -156,12 +156,12 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -156,12 +156,12 @@ class MasterClientElectionTests(NeoTestBase):
conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port), conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port),
connector_handler = DoNothingConnector) connector_handler = DoNothingConnector)
conn.setUUID(uuid) conn.setUUID(uuid)
p = protocol.askPrimaryMaster() p = Packets.AskPrimary()
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
self.assertEqual(len(self.app.negotiating_master_node_set), 1) self.assertEqual(len(self.app.negotiating_master_node_set), 1)
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.assertEqual(len(self.app.nm.getMasterList()), 1) self.assertEqual(len(self.app.nm.getMasterList()), 1)
self.election.answerPrimaryMaster(conn, p, INVALID_UUID, []) self.election.answerPrimary(conn, p, INVALID_UUID, [])
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.assertEqual(self.app.primary, None) self.assertEqual(self.app.primary, None)
self.assertEqual(len(self.app.nm.getMasterList()), 1) self.assertEqual(len(self.app.nm.getMasterList()), 1)
...@@ -169,20 +169,20 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -169,20 +169,20 @@ class MasterClientElectionTests(NeoTestBase):
self.assertEqual(len(self.app.negotiating_master_node_set), 0) self.assertEqual(len(self.app.negotiating_master_node_set), 0)
def test_09_answerPrimaryMaster3(self): def test_09_answerPrimary3(self):
# test with master node and given uuid for PMN # test with master node and given uuid for PMN
uuid = self.getNewUUID() uuid = self.getNewUUID()
conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port), conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port),
connector_handler = DoNothingConnector) connector_handler = DoNothingConnector)
conn.setUUID(uuid) conn.setUUID(uuid)
p = protocol.askPrimaryMaster() p = Packets.AskPrimary()
self.app.nm.createMaster(address=("127.0.0.1", self.master_port), uuid=uuid) self.app.nm.createMaster(address=("127.0.0.1", self.master_port), uuid=uuid)
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
self.assertEqual(len(self.app.negotiating_master_node_set), 1) self.assertEqual(len(self.app.negotiating_master_node_set), 1)
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.assertEqual(len(self.app.nm.getMasterList()), 2) self.assertEqual(len(self.app.nm.getMasterList()), 2)
self.assertEqual(self.app.primary_master_node, None) self.assertEqual(self.app.primary_master_node, None)
self.election.answerPrimaryMaster(conn, p, uuid, []) self.election.answerPrimary(conn, p, uuid, [])
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.assertEqual(len(self.app.nm.getMasterList()), 2) self.assertEqual(len(self.app.nm.getMasterList()), 2)
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
...@@ -191,19 +191,19 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -191,19 +191,19 @@ class MasterClientElectionTests(NeoTestBase):
self.assertEqual(self.app.primary, False) self.assertEqual(self.app.primary, False)
def test_09_answerPrimaryMaster4(self): def test_09_answerPrimary4(self):
# test with master node and unknown uuid for PMN # test with master node and unknown uuid for PMN
uuid = self.getNewUUID() uuid = self.getNewUUID()
conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port), conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port),
connector_handler = DoNothingConnector) connector_handler = DoNothingConnector)
conn.setUUID(uuid) conn.setUUID(uuid)
p = protocol.askPrimaryMaster() p = Packets.AskPrimary()
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
self.assertEqual(len(self.app.negotiating_master_node_set), 1) self.assertEqual(len(self.app.negotiating_master_node_set), 1)
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.assertEqual(len(self.app.nm.getMasterList()), 1) self.assertEqual(len(self.app.nm.getMasterList()), 1)
self.assertEqual(self.app.primary_master_node, None) self.assertEqual(self.app.primary_master_node, None)
self.election.answerPrimaryMaster(conn, p, uuid, []) self.election.answerPrimary(conn, p, uuid, [])
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.assertEqual(len(self.app.nm.getMasterList()), 1) self.assertEqual(len(self.app.nm.getMasterList()), 1)
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
...@@ -212,13 +212,13 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -212,13 +212,13 @@ class MasterClientElectionTests(NeoTestBase):
self.assertEqual(self.app.primary, None) self.assertEqual(self.app.primary, None)
def test_09_answerPrimaryMaster5(self): def test_09_answerPrimary5(self):
# test with master node and new uuid for PMN # test with master node and new uuid for PMN
uuid = self.getNewUUID() uuid = self.getNewUUID()
conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port), conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port),
connector_handler = DoNothingConnector) connector_handler = DoNothingConnector)
conn.setUUID(uuid) conn.setUUID(uuid)
p = protocol.askPrimaryMaster() p = Packets.AskPrimary()
self.app.nm.createMaster(address=("127.0.0.1", self.master_port), uuid=uuid) self.app.nm.createMaster(address=("127.0.0.1", self.master_port), uuid=uuid)
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
self.assertEqual(len(self.app.negotiating_master_node_set), 1) self.assertEqual(len(self.app.negotiating_master_node_set), 1)
...@@ -226,7 +226,7 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -226,7 +226,7 @@ class MasterClientElectionTests(NeoTestBase):
self.assertEqual(len(self.app.nm.getMasterList()), 2) self.assertEqual(len(self.app.nm.getMasterList()), 2)
self.assertEqual(self.app.primary_master_node, None) self.assertEqual(self.app.primary_master_node, None)
master_uuid = self.getNewUUID() master_uuid = self.getNewUUID()
self.election.answerPrimaryMaster(conn, p, master_uuid, self.election.answerPrimary(conn, p, master_uuid,
[(("127.0.0.1", self.master_port+1), master_uuid,)]) [(("127.0.0.1", self.master_port+1), master_uuid,)])
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.assertEqual(len(self.app.nm.getMasterList()), 3) self.assertEqual(len(self.app.nm.getMasterList()), 3)
...@@ -235,7 +235,7 @@ class MasterClientElectionTests(NeoTestBase): ...@@ -235,7 +235,7 @@ class MasterClientElectionTests(NeoTestBase):
self.assertNotEqual(self.app.primary_master_node, None) self.assertNotEqual(self.app.primary_master_node, None)
self.assertEqual(self.app.primary, False) self.assertEqual(self.app.primary, False)
# Now tell it's another node which is primary, it must raise # Now tell it's another node which is primary, it must raise
self.assertRaises(ElectionFailure, self.election.answerPrimaryMaster, conn, p, uuid, []) self.assertRaises(ElectionFailure, self.election.answerPrimary, conn, p, uuid, [])
...@@ -280,12 +280,12 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -280,12 +280,12 @@ class MasterServerElectionTests(NeoTestBase):
return uuid return uuid
def checkCalledAskPrimaryMaster(self, conn, packet_number=0): def checkCalledAskPrimary(self, conn, packet_number=0):
""" Check ask primary master has been send""" """ Check ask primary master has been send"""
call = conn.mockGetNamedCalls("_addPacket")[packet_number] call = conn.mockGetNamedCalls("_addPacket")[packet_number]
packet = call.getParam(0) packet = call.getParam(0)
self.assertTrue(isinstance(packet, Packet)) self.assertTrue(isinstance(packet, Packet))
self.assertEquals(packet.getType(),ASK_PRIMARY_MASTER) self.assertEquals(packet.getType(),AskPrimary)
# Tests # Tests
...@@ -352,7 +352,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -352,7 +352,7 @@ class MasterServerElectionTests(NeoTestBase):
def test_07_packetReceived(self): def test_07_packetReceived(self):
uuid = self.identifyToMasterNode(port=self.master_port) uuid = self.identifyToMasterNode(port=self.master_port)
p = protocol.acceptNodeIdentification(NodeTypes.MASTER, uuid, p = Packets.AcceptIdentification(NodeTypes.MASTER, uuid,
("127.0.0.1", self.master_port), 1009, 2, self.app.uuid) ("127.0.0.1", self.master_port), 1009, 2, self.app.uuid)
conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port), conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port),
...@@ -369,20 +369,20 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -369,20 +369,20 @@ class MasterServerElectionTests(NeoTestBase):
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(), self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getState(),
NodeStates.RUNNING) NodeStates.RUNNING)
def test_08_acceptNodeIdentification1(self): def test_08_AcceptIdentification1(self):
# test with storage node, must be rejected # test with storage node, must be rejected
uuid = self.getNewUUID() uuid = self.getNewUUID()
conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port), conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port),
connector_handler = DoNothingConnector) connector_handler = DoNothingConnector)
args = (NodeTypes.MASTER, uuid, ('127.0.0.1', self.master_port), args = (NodeTypes.MASTER, uuid, ('127.0.0.1', self.master_port),
self.app.pt.getPartitions(), self.app.pt.getReplicas(), self.app.uuid) self.app.pt.getPartitions(), self.app.pt.getReplicas(), self.app.uuid)
p = protocol.acceptNodeIdentification(*args) p = Packets.AcceptIdentification(*args)
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
self.assertEqual(len(self.app.negotiating_master_node_set), 1) self.assertEqual(len(self.app.negotiating_master_node_set), 1)
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getUUID(), None) self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getUUID(), None)
self.assertEqual(conn.getUUID(), None) self.assertEqual(conn.getUUID(), None)
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.election.acceptNodeIdentification(conn, p, NodeTypes.STORAGE, self.election.AcceptIdentification(conn, p, NodeTypes.STORAGE,
uuid, "127.0.0.1", self.master_port, uuid, "127.0.0.1", self.master_port,
self.app.pt.getPartitions(), self.app.pt.getPartitions(),
self.app.pt.getReplicas(), self.app.pt.getReplicas(),
...@@ -392,41 +392,41 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -392,41 +392,41 @@ class MasterServerElectionTests(NeoTestBase):
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
self.assertEqual(len(self.app.negotiating_master_node_set), 0) self.assertEqual(len(self.app.negotiating_master_node_set), 0)
def test_08_acceptNodeIdentification2(self): def test_08_AcceptIdentification2(self):
# test with bad address, must be rejected # test with bad address, must be rejected
uuid = self.getNewUUID() uuid = self.getNewUUID()
conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port), conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port),
connector_handler = DoNothingConnector) connector_handler = DoNothingConnector)
args = (NodeTypes.MASTER, uuid, ('127.0.0.1', self.master_port), args = (NodeTypes.MASTER, uuid, ('127.0.0.1', self.master_port),
self.app.pt.getPartitions(), self.app.pt.getReplicas(), self.app.uuid) self.app.pt.getPartitions(), self.app.pt.getReplicas(), self.app.uuid)
p = protocol.acceptNodeIdentification(*args) p = Packets.AcceptIdentification(*args)
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
self.assertEqual(len(self.app.negotiating_master_node_set), 1) self.assertEqual(len(self.app.negotiating_master_node_set), 1)
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getUUID(), None) self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getUUID(), None)
self.assertEqual(conn.getUUID(), None) self.assertEqual(conn.getUUID(), None)
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.election.acceptNodeIdentification(conn, p, NodeTypes.STORAGE, self.election.AcceptIdentification(conn, p, NodeTypes.STORAGE,
uuid, ("127.0.0.2", self.master_port), uuid, ("127.0.0.2", self.master_port),
self.app.pt.getPartitions(), self.app.pt.getPartitions(),
self.app.pt.getReplicas(), self.app.pt.getReplicas(),
self.app.uuid) self.app.uuid)
self.assertEqual(conn.getConnector(), None) self.assertEqual(conn.getConnector(), None)
def test_08_acceptNodeIdentification3(self): def test_08_AcceptIdentification3(self):
# test with master node, must be ok # test with master node, must be ok
uuid = self.getNewUUID() uuid = self.getNewUUID()
conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port), conn = ClientConnection(self.app.em, self.election, addr = ("127.0.0.1", self.master_port),
connector_handler = DoNothingConnector) connector_handler = DoNothingConnector)
args = (NodeTypes.MASTER, uuid, ('127.0.0.1', self.master_port), args = (NodeTypes.MASTER, uuid, ('127.0.0.1', self.master_port),
self.app.pt.getPartitions(), self.app.pt.getReplicas(), self.app.uuid) self.app.pt.getPartitions(), self.app.pt.getReplicas(), self.app.uuid)
p = protocol.acceptNodeIdentification(*args) p = Packets.AcceptIdentification(*args)
self.assertEqual(len(self.app.unconnected_master_node_set), 0) self.assertEqual(len(self.app.unconnected_master_node_set), 0)
self.assertEqual(len(self.app.negotiating_master_node_set), 1) self.assertEqual(len(self.app.negotiating_master_node_set), 1)
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getUUID(), None) self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getUUID(), None)
self.assertEqual(conn.getUUID(), None) self.assertEqual(conn.getUUID(), None)
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),1)
self.election.acceptNodeIdentification(conn, p, NodeTypes.MASTER, self.election.AcceptIdentification(conn, p, NodeTypes.MASTER,
uuid, ("127.0.0.1", self.master_port), uuid, ("127.0.0.1", self.master_port),
self.app.pt.getPartitions(), self.app.pt.getPartitions(),
self.app.pt.getReplicas(), self.app.pt.getReplicas(),
...@@ -434,20 +434,20 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -434,20 +434,20 @@ class MasterServerElectionTests(NeoTestBase):
self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getUUID(), uuid) self.assertEqual(self.app.nm.getByAddress(conn.getAddress()).getUUID(), uuid)
self.assertEqual(conn.getUUID(), uuid) self.assertEqual(conn.getUUID(), uuid)
self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),2) self.assertEqual(len(conn.getConnector().mockGetNamedCalls("_addPacket")),2)
self.checkCalledAskPrimaryMaster(conn.getConnector(), 1) self.checkCalledAskPrimary(conn.getConnector(), 1)
def test_10_requestNodeIdentification(self): def test_10_RequestIdentification(self):
election = self.election election = self.election
uuid = self.getNewUUID() uuid = self.getNewUUID()
args = (NodeTypes.MASTER, uuid, ('127.0.0.1', self.storage_port), args = (NodeTypes.MASTER, uuid, ('127.0.0.1', self.storage_port),
'INVALID_NAME') 'INVALID_NAME')
packet = protocol.requestNodeIdentification(*args) packet = Packets.RequestIdentification(*args)
# test alien cluster # test alien cluster
conn = Mock({"_addPacket" : None, "abort" : None, conn = Mock({"_addPacket" : None, "abort" : None,
"isServer" : True}) "isServer" : True})
self.checkProtocolErrorRaised( self.checkProtocolErrorRaised(
election.requestNodeIdentification, election.RequestIdentification,
conn, conn,
packet=packet, packet=packet,
node_type=NodeTypes.MASTER, node_type=NodeTypes.MASTER,
...@@ -458,7 +458,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -458,7 +458,7 @@ class MasterServerElectionTests(NeoTestBase):
conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None, conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None,
"isServer" : True}) "isServer" : True})
self.checkNotReadyErrorRaised( self.checkNotReadyErrorRaised(
election.requestNodeIdentification, election.RequestIdentification,
conn, conn,
packet=packet, packet=packet,
node_type=NodeTypes.STORAGE, node_type=NodeTypes.STORAGE,
...@@ -473,7 +473,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -473,7 +473,7 @@ class MasterServerElectionTests(NeoTestBase):
node = self.app.nm.getMasterList()[0] node = self.app.nm.getMasterList()[0]
self.assertEqual(node.getUUID(), None) self.assertEqual(node.getUUID(), None)
self.assertEqual(node.getState(), NodeStates.RUNNING) self.assertEqual(node.getState(), NodeStates.RUNNING)
election.requestNodeIdentification(conn, election.RequestIdentification(conn,
packet=packet, packet=packet,
node_type=NodeTypes.MASTER, node_type=NodeTypes.MASTER,
uuid=uuid, uuid=uuid,
...@@ -482,7 +482,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -482,7 +482,7 @@ class MasterServerElectionTests(NeoTestBase):
self.assertEqual(len(self.app.nm.getMasterList()), 1) self.assertEqual(len(self.app.nm.getMasterList()), 1)
self.assertEqual(node.getUUID(), uuid) self.assertEqual(node.getUUID(), uuid)
self.assertEqual(node.getState(), NodeStates.RUNNING) self.assertEqual(node.getState(), NodeStates.RUNNING)
self.checkAcceptNodeIdentification(conn, answered_packet=packet) self.checkAcceptIdentification(conn, answered_packet=packet)
# unknown node # unknown node
conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None, conn = Mock({"_addPacket" : None, "abort" : None, "expectMessage" : None,
"isServer" : True}) "isServer" : True})
...@@ -490,7 +490,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -490,7 +490,7 @@ class MasterServerElectionTests(NeoTestBase):
self.assertEqual(len(self.app.nm.getMasterList()), 1) self.assertEqual(len(self.app.nm.getMasterList()), 1)
self.assertEqual(len(self.app.unconnected_master_node_set), 1) self.assertEqual(len(self.app.unconnected_master_node_set), 1)
self.assertEqual(len(self.app.negotiating_master_node_set), 0) self.assertEqual(len(self.app.negotiating_master_node_set), 0)
election.requestNodeIdentification(conn, election.RequestIdentification(conn,
packet=packet, packet=packet,
node_type=NodeTypes.MASTER, node_type=NodeTypes.MASTER,
uuid=new_uuid, uuid=new_uuid,
...@@ -498,7 +498,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -498,7 +498,7 @@ class MasterServerElectionTests(NeoTestBase):
self.master_port+1), self.master_port+1),
name=self.app.name,) name=self.app.name,)
self.assertEqual(len(self.app.nm.getMasterList()), 2) self.assertEqual(len(self.app.nm.getMasterList()), 2)
self.checkAcceptNodeIdentification(conn, answered_packet=packet) self.checkAcceptIdentification(conn, answered_packet=packet)
self.assertEqual(len(self.app.unconnected_master_node_set), 2) self.assertEqual(len(self.app.unconnected_master_node_set), 2)
self.assertEqual(len(self.app.negotiating_master_node_set), 0) self.assertEqual(len(self.app.negotiating_master_node_set), 0)
# broken node # broken node
...@@ -510,7 +510,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -510,7 +510,7 @@ class MasterServerElectionTests(NeoTestBase):
node.setState(NodeStates.BROKEN) node.setState(NodeStates.BROKEN)
self.assertEqual(node.getState(), NodeStates.BROKEN) self.assertEqual(node.getState(), NodeStates.BROKEN)
self.checkBrokenNodeDisallowedErrorRaised( self.checkBrokenNodeDisallowedErrorRaised(
election.requestNodeIdentification, election.RequestIdentification,
conn, conn,
packet=packet, packet=packet,
node_type=NodeTypes.MASTER, node_type=NodeTypes.MASTER,
...@@ -520,17 +520,17 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -520,17 +520,17 @@ class MasterServerElectionTests(NeoTestBase):
name=self.app.name,) name=self.app.name,)
def test_12_announcePrimaryMaster(self): def test_12_announcePrimary(self):
election = self.election election = self.election
uuid = self.identifyToMasterNode(port=self.master_port) uuid = self.identifyToMasterNode(port=self.master_port)
packet = Packet(msg_type=PacketTypes.ANNOUNCE_PRIMARY_MASTER) packet = Packets.AnnouncePrimary()
# No uuid # No uuid
conn = Mock({"_addPacket" : None, conn = Mock({"_addPacket" : None,
"getUUID" : None, "getUUID" : None,
"isServer" : True, "isServer" : True,
"getAddress" : ("127.0.0.1", self.master_port)}) "getAddress" : ("127.0.0.1", self.master_port)})
self.assertEqual(len(self.app.nm.getMasterList()), 1) self.assertEqual(len(self.app.nm.getMasterList()), 1)
self.checkIdenficationRequired(election.announcePrimaryMaster, conn, packet) self.checkIdenficationRequired(election.announcePrimary, conn, packet)
# announce # announce
conn = Mock({"_addPacket" : None, conn = Mock({"_addPacket" : None,
"getUUID" : uuid, "getUUID" : uuid,
...@@ -538,7 +538,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -538,7 +538,7 @@ class MasterServerElectionTests(NeoTestBase):
"getAddress" : ("127.0.0.1", self.master_port)}) "getAddress" : ("127.0.0.1", self.master_port)})
self.assertEqual(self.app.primary, None) self.assertEqual(self.app.primary, None)
self.assertEqual(self.app.primary_master_node, None) self.assertEqual(self.app.primary_master_node, None)
election.announcePrimaryMaster(conn, packet) election.announcePrimary(conn, packet)
self.assertEqual(self.app.primary, False) self.assertEqual(self.app.primary, False)
self.assertNotEqual(self.app.primary_master_node, None) self.assertNotEqual(self.app.primary_master_node, None)
# set current as primary, and announce another, must raise # set current as primary, and announce another, must raise
...@@ -548,24 +548,24 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -548,24 +548,24 @@ class MasterServerElectionTests(NeoTestBase):
"getAddress" : ("127.0.0.1", self.master_port)}) "getAddress" : ("127.0.0.1", self.master_port)})
self.app.primary = True self.app.primary = True
self.assertEqual(self.app.primary, True) self.assertEqual(self.app.primary, True)
self.assertRaises(ElectionFailure, election.announcePrimaryMaster, conn, packet) self.assertRaises(ElectionFailure, election.announcePrimary, conn, packet)
def test_13_reelectPrimaryMaster(self): def test_13_reelectPrimary(self):
election = self.election election = self.election
uuid = self.identifyToMasterNode(port=self.master_port) uuid = self.identifyToMasterNode(port=self.master_port)
packet = protocol.askPrimaryMaster() packet = Packets.AskPrimary()
# No uuid # No uuid
conn = Mock({"_addPacket" : None, conn = Mock({"_addPacket" : None,
"getUUID" : None, "getUUID" : None,
"isServer" : True, "isServer" : True,
"getAddress" : ("127.0.0.1", self.master_port)}) "getAddress" : ("127.0.0.1", self.master_port)})
self.assertRaises(ElectionFailure, election.reelectPrimaryMaster, conn, packet) self.assertRaises(ElectionFailure, election.reelectPrimary, conn, packet)
def test_14_notifyNodeInformation(self): def test_14_notifyNodeInformation(self):
election = self.election election = self.election
uuid = self.identifyToMasterNode(port=self.master_port) uuid = self.identifyToMasterNode(port=self.master_port)
packet = Packet(msg_type=PacketTypes.NOTIFY_NODE_INFORMATION) packet = Packets.NotifyNodeInformation()
# do not answer if no uuid # do not answer if no uuid
conn = Mock({"getUUID" : None, conn = Mock({"getUUID" : None,
"getAddress" : ("127.0.0.1", self.master_port)}) "getAddress" : ("127.0.0.1", self.master_port)})
...@@ -594,7 +594,7 @@ class MasterServerElectionTests(NeoTestBase): ...@@ -594,7 +594,7 @@ class MasterServerElectionTests(NeoTestBase):
"getAddress" : ("127.0.0.1", self.master_port)}) "getAddress" : ("127.0.0.1", self.master_port)})
node_list = [(NodeTypes.CLIENT, ('127.0.0.1', self.master_port - 1), node_list = [(NodeTypes.CLIENT, ('127.0.0.1', self.master_port - 1),
self.getNewUUID(), NodeStates.DOWN),] self.getNewUUID(), NodeStates.DOWN),]
self.assertEqual(len(self.app.nm.getNodeList()), 0) self.assertEqual(len(self.app.nm.getList()), 0)
election.notifyNodeInformation(conn, packet, node_list) election.notifyNodeInformation(conn, packet, node_list)
self.assertEqual(len(self.app.nm.getNodeList()), 0) self.assertEqual(len(self.app.nm.getNodeList()), 0)
# tell about another master node # tell about another master node
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import unittest import unittest
from struct import pack, unpack from struct import pack, unpack
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.protocol import Packet, PacketTypes from neo.protocol import Packet, Packets
from neo.protocol import NodeTypes, NodeStates, CellStates from neo.protocol import NodeTypes, NodeStates, CellStates
from neo.master.handlers.recovery import RecoveryHandler from neo.master.handlers.recovery import RecoveryHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -93,7 +93,7 @@ class MasterRecoveryTests(NeoTestBase): ...@@ -93,7 +93,7 @@ class MasterRecoveryTests(NeoTestBase):
def test_08_notifyNodeInformation(self): def test_08_notifyNodeInformation(self):
recovery = self.recovery recovery = self.recovery
uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port) uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port)
packet = Packet(msg_type=PacketTypes.NOTIFY_NODE_INFORMATION) packet = Packets.NotifyNodeInformation()
# tell about a client node, do nothing # tell about a client node, do nothing
conn = self.getFakeConnection(uuid, self.master_address) conn = self.getFakeConnection(uuid, self.master_address)
node_list = [(NodeTypes.CLIENT, '127.0.0.1', self.client_port, node_list = [(NodeTypes.CLIENT, '127.0.0.1', self.client_port,
...@@ -151,7 +151,7 @@ class MasterRecoveryTests(NeoTestBase): ...@@ -151,7 +151,7 @@ class MasterRecoveryTests(NeoTestBase):
def test_09_answerLastIDs(self): def test_09_answerLastIDs(self):
recovery = self.recovery recovery = self.recovery
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ANSWER_LAST_IDS) packet = Packets.AnswerLastIDs()
loid = self.app.loid loid = self.app.loid
ltid = self.app.ltid ltid = self.app.ltid
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
...@@ -178,7 +178,7 @@ class MasterRecoveryTests(NeoTestBase): ...@@ -178,7 +178,7 @@ class MasterRecoveryTests(NeoTestBase):
def test_10_answerPartitionTable(self): def test_10_answerPartitionTable(self):
recovery = self.recovery recovery = self.recovery
uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port) uuid = self.identifyToMasterNode(NodeTypes.MASTER, port=self.master_port)
packet = Packet(msg_type=PacketTypes.ANSWER_PARTITION_TABLE) packet = Packets.AnswerPartitionTable()
# not from target node, ignore # not from target node, ignore
uuid = self.identifyToMasterNode(NodeTypes.STORAGE, port=self.storage_port) uuid = self.identifyToMasterNode(NodeTypes.STORAGE, port=self.storage_port)
conn = self.getFakeConnection(uuid, self.storage_port) conn = self.getFakeConnection(uuid, self.storage_port)
......
...@@ -21,7 +21,7 @@ from mock import Mock ...@@ -21,7 +21,7 @@ from mock import Mock
from struct import pack, unpack from struct import pack, unpack
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo import protocol from neo import protocol
from neo.protocol import Packet, PacketTypes from neo.protocol import Packet, Packets
from neo.protocol import NodeTypes, NodeStates, CellStates from neo.protocol import NodeTypes, NodeStates, CellStates
from neo.master.handlers.storage import StorageServiceHandler from neo.master.handlers.storage import StorageServiceHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -66,7 +66,7 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -66,7 +66,7 @@ class MasterStorageHandlerTests(NeoTestBase):
def test_05_notifyNodeInformation(self): def test_05_notifyNodeInformation(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.NOTIFY_NODE_INFORMATION) packet = Packets.NotifyNodeInformation()
# tell the master node that is not running any longer, it must raises # tell the master node that is not running any longer, it must raises
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
node_list = [(NodeTypes.MASTER, '127.0.0.1', self.master_port, node_list = [(NodeTypes.MASTER, '127.0.0.1', self.master_port,
...@@ -136,7 +136,7 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -136,7 +136,7 @@ class MasterStorageHandlerTests(NeoTestBase):
def test_06_answerLastIDs(self): def test_06_answerLastIDs(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ANSWER_LAST_IDS) packet = Packets.AnswerLastIDs()
loid = self.app.loid loid = self.app.loid
ltid = self.app.ltid ltid = self.app.ltid
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
...@@ -154,7 +154,7 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -154,7 +154,7 @@ class MasterStorageHandlerTests(NeoTestBase):
def test_10_notifyInformationLocked(self): def test_10_notifyInformationLocked(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode(port=10020) uuid = self.identifyToMasterNode(port=10020)
packet = Packet(msg_type=PacketTypes.NOTIFY_INFORMATION_LOCKED) packet = Packets.NotifyInformationLocked()
# give an older tid than the PMN known, must abort # give an older tid than the PMN known, must abort
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
oid_list = [] oid_list = []
...@@ -197,7 +197,8 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -197,7 +197,8 @@ class MasterStorageHandlerTests(NeoTestBase):
def test_12_askLastIDs(self): def test_12_askLastIDs(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ASK_LAST_IDS) packet = Packets.AskLastIDs()
packet.setId(0)
# give a uuid # give a uuid
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
ptid = self.app.pt.getID() ptid = self.app.pt.getID()
...@@ -205,7 +206,7 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -205,7 +206,7 @@ class MasterStorageHandlerTests(NeoTestBase):
oid = self.app.loid oid = self.app.loid
service.askLastIDs(conn, packet) service.askLastIDs(conn, packet)
packet = self.checkAnswerLastIDs(conn, answered_packet=packet) packet = self.checkAnswerLastIDs(conn, answered_packet=packet)
loid, ltid, lptid = protocol._decodeAnswerLastIDs(packet._body) loid, ltid, lptid = packet.decode()
self.assertEqual(loid, oid) self.assertEqual(loid, oid)
self.assertEqual(ltid, tid) self.assertEqual(ltid, tid)
self.assertEqual(lptid, ptid) self.assertEqual(lptid, ptid)
...@@ -214,12 +215,13 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -214,12 +215,13 @@ class MasterStorageHandlerTests(NeoTestBase):
def test_13_askUnfinishedTransactions(self): def test_13_askUnfinishedTransactions(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ASK_UNFINISHED_TRANSACTIONS) packet = Packets.AskUnfinishedTransactions()
packet.setId(0)
# give a uuid # give a uuid
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
service.askUnfinishedTransactions(conn, packet) service.askUnfinishedTransactions(conn, packet)
packet = self.checkAnswerUnfinishedTransactions(conn, answered_packet=packet) packet = self.checkAnswerUnfinishedTransactions(conn, answered_packet=packet)
tid_list = protocol._decodeAnswerUnfinishedTransactions(packet._body)[0] tid_list = packet.decode()
self.assertEqual(len(tid_list), 0) self.assertEqual(len(tid_list), 0)
# create some transaction # create some transaction
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
...@@ -238,7 +240,7 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -238,7 +240,7 @@ class MasterStorageHandlerTests(NeoTestBase):
def test_14_notifyPartitionChanges(self): def test_14_notifyPartitionChanges(self):
service = self.service service = self.service
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.NOTIFY_PARTITION_CHANGES) packet = Packets.NotifyPartitionChanges()
# do not answer if not a storage node # do not answer if not a storage node
client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT, client_uuid = self.identifyToMasterNode(node_type=NodeTypes.CLIENT,
port=self.client_port) port=self.client_port)
...@@ -335,7 +337,7 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -335,7 +337,7 @@ class MasterStorageHandlerTests(NeoTestBase):
port = self.client_port) port = self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
packet = Packet(msg_type=ASK_BEGIN_TRANSACTION) packet = AskBeginTransaction()
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
...@@ -384,7 +386,7 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -384,7 +386,7 @@ class MasterStorageHandlerTests(NeoTestBase):
port = self.client_port) port = self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
packet = Packet(msg_type=ASK_BEGIN_TRANSACTION) packet = AskBeginTransaction()
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
...@@ -433,7 +435,7 @@ class MasterStorageHandlerTests(NeoTestBase): ...@@ -433,7 +435,7 @@ class MasterStorageHandlerTests(NeoTestBase):
port = self.client_port) port = self.client_port)
conn = self.getFakeConnection(client_uuid, self.client_address) conn = self.getFakeConnection(client_uuid, self.client_address)
lptid = self.app.pt.getID() lptid = self.app.pt.getID()
packet = Packet(msg_type=ASK_BEGIN_TRANSACTION) packet = AskBeginTransaction()
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
service.askBeginTransaction(conn, packet) service.askBeginTransaction(conn, packet)
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
import unittest import unittest
from struct import pack, unpack from struct import pack, unpack
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.protocol import Packet, PacketTypes from neo import protocol
from neo.protocol import Packet, Packets
from neo.protocol import NodeTypes, NodeStates, ErrorCodes from neo.protocol import NodeTypes, NodeStates, ErrorCodes
from neo.master.handlers.verification import VerificationHandler from neo.master.handlers.verification import VerificationHandler
from neo.master.app import Application from neo.master.app import Application
...@@ -127,7 +128,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -127,7 +128,7 @@ class MasterVerificationTests(NeoTestBase):
def test_09_answerLastIDs(self): def test_09_answerLastIDs(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ANSWER_LAST_IDS) packet = Packets.AnswerLastIDs()
loid = self.app.loid loid = self.app.loid
ltid = self.app.ltid ltid = self.app.ltid
lptid = '\0' * 8 lptid = '\0' * 8
...@@ -151,7 +152,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -151,7 +152,7 @@ class MasterVerificationTests(NeoTestBase):
def test_11_answerUnfinishedTransactions(self): def test_11_answerUnfinishedTransactions(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ANSWER_UNFINISHED_TRANSACTIONS) packet = Packets.AnswerUnfinishedTransactions()
# do nothing # do nothing
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.assertEquals(len(self.app.asking_uuid_dict), 0) self.assertEquals(len(self.app.asking_uuid_dict), 0)
...@@ -178,7 +179,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -178,7 +179,7 @@ class MasterVerificationTests(NeoTestBase):
def test_12_answerTransactionInformation(self): def test_12_answerTransactionInformation(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ANSWER_TRANSACTION_INFORMATION) packet = Packets.AnswerTransactionInformation()
# do nothing, as unfinished_oid_set is None # do nothing, as unfinished_oid_set is None
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.assertEquals(len(self.app.asking_uuid_dict), 0) self.assertEquals(len(self.app.asking_uuid_dict), 0)
...@@ -229,7 +230,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -229,7 +230,7 @@ class MasterVerificationTests(NeoTestBase):
def test_13_tidNotFound(self): def test_13_tidNotFound(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=ErrorCodes.TID_NOT_FOUND) packet = protocol.tidNotFound('')
# do nothing as asking_uuid_dict is True # do nothing as asking_uuid_dict is True
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.assertEquals(len(self.app.asking_uuid_dict), 0) self.assertEquals(len(self.app.asking_uuid_dict), 0)
...@@ -250,7 +251,7 @@ class MasterVerificationTests(NeoTestBase): ...@@ -250,7 +251,7 @@ class MasterVerificationTests(NeoTestBase):
def test_14_answerObjectPresent(self): def test_14_answerObjectPresent(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=PacketTypes.ANSWER_OBJECT_PRESENT) packet = Packets.AnswerObjectPresent()
# do nothing as asking_uuid_dict is True # do nothing as asking_uuid_dict is True
upper, lower = unpack('!LL', self.app.ltid) upper, lower = unpack('!LL', self.app.ltid)
new_tid = pack('!LL', upper, lower + 10) new_tid = pack('!LL', upper, lower + 10)
...@@ -272,8 +273,8 @@ class MasterVerificationTests(NeoTestBase): ...@@ -272,8 +273,8 @@ class MasterVerificationTests(NeoTestBase):
def test_15_oidNotFound(self): def test_15_oidNotFound(self):
verification = self.verification verification = self.verification
uuid = self.identifyToMasterNode() uuid = self.identifyToMasterNode()
packet = Packet(msg_type=ErrorCodes.OID_NOT_FOUND) packet = protocol.oidNotFound('')
# do nothinf as asking_uuid_dict is True # do nothing as asking_uuid_dict is True
conn = self.getFakeConnection(uuid, self.storage_address) conn = self.getFakeConnection(uuid, self.storage_address)
self.assertEquals(len(self.app.asking_uuid_dict), 0) self.assertEquals(len(self.app.asking_uuid_dict), 0)
self.app.asking_uuid_dict[uuid] = True self.app.asking_uuid_dict[uuid] = True
......
...@@ -27,7 +27,7 @@ from neo.storage.handlers.client import TransactionInformation ...@@ -27,7 +27,7 @@ from neo.storage.handlers.client import TransactionInformation
from neo.storage.handlers.client import ClientOperationHandler from neo.storage.handlers.client import ClientOperationHandler
from neo.exception import PrimaryFailure, OperationFailure from neo.exception import PrimaryFailure, OperationFailure
from neo.pt import PartitionTable from neo.pt import PartitionTable
from neo.protocol import PacketTypes, Packet, INVALID_PARTITION from neo.protocol import Packets, Packet, INVALID_PARTITION
from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL
class StorageClientHandlerTests(NeoTestBase): class StorageClientHandlerTests(NeoTestBase):
...@@ -107,14 +107,16 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -107,14 +107,16 @@ class StorageClientHandlerTests(NeoTestBase):
def test_18_askTransactionInformation1(self): def test_18_askTransactionInformation1(self):
# transaction does not exists # transaction does not exists
conn = Mock({ }) conn = Mock({ })
packet = Packet(msg_type=PacketTypes.ASK_TRANSACTION_INFORMATION) packet = Packets.AskTransactionInformation()
packet.setId(0)
self.operation.askTransactionInformation(conn, packet, INVALID_TID) self.operation.askTransactionInformation(conn, packet, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_18_askTransactionInformation2(self): def test_18_askTransactionInformation2(self):
# answer # answer
conn = Mock({ }) conn = Mock({ })
packet = Packet(msg_type=PacketTypes.ASK_TRANSACTION_INFORMATION) packet = Packets.AskTransactionInformation()
packet.setId(0)
dm = Mock({ "getTransaction": (INVALID_TID, 'user', 'desc', '', ), }) dm = Mock({ "getTransaction": (INVALID_TID, 'user', 'desc', '', ), })
self.app.dm = dm self.app.dm = dm
self.operation.askTransactionInformation(conn, packet, INVALID_TID) self.operation.askTransactionInformation(conn, packet, INVALID_TID)
...@@ -124,7 +126,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -124,7 +126,8 @@ class StorageClientHandlerTests(NeoTestBase):
# delayed response # delayed response
conn = Mock({}) conn = Mock({})
self.app.dm = Mock() self.app.dm = Mock()
packet = Packet(msg_type=PacketTypes.ASK_OBJECT) packet = Packets.AskObject()
packet.setId(0)
self.app.load_lock_dict[INVALID_OID] = object() self.app.load_lock_dict[INVALID_OID] = object()
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, packet,
...@@ -139,7 +142,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -139,7 +142,8 @@ class StorageClientHandlerTests(NeoTestBase):
# invalid serial / tid / packet not found # invalid serial / tid / packet not found
self.app.dm = Mock({'getObject': None}) self.app.dm = Mock({'getObject': None})
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_OBJECT) packet = Packets.AskObject()
packet.setId(0)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, packet,
oid=INVALID_OID, oid=INVALID_OID,
...@@ -155,7 +159,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -155,7 +159,8 @@ class StorageClientHandlerTests(NeoTestBase):
# object found => answer # object found => answer
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )}) self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )})
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_OBJECT) packet = Packets.AskObject()
packet.setId(0)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, packet,
oid=INVALID_OID, oid=INVALID_OID,
...@@ -170,7 +175,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -170,7 +175,8 @@ class StorageClientHandlerTests(NeoTestBase):
app.pt = Mock() app.pt = Mock()
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_TIDS) packet = Packets.AskTIDs()
packet.setId(0)
self.checkProtocolErrorRaised(self.operation.askTIDs, conn, packet, 1, 1, None) self.checkProtocolErrorRaised(self.operation.askTIDs, conn, packet, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getTIDList')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getTIDList')), 0)
...@@ -178,7 +184,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -178,7 +184,8 @@ class StorageClientHandlerTests(NeoTestBase):
def test_25_askTIDs2(self): def test_25_askTIDs2(self):
# well case => answer # well case => answer
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_TIDS) packet = Packets.AskTIDs()
packet.setId(0)
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.operation.askTIDs(conn, packet, 1, 2, 1) self.operation.askTIDs(conn, packet, 1, 2, 1)
...@@ -190,7 +197,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -190,7 +197,8 @@ class StorageClientHandlerTests(NeoTestBase):
def test_25_askTIDs3(self): def test_25_askTIDs3(self):
# invalid partition => answer usable partitions # invalid partition => answer usable partitions
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_TIDS) packet = Packets.AskTIDs()
packet.setId(0)
cell = Mock({'getUUID':self.app.uuid}) cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.app.pt = Mock({'getCellList': (cell, ), 'getPartitions': 1}) self.app.pt = Mock({'getCellList': (cell, ), 'getPartitions': 1})
...@@ -206,13 +214,15 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -206,13 +214,15 @@ class StorageClientHandlerTests(NeoTestBase):
app = self.app app = self.app
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_OBJECT_HISTORY) packet = Packets.AskObjectHistory()
packet.setId(0)
self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn, packet, 1, 1, None) self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn, packet, 1, 1, None)
self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_26_askObjectHistory2(self): def test_26_askObjectHistory2(self):
# first case: empty history # first case: empty history
packet = Packet(msg_type=PacketTypes.ASK_OBJECT_HISTORY) packet = Packets.AskObjectHistory()
packet.setId(0)
conn = Mock({}) conn = Mock({})
self.app.dm = Mock({'getObjectHistory': None}) self.app.dm = Mock({'getObjectHistory': None})
self.operation.askObjectHistory(conn, packet, INVALID_OID, 1, 2) self.operation.askObjectHistory(conn, packet, INVALID_OID, 1, 2)
...@@ -225,7 +235,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -225,7 +235,8 @@ class StorageClientHandlerTests(NeoTestBase):
def test_27_askStoreTransaction2(self): def test_27_askStoreTransaction2(self):
# add transaction entry # add transaction entry
packet = Packet(msg_type=PacketTypes.ASK_STORE_TRANSACTION) packet = Packets.AskStoreTransaction()
packet.setId(0)
conn = Mock({'getUUID': self.getNewUUID()}) conn = Mock({'getUUID': self.getNewUUID()})
self.operation.askStoreTransaction(conn, packet, self.operation.askStoreTransaction(conn, packet,
INVALID_TID, '', '', '', ()) INVALID_TID, '', '', '', ())
...@@ -237,7 +248,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -237,7 +248,8 @@ class StorageClientHandlerTests(NeoTestBase):
def test_28_askStoreObject2(self): def test_28_askStoreObject2(self):
# locked => delayed response # locked => delayed response
packet = Packet(msg_type=PacketTypes.ASK_STORE_OBJECT) packet = Packets.AskStoreObject()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
oid = '\x02' * 8 oid = '\x02' * 8
tid1, tid2 = self.getTwoIDs() tid1, tid2 = self.getTwoIDs()
...@@ -254,7 +266,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -254,7 +266,8 @@ class StorageClientHandlerTests(NeoTestBase):
def test_28_askStoreObject3(self): def test_28_askStoreObject3(self):
# locked => unresolvable conflict => answer # locked => unresolvable conflict => answer
packet = Packet(msg_type=PacketTypes.ASK_STORE_OBJECT) packet = Packets.AskStoreObject()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
tid1, tid2 = self.getTwoIDs() tid1, tid2 = self.getTwoIDs()
self.app.store_lock_dict[INVALID_OID] = tid2 self.app.store_lock_dict[INVALID_OID] = tid2
...@@ -268,7 +281,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -268,7 +281,8 @@ class StorageClientHandlerTests(NeoTestBase):
def test_28_askStoreObject4(self): def test_28_askStoreObject4(self):
# resolvable conflict => answer # resolvable conflict => answer
packet = Packet(msg_type=PacketTypes.ASK_STORE_OBJECT) packet = Packets.AskStoreObject()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
self.app.dm = Mock({'getObjectHistory':((self.getNewUUID(), ), )}) self.app.dm = Mock({'getObjectHistory':((self.getNewUUID(), ), )})
self.assertEquals(self.app.store_lock_dict.get(INVALID_OID, None), None) self.assertEquals(self.app.store_lock_dict.get(INVALID_OID, None), None)
...@@ -282,7 +296,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -282,7 +296,8 @@ class StorageClientHandlerTests(NeoTestBase):
def test_28_askStoreObject5(self): def test_28_askStoreObject5(self):
# no conflict => answer # no conflict => answer
packet = Packet(msg_type=PacketTypes.ASK_STORE_OBJECT) packet = Packets.AskStoreObject()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
self.operation.askStoreObject(conn, packet, INVALID_OID, self.operation.askStoreObject(conn, packet, INVALID_OID,
INVALID_SERIAL, 0, 0, '', INVALID_TID) INVALID_SERIAL, 0, 0, '', INVALID_TID)
...@@ -297,7 +312,8 @@ class StorageClientHandlerTests(NeoTestBase): ...@@ -297,7 +312,8 @@ class StorageClientHandlerTests(NeoTestBase):
def test_29_abortTransaction(self): def test_29_abortTransaction(self):
# remove transaction # remove transaction
packet = Packet(msg_type=PacketTypes.ABORT_TRANSACTION) packet = Packets.AbortTransaction()
packet.setId(0)
conn = Mock({'getUUID': self.app.uuid}) conn = Mock({'getUUID': self.app.uuid})
transaction = Mock({ 'getObjectList': ((0, ), ), }) transaction = Mock({ 'getObjectList': ((0, ), ), })
self.called = False self.called = False
......
...@@ -21,7 +21,7 @@ from neo.tests import NeoTestBase ...@@ -21,7 +21,7 @@ from neo.tests import NeoTestBase
from neo.pt import PartitionTable from neo.pt import PartitionTable
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.initialization import InitializationHandler from neo.storage.handlers.initialization import InitializationHandler
from neo.protocol import Packet, PacketTypes, CellStates from neo.protocol import Packet, Packets, CellStates
from neo.exception import PrimaryFailure from neo.exception import PrimaryFailure
class StorageInitializationHandlerTests(NeoTestBase): class StorageInitializationHandlerTests(NeoTestBase):
...@@ -80,7 +80,7 @@ class StorageInitializationHandlerTests(NeoTestBase): ...@@ -80,7 +80,7 @@ class StorageInitializationHandlerTests(NeoTestBase):
self.checkNoPacketSent(conn) self.checkNoPacketSent(conn)
def test_09_sendPartitionTable(self): def test_09_sendPartitionTable(self):
packet = Packet(msg_type=PacketTypes.SEND_PARTITION_TABLE) packet = Packets.SendPartitionTable()
uuid = self.getNewUUID() uuid = self.getNewUUID()
# send a table # send a table
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
......
...@@ -26,7 +26,7 @@ from neo.storage.app import Application ...@@ -26,7 +26,7 @@ from neo.storage.app import Application
from neo.storage.handlers.master import MasterOperationHandler from neo.storage.handlers.master import MasterOperationHandler
from neo.exception import PrimaryFailure, OperationFailure from neo.exception import PrimaryFailure, OperationFailure
from neo.pt import PartitionTable from neo.pt import PartitionTable
from neo.protocol import CellStates, PacketTypes, Packet from neo.protocol import CellStates, Packets, Packet
from neo.protocol import INVALID_TID, INVALID_OID from neo.protocol import INVALID_TID, INVALID_OID
class StorageMasterHandlerTests(NeoTestBase): class StorageMasterHandlerTests(NeoTestBase):
...@@ -100,7 +100,7 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -100,7 +100,7 @@ class StorageMasterHandlerTests(NeoTestBase):
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
}) })
app.replicator = Mock({}) app.replicator = Mock({})
packet = Packet(msg_type=PacketTypes.NOTIFY_PARTITION_CHANGES) packet = Packets.NotifyPartitionChanges()
self.app.pt = Mock({'getID': 1}) self.app.pt = Mock({'getID': 1})
count = len(self.app.nm.getList()) count = len(self.app.nm.getList())
self.operation.notifyPartitionChanges(conn, packet, 0, ()) self.operation.notifyPartitionChanges(conn, packet, 0, ())
...@@ -124,7 +124,7 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -124,7 +124,7 @@ class StorageMasterHandlerTests(NeoTestBase):
"isServer": False, "isServer": False,
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
}) })
packet = Packet(msg_type=PacketTypes.NOTIFY_PARTITION_CHANGES) packet = Packets.NotifyPartitionChanges()
app = self.app app = self.app
# register nodes # register nodes
app.nm.createStorage(uuid=uuid1) app.nm.createStorage(uuid=uuid1)
...@@ -147,14 +147,14 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -147,14 +147,14 @@ class StorageMasterHandlerTests(NeoTestBase):
def test_16_stopOperation1(self): def test_16_stopOperation1(self):
# OperationFailure # OperationFailure
conn = Mock({ 'isServer': False }) conn = Mock({ 'isServer': False })
packet = Packet(msg_type=PacketTypes.STOP_OPERATION) packet = Packets.StopOperation()
self.assertRaises(OperationFailure, self.operation.stopOperation, conn, packet) self.assertRaises(OperationFailure, self.operation.stopOperation, conn, packet)
def test_22_lockInformation2(self): def test_22_lockInformation2(self):
# load transaction informations # load transaction informations
conn = Mock({ 'isServer': False, }) conn = Mock({ 'isServer': False, })
self.app.dm = Mock({ }) self.app.dm = Mock({ })
packet = Packet(msg_type=PacketTypes.LOCK_INFORMATION) packet = Packets.LockInformation()
packet.setId(1) packet.setId(1)
transaction = Mock({ 'getObjectList': ((0, ), ), }) transaction = Mock({ 'getObjectList': ((0, ), ), })
self.app.transaction_dict[INVALID_TID] = transaction self.app.transaction_dict[INVALID_TID] = transaction
...@@ -173,7 +173,7 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -173,7 +173,7 @@ class StorageMasterHandlerTests(NeoTestBase):
# delete transaction informations # delete transaction informations
conn = Mock({ 'isServer': False, }) conn = Mock({ 'isServer': False, })
self.app.dm = Mock({ }) self.app.dm = Mock({ })
packet = Packet(msg_type=PacketTypes.LOCK_INFORMATION) packet = Packets.LockInformation()
packet.setId(1) packet.setId(1)
transaction = Mock({ 'getObjectList': ((0, ), ), }) transaction = Mock({ 'getObjectList': ((0, ), ), })
self.app.transaction_dict[INVALID_TID] = transaction self.app.transaction_dict[INVALID_TID] = transaction
...@@ -195,7 +195,7 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -195,7 +195,7 @@ class StorageMasterHandlerTests(NeoTestBase):
def test_30_answerLastIDs(self): def test_30_answerLastIDs(self):
# set critical TID on replicator # set critical TID on replicator
conn = Mock() conn = Mock()
packet = Packet(msg_type=PacketTypes.ANSWER_LAST_IDS) packet = Packets.AnswerLastIDs()
self.app.replicator = Mock() self.app.replicator = Mock()
self.operation.answerLastIDs( self.operation.answerLastIDs(
conn=conn, conn=conn,
...@@ -211,7 +211,7 @@ class StorageMasterHandlerTests(NeoTestBase): ...@@ -211,7 +211,7 @@ class StorageMasterHandlerTests(NeoTestBase):
def test_31_answerUnfinishedTransactions(self): def test_31_answerUnfinishedTransactions(self):
# set unfinished TID on replicator # set unfinished TID on replicator
conn = Mock() conn = Mock()
packet = Packet(msg_type=PacketTypes.ANSWER_UNFINISHED_TRANSACTIONS) packet = Packets.AnswerUnfinishedTransactions()
self.app.replicator = Mock() self.app.replicator = Mock()
self.operation.answerUnfinishedTransactions( self.operation.answerUnfinishedTransactions(
conn=conn, conn=conn,
......
...@@ -24,7 +24,7 @@ from collections import deque ...@@ -24,7 +24,7 @@ from collections import deque
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.storage import StorageOperationHandler from neo.storage.handlers.storage import StorageOperationHandler
from neo.protocol import PacketTypes, Packet, INVALID_PARTITION from neo.protocol import Packets, Packet, INVALID_PARTITION
from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL from neo.protocol import INVALID_TID, INVALID_OID, INVALID_SERIAL
class StorageStorageHandlerTests(NeoTestBase): class StorageStorageHandlerTests(NeoTestBase):
...@@ -65,14 +65,16 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -65,14 +65,16 @@ class StorageStorageHandlerTests(NeoTestBase):
def test_18_askTransactionInformation1(self): def test_18_askTransactionInformation1(self):
# transaction does not exists # transaction does not exists
conn = Mock({ }) conn = Mock({ })
packet = Packet(msg_type=PacketTypes.ASK_TRANSACTION_INFORMATION) packet = Packets.AskTransactionInformation()
packet.setId(0)
self.operation.askTransactionInformation(conn, packet, INVALID_TID) self.operation.askTransactionInformation(conn, packet, INVALID_TID)
self.checkErrorPacket(conn) self.checkErrorPacket(conn)
def test_18_askTransactionInformation2(self): def test_18_askTransactionInformation2(self):
# answer # answer
conn = Mock({ }) conn = Mock({ })
packet = Packet(msg_type=PacketTypes.ASK_TRANSACTION_INFORMATION) packet = Packets.AskTransactionInformation()
packet.setId(0)
dm = Mock({ "getTransaction": (INVALID_TID, 'user', 'desc', '', ), }) dm = Mock({ "getTransaction": (INVALID_TID, 'user', 'desc', '', ), })
self.app.dm = dm self.app.dm = dm
self.operation.askTransactionInformation(conn, packet, INVALID_TID) self.operation.askTransactionInformation(conn, packet, INVALID_TID)
...@@ -82,7 +84,7 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -82,7 +84,7 @@ class StorageStorageHandlerTests(NeoTestBase):
# delayed response # delayed response
conn = Mock({}) conn = Mock({})
self.app.dm = Mock() self.app.dm = Mock()
packet = Packet(msg_type=PacketTypes.ASK_OBJECT) packet = Packets.AskObject()
self.app.load_lock_dict[INVALID_OID] = object() self.app.load_lock_dict[INVALID_OID] = object()
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, packet,
...@@ -97,7 +99,8 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -97,7 +99,8 @@ class StorageStorageHandlerTests(NeoTestBase):
# invalid serial / tid / packet not found # invalid serial / tid / packet not found
self.app.dm = Mock({'getObject': None}) self.app.dm = Mock({'getObject': None})
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_OBJECT) packet = Packets.AskObject()
packet.setId(0)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, packet,
oid=INVALID_OID, oid=INVALID_OID,
...@@ -113,7 +116,8 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -113,7 +116,8 @@ class StorageStorageHandlerTests(NeoTestBase):
# object found => answer # object found => answer
self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )}) self.app.dm = Mock({'getObject': ('', '', 0, 0, '', )})
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_OBJECT) packet = Packets.AskObject()
packet.setId(0)
self.assertEquals(len(self.app.event_queue), 0) self.assertEquals(len(self.app.event_queue), 0)
self.operation.askObject(conn, packet, self.operation.askObject(conn, packet,
oid=INVALID_OID, oid=INVALID_OID,
...@@ -128,7 +132,7 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -128,7 +132,7 @@ class StorageStorageHandlerTests(NeoTestBase):
app.pt = Mock() app.pt = Mock()
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_TIDS) packet = Packets.AskTIDs()
self.checkProtocolErrorRaised(self.operation.askTIDs, conn, packet, 1, 1, None) self.checkProtocolErrorRaised(self.operation.askTIDs, conn, packet, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getTIDList')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getTIDList')), 0)
...@@ -136,7 +140,8 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -136,7 +140,8 @@ class StorageStorageHandlerTests(NeoTestBase):
def test_25_askTIDs2(self): def test_25_askTIDs2(self):
# well case => answer # well case => answer
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_TIDS) packet = Packets.AskTIDs()
packet.setId(0)
self.app.dm = Mock({'getTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
self.operation.askTIDs(conn, packet, 1, 2, 1) self.operation.askTIDs(conn, packet, 1, 2, 1)
...@@ -148,7 +153,8 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -148,7 +153,8 @@ class StorageStorageHandlerTests(NeoTestBase):
def test_25_askTIDs3(self): def test_25_askTIDs3(self):
# invalid partition => answer usable partitions # invalid partition => answer usable partitions
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_TIDS) packet = Packets.AskTIDs()
packet.setId(0)
cell = Mock({'getUUID':self.app.uuid}) cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getTIDList': (INVALID_TID, )}) self.app.dm = Mock({'getTIDList': (INVALID_TID, )})
self.app.pt = Mock({'getCellList': (cell, ), 'getPartitions': 1}) self.app.pt = Mock({'getCellList': (cell, ), 'getPartitions': 1})
...@@ -164,13 +170,15 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -164,13 +170,15 @@ class StorageStorageHandlerTests(NeoTestBase):
app = self.app app = self.app
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_OBJECT_HISTORY) packet = Packets.AskObjectHistory()
packet.setId(0)
self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn, packet, 1, 1, None) self.checkProtocolErrorRaised(self.operation.askObjectHistory, conn, packet, 1, 1, None)
self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getObjectHistory')), 0)
def test_26_askObjectHistory2(self): def test_26_askObjectHistory2(self):
# first case: empty history # first case: empty history
packet = Packet(msg_type=PacketTypes.ASK_OBJECT_HISTORY) packet = Packets.AskObjectHistory()
packet.setId(0)
conn = Mock({}) conn = Mock({})
self.app.dm = Mock({'getObjectHistory': None}) self.app.dm = Mock({'getObjectHistory': None})
self.operation.askObjectHistory(conn, packet, INVALID_OID, 1, 2) self.operation.askObjectHistory(conn, packet, INVALID_OID, 1, 2)
...@@ -187,7 +195,8 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -187,7 +195,8 @@ class StorageStorageHandlerTests(NeoTestBase):
app.pt = Mock() app.pt = Mock()
app.dm = Mock() app.dm = Mock()
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_OIDS) packet = Packets.AskOIDs()
packet.setId(0)
self.checkProtocolErrorRaised(self.operation.askOIDs, conn, packet, 1, 1, None) self.checkProtocolErrorRaised(self.operation.askOIDs, conn, packet, 1, 1, None)
self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0) self.assertEquals(len(app.pt.mockGetNamedCalls('getCellList')), 0)
self.assertEquals(len(app.dm.mockGetNamedCalls('getOIDList')), 0) self.assertEquals(len(app.dm.mockGetNamedCalls('getOIDList')), 0)
...@@ -195,7 +204,8 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -195,7 +204,8 @@ class StorageStorageHandlerTests(NeoTestBase):
def test_25_askOIDs2(self): def test_25_askOIDs2(self):
# well case > answer OIDs # well case > answer OIDs
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_OIDS) packet = Packets.AskOIDs()
packet.setId(0)
self.app.pt = Mock({'getPartitions': 1}) self.app.pt = Mock({'getPartitions': 1})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )}) self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
self.operation.askOIDs(conn, packet, 1, 2, 1) self.operation.askOIDs(conn, packet, 1, 2, 1)
...@@ -207,7 +217,8 @@ class StorageStorageHandlerTests(NeoTestBase): ...@@ -207,7 +217,8 @@ class StorageStorageHandlerTests(NeoTestBase):
def test_25_askOIDs3(self): def test_25_askOIDs3(self):
# invalid partition => answer usable partitions # invalid partition => answer usable partitions
conn = Mock({}) conn = Mock({})
packet = Packet(msg_type=PacketTypes.ASK_OIDS) packet = Packets.AskOIDs()
packet.setId(0)
cell = Mock({'getUUID':self.app.uuid}) cell = Mock({'getUUID':self.app.uuid})
self.app.dm = Mock({'getOIDList': (INVALID_OID, )}) self.app.dm = Mock({'getOIDList': (INVALID_OID, )})
self.app.pt = Mock({'getCellList': (cell, ), 'getPartitions': 1}) self.app.pt = Mock({'getCellList': (cell, ), 'getPartitions': 1})
......
...@@ -21,7 +21,7 @@ from neo.tests import NeoTestBase ...@@ -21,7 +21,7 @@ from neo.tests import NeoTestBase
from neo.pt import PartitionTable from neo.pt import PartitionTable
from neo.storage.app import Application from neo.storage.app import Application
from neo.storage.handlers.verification import VerificationHandler from neo.storage.handlers.verification import VerificationHandler
from neo.protocol import Packet, PacketTypes, CellStates, ErrorCodes from neo.protocol import Packet, Packets, CellStates, ErrorCodes
from neo.protocol import INVALID_OID, INVALID_TID from neo.protocol import INVALID_OID, INVALID_TID
from neo.exception import PrimaryFailure, OperationFailure from neo.exception import PrimaryFailure, OperationFailure
from neo.storage.mysqldb import p64, u64 from neo.storage.mysqldb import p64, u64
...@@ -172,7 +172,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -172,7 +172,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
"isServer": False, "isServer": False,
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
}) })
packet = Packet(msg_type=PacketTypes.NOTIFY_PARTITION_CHANGES) packet = Packets.NotifyPartitionChanges()
self.verification.notifyPartitionChanges(conn, packet, 1, ()) self.verification.notifyPartitionChanges(conn, packet, 1, ())
self.verification.notifyPartitionChanges(conn, packet, 0, ()) self.verification.notifyPartitionChanges(conn, packet, 0, ())
self.assertEqual(self.app.pt.getID(), 1) self.assertEqual(self.app.pt.getID(), 1)
...@@ -182,7 +182,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -182,7 +182,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
"isServer": False, "isServer": False,
"getAddress" : ("127.0.0.1", self.master_port), "getAddress" : ("127.0.0.1", self.master_port),
}) })
packet = Packet(msg_type=PacketTypes.NOTIFY_PARTITION_CHANGES) packet = Packets.NotifyPartitionChanges()
new_uuid = self.getNewUUID() new_uuid = self.getNewUUID()
cell = (0, new_uuid, CellStates.UP_TO_DATE) cell = (0, new_uuid, CellStates.UP_TO_DATE)
self.app.nm.createStorage(uuid=new_uuid) self.app.nm.createStorage(uuid=new_uuid)
...@@ -201,21 +201,22 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -201,21 +201,22 @@ class StorageVerificationHandlerTests(NeoTestBase):
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
self.assertFalse(self.app.operational) self.assertFalse(self.app.operational)
packet = Packet(msg_type=PacketTypes.STOP_OPERATION) packet = Packets.StopOperation()
self.verification.startOperation(conn, packet) self.verification.startOperation(conn, packet)
self.assertTrue(self.app.operational) self.assertTrue(self.app.operational)
def test_12_stopOperation(self): def test_12_stopOperation(self):
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
packet = Packet(msg_type=PacketTypes.STOP_OPERATION) packet = Packets.StopOperation()
self.assertRaises(OperationFailure, self.verification.stopOperation, conn, packet) self.assertRaises(OperationFailure, self.verification.stopOperation, conn, packet)
def test_13_askUnfinishedTransactions(self): def test_13_askUnfinishedTransactions(self):
# client connection with no data # client connection with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packet(msg_type=PacketTypes.ASK_UNFINISHED_TRANSACTIONS) packet = Packets.AskUnfinishedTransactions()
packet.setId(0)
self.verification.askUnfinishedTransactions(conn, packet) self.verification.askUnfinishedTransactions(conn, packet)
(tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True) (tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 0) self.assertEqual(len(tid_list), 0)
...@@ -227,7 +228,8 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -227,7 +228,8 @@ class StorageVerificationHandlerTests(NeoTestBase):
self.app.dm.commit() self.app.dm.commit()
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packet(msg_type=PacketTypes.ASK_UNFINISHED_TRANSACTIONS) packet = Packets.AskUnfinishedTransactions()
packet.setId(0)
self.verification.askUnfinishedTransactions(conn, packet) self.verification.askUnfinishedTransactions(conn, packet)
(tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True) (tid_list, ) = self.checkAnswerUnfinishedTransactions(conn, decode=True)
self.assertEqual(len(tid_list), 1) self.assertEqual(len(tid_list), 1)
...@@ -237,7 +239,8 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -237,7 +239,8 @@ class StorageVerificationHandlerTests(NeoTestBase):
# ask from client conn with no data # ask from client conn with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
packet = Packet(msg_type=PacketTypes.ASK_TRANSACTION_INFORMATION) packet = Packets.AskTransactionInformation()
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(1)) self.verification.askTransactionInformation(conn, packet, p64(1))
code, message = self.checkErrorPacket(conn, decode=True) code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.TID_NOT_FOUND) self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)
...@@ -252,7 +255,8 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -252,7 +255,8 @@ class StorageVerificationHandlerTests(NeoTestBase):
# object from trans # object from trans
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
packet = Packet(msg_type=PacketTypes.ASK_TRANSACTION_INFORMATION) packet = Packets.AskTransactionInformation()
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(1)) self.verification.askTransactionInformation(conn, packet, p64(1))
tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True) tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 1) self.assertEqual(u64(tid), 1)
...@@ -264,7 +268,8 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -264,7 +268,8 @@ class StorageVerificationHandlerTests(NeoTestBase):
# object from ttrans # object from ttrans
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False }) 'isServer': False })
packet = Packet(msg_type=PacketTypes.ASK_TRANSACTION_INFORMATION) packet = Packets.AskTransactionInformation()
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(3)) self.verification.askTransactionInformation(conn, packet, p64(3))
tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True) tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 3) self.assertEqual(u64(tid), 3)
...@@ -278,7 +283,8 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -278,7 +283,8 @@ class StorageVerificationHandlerTests(NeoTestBase):
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': True }) 'isServer': True })
# find the one in trans # find the one in trans
packet = Packet(msg_type=PacketTypes.ASK_TRANSACTION_INFORMATION) packet = Packets.AskTransactionInformation()
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(1)) self.verification.askTransactionInformation(conn, packet, p64(1))
tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True) tid, user, desc, ext, oid_list = self.checkAnswerTransactionInformation(conn, decode=True)
self.assertEqual(u64(tid), 1) self.assertEqual(u64(tid), 1)
...@@ -290,7 +296,8 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -290,7 +296,8 @@ class StorageVerificationHandlerTests(NeoTestBase):
# do not find the one in ttrans # do not find the one in ttrans
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': True }) 'isServer': True })
packet = Packet(msg_type=PacketTypes.ASK_TRANSACTION_INFORMATION) packet = Packets.AskTransactionInformation()
packet.setId(0)
self.verification.askTransactionInformation(conn, packet, p64(2)) self.verification.askTransactionInformation(conn, packet, p64(2))
code, message = self.checkErrorPacket(conn, decode=True) code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.TID_NOT_FOUND) self.assertEqual(code, ErrorCodes.TID_NOT_FOUND)
...@@ -299,7 +306,8 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -299,7 +306,8 @@ class StorageVerificationHandlerTests(NeoTestBase):
# client connection with no data # client connection with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packet(msg_type=PacketTypes.ASK_OBJECT_PRESENT) packet = Packets.AskObjectPresent()
packet.setId(0)
self.verification.askObjectPresent(conn, packet, p64(1), p64(2)) self.verification.askObjectPresent(conn, packet, p64(1), p64(2))
code, message = self.checkErrorPacket(conn, decode=True) code, message = self.checkErrorPacket(conn, decode=True)
self.assertEqual(code, ErrorCodes.OID_NOT_FOUND) self.assertEqual(code, ErrorCodes.OID_NOT_FOUND)
...@@ -311,7 +319,8 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -311,7 +319,8 @@ class StorageVerificationHandlerTests(NeoTestBase):
self.app.dm.commit() self.app.dm.commit()
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packet(msg_type=PacketTypes.ASK_OBJECT_PRESENT) packet = Packets.AskObjectPresent()
packet.setId(0)
self.verification.askObjectPresent(conn, packet, p64(1), p64(2)) self.verification.askObjectPresent(conn, packet, p64(1), p64(2))
oid, tid = self.checkAnswerObjectPresent(conn, decode=True) oid, tid = self.checkAnswerObjectPresent(conn, decode=True)
self.assertEqual(u64(tid), 2) self.assertEqual(u64(tid), 2)
...@@ -321,7 +330,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -321,7 +330,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
# client connection with no data # client connection with no data
conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port), conn = Mock({ "getAddress" : ("127.0.0.1", self.master_port),
'isServer': False}) 'isServer': False})
packet = Packet(msg_type=PacketTypes.ASK_OBJECT_PRESENT) packet = Packets.AskObjectPresent()
self.verification.deleteTransaction(conn, packet, p64(1)) self.verification.deleteTransaction(conn, packet, p64(1))
# client connection with data # client connection with data
self.app.dm.begin() self.app.dm.begin()
...@@ -338,7 +347,7 @@ class StorageVerificationHandlerTests(NeoTestBase): ...@@ -338,7 +347,7 @@ class StorageVerificationHandlerTests(NeoTestBase):
'isServer': False }) 'isServer': False })
dm = Mock() dm = Mock()
self.app.dm = dm self.app.dm = dm
packet = Packet(msg_type=PacketTypes.COMMIT_TRANSACTION) packet = Packets.CommitTransaction()
self.verification.commitTransaction(conn, packet, p64(1)) self.verification.commitTransaction(conn, packet, p64(1))
self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1) self.assertEqual(len(dm.mockGetNamedCalls("finishTransaction")), 1)
call = dm.mockGetNamedCalls("finishTransaction")[0] call = dm.mockGetNamedCalls("finishTransaction")[0]
......
...@@ -51,7 +51,7 @@ class BootstrapManagerTests(NeoTestBase): ...@@ -51,7 +51,7 @@ class BootstrapManagerTests(NeoTestBase):
conn = Mock({"getUUID" : uuid, conn = Mock({"getUUID" : uuid,
"getAddress" : ("127.0.0.1", self.master_port)}) "getAddress" : ("127.0.0.1", self.master_port)})
self.bootstrap.connectionCompleted(conn) self.bootstrap.connectionCompleted(conn)
self.checkAskPrimaryMaster(conn) self.checkAskPrimary(conn)
def testHandleNotReady(self): def testHandleNotReady(self):
# the primary is not ready # the primary is not ready
......
...@@ -23,8 +23,7 @@ from neo.handler import EventHandler ...@@ -23,8 +23,7 @@ from neo.handler import EventHandler
from neo.tests import DoNothingConnector from neo.tests import DoNothingConnector
from neo.connector import ConnectorException, ConnectorTryAgainException, \ from neo.connector import ConnectorException, ConnectorTryAgainException, \
ConnectorInProgressException, ConnectorConnectionRefusedException ConnectorInProgressException, ConnectorConnectionRefusedException
from neo.protocol import PacketTypes from neo.protocol import Packets
from neo import protocol
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
class ConnectionTests(NeoTestBase): class ConnectionTests(NeoTestBase):
...@@ -394,7 +393,7 @@ class ConnectionTests(NeoTestBase): ...@@ -394,7 +393,7 @@ class ConnectionTests(NeoTestBase):
def test_07_Connection_addPacket(self): def test_07_Connection_addPacket(self):
# no connector # no connector
p = Mock({"encode" : "testdata"}) p = Mock({"__str__" : "testdata"})
em = Mock() em = Mock()
handler = Mock() handler = Mock()
bc = Connection(em, handler, connector_handler=DoNothingConnector, bc = Connection(em, handler, connector_handler=DoNothingConnector,
...@@ -467,9 +466,9 @@ class ConnectionTests(NeoTestBase): ...@@ -467,9 +466,9 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2133), self.getNewUUID()), (("127.0.0.1", 2133), self.getNewUUID()),
(("127.0.0.1", 2435), self.getNewUUID()), (("127.0.0.1", 2435), self.getNewUUID()),
(("127.0.0.1", 2132), self.getNewUUID())) (("127.0.0.1", 2132), self.getNewUUID()))
p = protocol.answerPrimaryMaster(self.getNewUUID(), master_list) p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p.setId(1) p.setId(1)
data = p.encode() data = str(p)
bc.read_buf += data bc.read_buf += data
self.assertEqual(len(bc.event_dict), 0) self.assertEqual(len(bc.event_dict), 0)
bc.analyse() bc.analyse()
...@@ -500,9 +499,9 @@ class ConnectionTests(NeoTestBase): ...@@ -500,9 +499,9 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2133), self.getNewUUID()), (("127.0.0.1", 2133), self.getNewUUID()),
(("127.0.0.1", 2435), self.getNewUUID()), (("127.0.0.1", 2435), self.getNewUUID()),
(("127.0.0.1", 2132), self.getNewUUID())) (("127.0.0.1", 2132), self.getNewUUID()))
p1 = protocol.answerPrimaryMaster(self.getNewUUID(), master_list) p1 = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p1.setId(1) p1.setId(1)
data = p1.encode() data = str(p1)
bc.read_buf += data bc.read_buf += data
# packet 2 # packet 2
master_list = ( master_list = (
...@@ -514,11 +513,11 @@ class ConnectionTests(NeoTestBase): ...@@ -514,11 +513,11 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2133), self.getNewUUID()), (("127.0.0.1", 2133), self.getNewUUID()),
(("127.0.0.1", 2435), self.getNewUUID()), (("127.0.0.1", 2435), self.getNewUUID()),
(("127.0.0.1", 2132), self.getNewUUID())) (("127.0.0.1", 2132), self.getNewUUID()))
p2 = protocol.answerPrimaryMaster( self.getNewUUID(), master_list) p2 = Packets.AnswerPrimary( self.getNewUUID(), master_list)
p2.setId(2) p2.setId(2)
data = p2.encode() data = str(p2)
bc.read_buf += data bc.read_buf += data
self.assertEqual(len(bc.read_buf), len(p1.encode()) + len(p2.encode())) self.assertEqual(len(bc.read_buf), len(p1) + len(p2))
self.assertEqual(len(bc.event_dict), 0) self.assertEqual(len(bc.event_dict), 0)
bc.analyse() bc.analyse()
# check two packets decoded # check two packets decoded
...@@ -569,9 +568,9 @@ class ConnectionTests(NeoTestBase): ...@@ -569,9 +568,9 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2133), self.getNewUUID()), (("127.0.0.1", 2133), self.getNewUUID()),
(("127.0.0.1", 2435), self.getNewUUID()), (("127.0.0.1", 2435), self.getNewUUID()),
(("127.0.0.1", 2132), self.getNewUUID())) (("127.0.0.1", 2132), self.getNewUUID()))
p = protocol.answerPrimaryMaster(self.getNewUUID(), master_list) p = Packets.AnswerPrimary(self.getNewUUID(), master_list)
p.setId(1) p.setId(1)
data = p.encode() data = str(p)
bc.read_buf += data bc.read_buf += data
self.assertEqual(len(bc.event_dict), 0) self.assertEqual(len(bc.event_dict), 0)
bc.expectMessage(1) bc.expectMessage(1)
...@@ -700,10 +699,9 @@ class ConnectionTests(NeoTestBase): ...@@ -700,10 +699,9 @@ class ConnectionTests(NeoTestBase):
(("127.0.0.1", 2435), self.getNewUUID()), (("127.0.0.1", 2435), self.getNewUUID()),
(("127.0.0.1", 2132), self.getNewUUID())) (("127.0.0.1", 2132), self.getNewUUID()))
uuid = self.getNewUUID() uuid = self.getNewUUID()
p = protocol.answerPrimaryMaster(uuid, master_list) p = Packets.AnswerPrimary(uuid, master_list)
p.setId(1) p.setId(1)
data = p.encode() return str(p)
return data
DoNothingConnector.receive = receive DoNothingConnector.receive = receive
connector = DoNothingConnector() connector = DoNothingConnector()
bc = Connection(em, handler, connector_handler=DoNothingConnector, bc = Connection(em, handler, connector_handler=DoNothingConnector,
...@@ -719,7 +717,7 @@ class ConnectionTests(NeoTestBase): ...@@ -719,7 +717,7 @@ class ConnectionTests(NeoTestBase):
self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1) self.assertEquals(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0] call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data.getType(), PacketTypes.ANSWER_PRIMARY_MASTER) self.assertEqual(data.getType(), Packets.AnswerPrimary)
self.assertEqual(data.getId(), 1) self.assertEqual(data.getId(), 1)
self.assertEqual(len(bc.event_dict), 0) self.assertEqual(len(bc.event_dict), 0)
self.assertEqual(bc.read_buf, '') self.assertEqual(bc.read_buf, '')
......
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
import unittest, os import unittest, os
from mock import Mock from mock import Mock
from neo import protocol from neo import protocol
from neo.protocol import Packets
from neo.protocol import NodeTypes, NodeStates, CellStates from neo.protocol import NodeTypes, NodeStates, CellStates
from neo.protocol import ErrorCodes, PacketTypes, Packet from neo.protocol import ErrorCodes, Packets, Packet
from neo.protocol import INVALID_TID, PACKET_HEADER_SIZE from neo.protocol import INVALID_TID, PACKET_HEADER_SIZE
from neo.tests import NeoTestBase from neo.tests import NeoTestBase
from neo.util import getNextTID from neo.util import getNextTID
...@@ -37,17 +38,6 @@ class ProtocolTests(NeoTestBase): ...@@ -37,17 +38,6 @@ class ProtocolTests(NeoTestBase):
self.ltid = getNextTID(self.ltid) self.ltid = getNextTID(self.ltid)
return self.ltid return self.ltid
def test_01_Packet_init(self):
p = Packet(msg_type=PacketTypes.ASK_PRIMARY_MASTER, body=None)
self.assertEqual(p.getType(), PacketTypes.ASK_PRIMARY_MASTER)
self.assertEqual(len(p), PACKET_HEADER_SIZE)
def test_02_error(self):
p = protocol._error(0, "error message")
code, msg = protocol._decodeError(p._body)
self.assertEqual(code, ErrorCodes.NO_ERROR)
self.assertEqual(msg, "error message")
def test_03_protocolError(self): def test_03_protocolError(self):
p = protocol.protocolError("bad protocol") p = protocol.protocolError("bad protocol")
error_code, error_msg = p.decode() error_code, error_msg = p.decode()
...@@ -78,16 +68,16 @@ class ProtocolTests(NeoTestBase): ...@@ -78,16 +68,16 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(error_msg, "tid not found: no tid") self.assertEqual(error_msg, "tid not found: no tid")
def test_09_ping(self): def test_09_ping(self):
p = protocol.ping() p = Packets.Ping()
self.assertEqual(None, p.decode()) self.assertEqual(p.decode(), ())
def test_10_pong(self): def test_10_pong(self):
p = protocol.pong() p = Packets.Pong()
self.assertEqual(None, p.decode()) self.assertEqual(p.decode(), ())
def test_11_requestNodeIdentification(self): def test_11_RequestIdentification(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
p = protocol.requestNodeIdentification(NodeTypes.CLIENT, uuid, p = Packets.RequestIdentification(NodeTypes.CLIENT, uuid,
("127.0.0.1", 9080), "unittest") ("127.0.0.1", 9080), "unittest")
node, p_uuid, (ip, port), name = p.decode() node, p_uuid, (ip, port), name = p.decode()
self.assertEqual(node, NodeTypes.CLIENT) self.assertEqual(node, NodeTypes.CLIENT)
...@@ -96,9 +86,9 @@ class ProtocolTests(NeoTestBase): ...@@ -96,9 +86,9 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(port, 9080) self.assertEqual(port, 9080)
self.assertEqual(name, "unittest") self.assertEqual(name, "unittest")
def test_12_acceptNodeIdentification(self): def test_12_AcceptIdentification(self):
uuid1, uuid2 = self.getNewUUID(), self.getNewUUID() uuid1, uuid2 = self.getNewUUID(), self.getNewUUID()
p = protocol.acceptNodeIdentification(NodeTypes.CLIENT, uuid1, p = Packets.AcceptIdentification(NodeTypes.CLIENT, uuid1,
("127.0.0.1", 9080), 10, 20, uuid2) ("127.0.0.1", 9080), 10, 20, uuid2)
node, p_uuid, (ip, port), nb_partitions, nb_replicas, your_uuid = p.decode() node, p_uuid, (ip, port), nb_partitions, nb_replicas, your_uuid = p.decode()
self.assertEqual(node, NodeTypes.CLIENT) self.assertEqual(node, NodeTypes.CLIENT)
...@@ -109,11 +99,11 @@ class ProtocolTests(NeoTestBase): ...@@ -109,11 +99,11 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(nb_replicas, 20) self.assertEqual(nb_replicas, 20)
self.assertEqual(your_uuid, uuid2) self.assertEqual(your_uuid, uuid2)
def test_13_askPrimaryMaster(self): def test_13_askPrimary(self):
p = protocol.askPrimaryMaster() p = Packets.AskPrimary()
self.assertEqual(None, p.decode()) self.assertEqual(p.decode(), ())
def test_14_answerPrimaryMaster(self): def test_14_answerPrimary(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
uuid1 = self.getNewUUID() uuid1 = self.getNewUUID()
uuid2 = self.getNewUUID() uuid2 = self.getNewUUID()
...@@ -121,18 +111,18 @@ class ProtocolTests(NeoTestBase): ...@@ -121,18 +111,18 @@ class ProtocolTests(NeoTestBase):
master_list = [(("127.0.0.1", 1), uuid1), master_list = [(("127.0.0.1", 1), uuid1),
(("127.0.0.2", 2), uuid2), (("127.0.0.2", 2), uuid2),
(("127.0.0.3", 3), uuid3)] (("127.0.0.3", 3), uuid3)]
p = protocol.answerPrimaryMaster(uuid, master_list) p = Packets.AnswerPrimary(uuid, master_list)
primary_uuid, p_master_list = p.decode() primary_uuid, p_master_list = p.decode()
self.assertEqual(primary_uuid, uuid) self.assertEqual(primary_uuid, uuid)
self.assertEqual(master_list, p_master_list) self.assertEqual(master_list, p_master_list)
def test_15_announcePrimaryMaster(self): def test_15_announcePrimary(self):
p = protocol.announcePrimaryMaster() p = Packets.AnnouncePrimary()
self.assertEqual(p.decode(), None) self.assertEqual(p.decode(), ())
def test_16_reelectPrimaryMaster(self): def test_16_reelectPrimary(self):
p = protocol.reelectPrimaryMaster() p = Packets.ReelectPrimary()
self.assertEqual(p.decode(), None) self.assertEqual(p.decode(), ())
def test_17_notifyNodeInformation(self): def test_17_notifyNodeInformation(self):
uuid = self.getNewUUID() uuid = self.getNewUUID()
...@@ -142,19 +132,19 @@ class ProtocolTests(NeoTestBase): ...@@ -142,19 +132,19 @@ class ProtocolTests(NeoTestBase):
node_list = [(NodeTypes.CLIENT, ("127.0.0.1", 1), uuid1, NodeStates.RUNNING), node_list = [(NodeTypes.CLIENT, ("127.0.0.1", 1), uuid1, NodeStates.RUNNING),
(NodeTypes.CLIENT, ("127.0.0.2", 2), uuid2, NodeStates.DOWN), (NodeTypes.CLIENT, ("127.0.0.2", 2), uuid2, NodeStates.DOWN),
(NodeTypes.CLIENT, ("127.0.0.3", 3), uuid3, NodeStates.BROKEN)] (NodeTypes.CLIENT, ("127.0.0.3", 3), uuid3, NodeStates.BROKEN)]
p = protocol.notifyNodeInformation(node_list) p = Packets.NotifyNodeInformation(node_list)
p_node_list = p.decode()[0] p_node_list = p.decode()[0]
self.assertEqual(node_list, p_node_list) self.assertEqual(node_list, p_node_list)
def test_18_askLastIDs(self): def test_18_askLastIDs(self):
p = protocol.askLastIDs() p = Packets.AskLastIDs()
self.assertEqual(p.decode(), None) self.assertEqual(p.decode(), ())
def test_19_answerLastIDs(self): def test_19_answerLastIDs(self):
oid = self.getNextTID() oid = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
ptid = self.getNextTID() ptid = self.getNextTID()
p = protocol.answerLastIDs(oid, tid, ptid) p = Packets.AnswerLastIDs(oid, tid, ptid)
loid, ltid, lptid = p.decode() loid, ltid, lptid = p.decode()
self.assertEqual(loid, oid) self.assertEqual(loid, oid)
self.assertEqual(ltid, tid) self.assertEqual(ltid, tid)
...@@ -162,7 +152,7 @@ class ProtocolTests(NeoTestBase): ...@@ -162,7 +152,7 @@ class ProtocolTests(NeoTestBase):
def test_20_askPartitionTable(self): def test_20_askPartitionTable(self):
offset_list = [1, 523, 6, 124] offset_list = [1, 523, 6, 124]
p = protocol.askPartitionTable(offset_list) p = Packets.AskPartitionTable(offset_list)
p_offset_list = p.decode()[0] p_offset_list = p.decode()[0]
self.assertEqual(offset_list, p_offset_list) self.assertEqual(offset_list, p_offset_list)
...@@ -174,7 +164,7 @@ class ProtocolTests(NeoTestBase): ...@@ -174,7 +164,7 @@ class ProtocolTests(NeoTestBase):
cell_list = [(0, ((uuid1, CellStates.UP_TO_DATE), (uuid2, CellStates.OUT_OF_DATE))), cell_list = [(0, ((uuid1, CellStates.UP_TO_DATE), (uuid2, CellStates.OUT_OF_DATE))),
(43, ((uuid2, CellStates.OUT_OF_DATE),(uuid3, CellStates.DISCARDED))), (43, ((uuid2, CellStates.OUT_OF_DATE),(uuid3, CellStates.DISCARDED))),
(124, ((uuid1, CellStates.DISCARDED), (uuid3, CellStates.UP_TO_DATE)))] (124, ((uuid1, CellStates.DISCARDED), (uuid3, CellStates.UP_TO_DATE)))]
p = protocol.answerPartitionTable(ptid, cell_list) p = Packets.AnswerPartitionTable(ptid, cell_list)
pptid, p_cell_list = p.decode() pptid, p_cell_list = p.decode()
self.assertEqual(pptid, ptid) self.assertEqual(pptid, ptid)
self.assertEqual(p_cell_list, cell_list) self.assertEqual(p_cell_list, cell_list)
...@@ -187,7 +177,7 @@ class ProtocolTests(NeoTestBase): ...@@ -187,7 +177,7 @@ class ProtocolTests(NeoTestBase):
cell_list = [(0, ((uuid1, CellStates.UP_TO_DATE), (uuid2, CellStates.OUT_OF_DATE))), cell_list = [(0, ((uuid1, CellStates.UP_TO_DATE), (uuid2, CellStates.OUT_OF_DATE))),
(43, ((uuid2, CellStates.OUT_OF_DATE),(uuid3, CellStates.DISCARDED))), (43, ((uuid2, CellStates.OUT_OF_DATE),(uuid3, CellStates.DISCARDED))),
(124, ((uuid1, CellStates.DISCARDED), (uuid3, CellStates.UP_TO_DATE)))] (124, ((uuid1, CellStates.DISCARDED), (uuid3, CellStates.UP_TO_DATE)))]
p = protocol.answerPartitionTable(ptid, cell_list) p = Packets.AnswerPartitionTable(ptid, cell_list)
pptid, p_cell_list = p.decode() pptid, p_cell_list = p.decode()
self.assertEqual(pptid, ptid) self.assertEqual(pptid, ptid)
self.assertEqual(p_cell_list, cell_list) self.assertEqual(p_cell_list, cell_list)
...@@ -200,23 +190,23 @@ class ProtocolTests(NeoTestBase): ...@@ -200,23 +190,23 @@ class ProtocolTests(NeoTestBase):
cell_list = [(0, uuid1, CellStates.UP_TO_DATE), cell_list = [(0, uuid1, CellStates.UP_TO_DATE),
(43, uuid2, CellStates.OUT_OF_DATE), (43, uuid2, CellStates.OUT_OF_DATE),
(124, uuid1, CellStates.DISCARDED)] (124, uuid1, CellStates.DISCARDED)]
p = protocol.notifyPartitionChanges(ptid, p = Packets.NotifyPartitionChanges(ptid,
cell_list) cell_list)
pptid, p_cell_list = p.decode() pptid, p_cell_list = p.decode()
self.assertEqual(pptid, ptid) self.assertEqual(pptid, ptid)
self.assertEqual(p_cell_list, cell_list) self.assertEqual(p_cell_list, cell_list)
def test_24_startOperation(self): def test_24_startOperation(self):
p = protocol.startOperation() p = Packets.StartOperation()
self.assertEqual(p.decode(), None) self.assertEqual(p.decode(), ())
def test_25_stopOperation(self): def test_25_stopOperation(self):
p = protocol.stopOperation() p = Packets.StopOperation()
self.assertEqual(p.decode(), None) self.assertEqual(p.decode(), ())
def test_26_askUnfinishedTransaction(self): def test_26_askUnfinishedTransaction(self):
p = protocol.askUnfinishedTransactions() p = Packets.AskUnfinishedTransactions()
self.assertEqual(p.decode(), None) self.assertEqual(p.decode(), ())
def test_27_answerUnfinishedTransaction(self): def test_27_answerUnfinishedTransaction(self):
tid1 = self.getNextTID() tid1 = self.getNextTID()
...@@ -224,14 +214,14 @@ class ProtocolTests(NeoTestBase): ...@@ -224,14 +214,14 @@ class ProtocolTests(NeoTestBase):
tid3 = self.getNextTID() tid3 = self.getNextTID()
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid_list = [tid1, tid2, tid3, tid4] tid_list = [tid1, tid2, tid3, tid4]
p = protocol.answerUnfinishedTransactions(tid_list) p = Packets.AnswerUnfinishedTransactions(tid_list)
p_tid_list = p.decode()[0] p_tid_list = p.decode()[0]
self.assertEqual(p_tid_list, tid_list) self.assertEqual(p_tid_list, tid_list)
def test_28_askObjectPresent(self): def test_28_askObjectPresent(self):
oid = self.getNextTID() oid = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.askObjectPresent(oid, tid) p = Packets.AskObjectPresent(oid, tid)
loid, ltid = p.decode() loid, ltid = p.decode()
self.assertEqual(loid, oid) self.assertEqual(loid, oid)
self.assertEqual(ltid, tid) self.assertEqual(ltid, tid)
...@@ -239,21 +229,21 @@ class ProtocolTests(NeoTestBase): ...@@ -239,21 +229,21 @@ class ProtocolTests(NeoTestBase):
def test_29_answerObjectPresent(self): def test_29_answerObjectPresent(self):
oid = self.getNextTID() oid = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.answerObjectPresent(oid, tid) p = Packets.AnswerObjectPresent(oid, tid)
loid, ltid = p.decode() loid, ltid = p.decode()
self.assertEqual(loid, oid) self.assertEqual(loid, oid)
self.assertEqual(ltid, tid) self.assertEqual(ltid, tid)
def test_30_deleteTransaction(self): def test_30_deleteTransaction(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.deleteTransaction(tid) p = Packets.DeleteTransaction(tid)
self.assertEqual(p.getType(), PacketTypes.DELETE_TRANSACTION) self.assertEqual(p.getType(), Packets.DeleteTransaction)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
def test_31_commitTransaction(self): def test_31_commitTransaction(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.commitTransaction(tid) p = Packets.CommitTransaction(tid)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
...@@ -261,21 +251,21 @@ class ProtocolTests(NeoTestBase): ...@@ -261,21 +251,21 @@ class ProtocolTests(NeoTestBase):
def test_32_askBeginTransaction(self): def test_32_askBeginTransaction(self):
# try with an invalid TID, None must be returned # try with an invalid TID, None must be returned
tid = '\0' * 8 tid = '\0' * 8
p = protocol.askBeginTransaction(tid) p = Packets.AskBeginTransaction(tid)
self.assertEqual(p.decode(), (None, )) self.assertEqual(p.decode(), (None, ))
# and with another TID # and with another TID
tid = '\1' * 8 tid = '\1' * 8
p = protocol.askBeginTransaction(tid) p = Packets.AskBeginTransaction(tid)
self.assertEqual(p.decode(), (tid, )) self.assertEqual(p.decode(), (tid, ))
def test_33_answerBeginTransaction(self): def test_33_answerBeginTransaction(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.answerBeginTransaction(tid) p = Packets.AnswerBeginTransaction(tid)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
def test_34_askNewOIDs(self): def test_34_askNewOIDs(self):
p = protocol.askNewOIDs(10) p = Packets.AskNewOIDs(10)
nb = p.decode() nb = p.decode()
self.assertEqual(nb, (10,)) self.assertEqual(nb, (10,))
...@@ -285,7 +275,7 @@ class ProtocolTests(NeoTestBase): ...@@ -285,7 +275,7 @@ class ProtocolTests(NeoTestBase):
oid3 = self.getNextTID() oid3 = self.getNextTID()
oid4 = self.getNextTID() oid4 = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4] oid_list = [oid1, oid2, oid3, oid4]
p = protocol.answerNewOIDs(oid_list) p = Packets.AnswerNewOIDs(oid_list)
p_oid_list = p.decode()[0] p_oid_list = p.decode()[0]
self.assertEqual(p_oid_list, oid_list) self.assertEqual(p_oid_list, oid_list)
...@@ -296,26 +286,26 @@ class ProtocolTests(NeoTestBase): ...@@ -296,26 +286,26 @@ class ProtocolTests(NeoTestBase):
oid4 = self.getNextTID() oid4 = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4] oid_list = [oid1, oid2, oid3, oid4]
p = protocol.finishTransaction(oid_list, tid) p = Packets.FinishTransaction(oid_list, tid)
p_oid_list, ptid = p.decode() p_oid_list, ptid = p.decode()
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertEqual(p_oid_list, oid_list) self.assertEqual(p_oid_list, oid_list)
def test_37_notifyTransactionFinished(self): def test_37_notifyTransactionFinished(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.notifyTransactionFinished(tid) p = Packets.NotifyTransactionFinished(tid)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
def test_38_lockInformation(self): def test_38_lockInformation(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.lockInformation(tid) p = Packets.LockInformation(tid)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
def test_39_notifyInformationLocked(self): def test_39_notifyInformationLocked(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.notifyInformationLocked(tid) p = Packets.NotifyInformationLocked(tid)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
...@@ -326,20 +316,20 @@ class ProtocolTests(NeoTestBase): ...@@ -326,20 +316,20 @@ class ProtocolTests(NeoTestBase):
oid4 = self.getNextTID() oid4 = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4] oid_list = [oid1, oid2, oid3, oid4]
p = protocol.invalidateObjects(oid_list, tid) p = Packets.InvalidateObjects(oid_list, tid)
p_oid_list, ptid = p.decode() p_oid_list, ptid = p.decode()
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertEqual(p_oid_list, oid_list) self.assertEqual(p_oid_list, oid_list)
def test_41_unlockInformation(self): def test_41_unlockInformation(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.unlockInformation(tid) p = Packets.UnlockInformation(tid)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
def test_42_abortTransaction(self): def test_42_abortTransaction(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.abortTransaction(tid) p = Packets.AbortTransaction(tid)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
...@@ -350,7 +340,7 @@ class ProtocolTests(NeoTestBase): ...@@ -350,7 +340,7 @@ class ProtocolTests(NeoTestBase):
oid3 = self.getNextTID() oid3 = self.getNextTID()
oid4 = self.getNextTID() oid4 = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4] oid_list = [oid1, oid2, oid3, oid4]
p = protocol.askStoreTransaction(tid, "moi", "transaction", "exti", oid_list) p = Packets.AskStoreTransaction(tid, "moi", "transaction", "exti", oid_list)
ptid, user, desc, ext, p_oid_list = p.decode() ptid, user, desc, ext, p_oid_list = p.decode()
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
self.assertEqual(p_oid_list, oid_list) self.assertEqual(p_oid_list, oid_list)
...@@ -360,7 +350,7 @@ class ProtocolTests(NeoTestBase): ...@@ -360,7 +350,7 @@ class ProtocolTests(NeoTestBase):
def test_44_answerStoreTransaction(self): def test_44_answerStoreTransaction(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.answerStoreTransaction(tid) p = Packets.AnswerStoreTransaction(tid)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
...@@ -368,7 +358,7 @@ class ProtocolTests(NeoTestBase): ...@@ -368,7 +358,7 @@ class ProtocolTests(NeoTestBase):
oid = self.getNextTID() oid = self.getNextTID()
serial = self.getNextTID() serial = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.askStoreObject(oid, serial, 1, 55, "to", tid) p = Packets.AskStoreObject(oid, serial, 1, 55, "to", tid)
poid, pserial, compression, checksum, data, ptid = p.decode() poid, pserial, compression, checksum, data, ptid = p.decode()
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
self.assertEqual(serial, pserial) self.assertEqual(serial, pserial)
...@@ -380,7 +370,7 @@ class ProtocolTests(NeoTestBase): ...@@ -380,7 +370,7 @@ class ProtocolTests(NeoTestBase):
def test_46_answerStoreObject(self): def test_46_answerStoreObject(self):
oid = self.getNextTID() oid = self.getNextTID()
serial = self.getNextTID() serial = self.getNextTID()
p = protocol.answerStoreObject(1, oid, serial) p = Packets.AnswerStoreObject(1, oid, serial)
conflicting, poid, pserial = p.decode() conflicting, poid, pserial = p.decode()
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
self.assertEqual(serial, pserial) self.assertEqual(serial, pserial)
...@@ -390,7 +380,7 @@ class ProtocolTests(NeoTestBase): ...@@ -390,7 +380,7 @@ class ProtocolTests(NeoTestBase):
oid = self.getNextTID() oid = self.getNextTID()
serial = self.getNextTID() serial = self.getNextTID()
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.askObject(oid, serial, tid) p = Packets.AskObject(oid, serial, tid)
poid, pserial, ptid = p.decode() poid, pserial, ptid = p.decode()
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
self.assertEqual(serial, pserial) self.assertEqual(serial, pserial)
...@@ -400,7 +390,7 @@ class ProtocolTests(NeoTestBase): ...@@ -400,7 +390,7 @@ class ProtocolTests(NeoTestBase):
oid = self.getNextTID() oid = self.getNextTID()
serial_start = self.getNextTID() serial_start = self.getNextTID()
serial_end = self.getNextTID() serial_end = self.getNextTID()
p = protocol.answerObject(oid, serial_start, serial_end, 1, 55, "to",) p = Packets.AnswerObject(oid, serial_start, serial_end, 1, 55, "to",)
poid, pserial_start, pserial_end, compression, checksum, data= p.decode() poid, pserial_start, pserial_end, compression, checksum, data= p.decode()
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
self.assertEqual(serial_start, pserial_start) self.assertEqual(serial_start, pserial_start)
...@@ -410,7 +400,7 @@ class ProtocolTests(NeoTestBase): ...@@ -410,7 +400,7 @@ class ProtocolTests(NeoTestBase):
self.assertEqual(data, "to") self.assertEqual(data, "to")
def test_49_askTIDs(self): def test_49_askTIDs(self):
p = protocol.askTIDs(1, 10, 5) p = Packets.AskTIDs(1, 10, 5)
first, last, partition = p.decode() first, last, partition = p.decode()
self.assertEqual(first, 1) self.assertEqual(first, 1)
self.assertEqual(last, 10) self.assertEqual(last, 10)
...@@ -422,13 +412,13 @@ class ProtocolTests(NeoTestBase): ...@@ -422,13 +412,13 @@ class ProtocolTests(NeoTestBase):
tid3 = self.getNextTID() tid3 = self.getNextTID()
tid4 = self.getNextTID() tid4 = self.getNextTID()
tid_list = [tid1, tid2, tid3, tid4] tid_list = [tid1, tid2, tid3, tid4]
p = protocol.answerTIDs(tid_list) p = Packets.AnswerTIDs(tid_list)
p_tid_list = p.decode()[0] p_tid_list = p.decode()[0]
self.assertEqual(p_tid_list, tid_list) self.assertEqual(p_tid_list, tid_list)
def test_51_askTransactionInfomation(self): def test_51_askTransactionInfomation(self):
tid = self.getNextTID() tid = self.getNextTID()
p = protocol.askTransactionInformation(tid) p = Packets.AskTransactionInformation(tid)
ptid = p.decode()[0] ptid = p.decode()[0]
self.assertEqual(tid, ptid) self.assertEqual(tid, ptid)
...@@ -439,7 +429,7 @@ class ProtocolTests(NeoTestBase): ...@@ -439,7 +429,7 @@ class ProtocolTests(NeoTestBase):
oid3 = self.getNextTID() oid3 = self.getNextTID()
oid4 = self.getNextTID() oid4 = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4] oid_list = [oid1, oid2, oid3, oid4]
p = protocol.answerTransactionInformation(tid, "moi", p = Packets.AnswerTransactionInformation(tid, "moi",
"transaction", "exti", oid_list) "transaction", "exti", oid_list)
ptid, user, desc, ext, p_oid_list = p.decode() ptid, user, desc, ext, p_oid_list = p.decode()
self.assertEqual(ptid, tid) self.assertEqual(ptid, tid)
...@@ -450,7 +440,7 @@ class ProtocolTests(NeoTestBase): ...@@ -450,7 +440,7 @@ class ProtocolTests(NeoTestBase):
def test_53_askObjectHistory(self): def test_53_askObjectHistory(self):
oid = self.getNextTID() oid = self.getNextTID()
p = protocol.askObjectHistory(oid, 1, 10,) p = Packets.AskObjectHistory(oid, 1, 10,)
poid, first, last = p.decode() poid, first, last = p.decode()
self.assertEqual(first, 1) self.assertEqual(first, 1)
self.assertEqual(last, 10) self.assertEqual(last, 10)
...@@ -463,13 +453,13 @@ class ProtocolTests(NeoTestBase): ...@@ -463,13 +453,13 @@ class ProtocolTests(NeoTestBase):
hist3 = (self.getNextTID(), 326) hist3 = (self.getNextTID(), 326)
hist4 = (self.getNextTID(), 652) hist4 = (self.getNextTID(), 652)
hist_list = [hist1, hist2, hist3, hist4] hist_list = [hist1, hist2, hist3, hist4]
p = protocol.answerObjectHistory(oid, hist_list) p = Packets.AnswerObjectHistory(oid, hist_list)
poid, p_hist_list = p.decode() poid, p_hist_list = p.decode()
self.assertEqual(p_hist_list, hist_list) self.assertEqual(p_hist_list, hist_list)
self.assertEqual(oid, poid) self.assertEqual(oid, poid)
def test_55_askOIDs(self): def test_55_askOIDs(self):
p = protocol.askOIDs(1, 10, 5) p = Packets.AskOIDs(1, 10, 5)
first, last, partition = p.decode() first, last, partition = p.decode()
self.assertEqual(first, 1) self.assertEqual(first, 1)
self.assertEqual(last, 10) self.assertEqual(last, 10)
...@@ -481,7 +471,7 @@ class ProtocolTests(NeoTestBase): ...@@ -481,7 +471,7 @@ class ProtocolTests(NeoTestBase):
oid3 = self.getNextTID() oid3 = self.getNextTID()
oid4 = self.getNextTID() oid4 = self.getNextTID()
oid_list = [oid1, oid2, oid3, oid4] oid_list = [oid1, oid2, oid3, oid4]
p = protocol.answerOIDs(oid_list) p = Packets.AnswerOIDs(oid_list)
p_oid_list = p.decode()[0] p_oid_list = p.decode()[0]
self.assertEqual(p_oid_list, oid_list) self.assertEqual(p_oid_list, oid_list)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment