Commit f7378a70 authored by Julien Muchembled's avatar Julien Muchembled

Remove overkill Packet.getType method

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2761 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 86ce8602
...@@ -494,13 +494,13 @@ class Connection(BaseConnection): ...@@ -494,13 +494,13 @@ class Connection(BaseConnection):
self.getHandler()._packetMalformed(self, msg) self.getHandler()._packetMalformed(self, msg)
return return
self._timeout.refresh(time()) self._timeout.refresh(time())
packet_type = packet.getType() packet_type = type(packet)
if packet_type == Packets.Ping: if packet_type is Packets.Ping:
# Send a pong notification # Send a pong notification
PACKET_LOGGER.dispatch(self, packet, False) PACKET_LOGGER.dispatch(self, packet, False)
if not self.aborted: if not self.aborted:
self.answer(Packets.Pong(), packet.getId()) self.answer(Packets.Pong(), packet.getId())
elif packet_type == Packets.Pong: elif packet_type is Packets.Pong:
# Skip PONG packets, its only purpose is refresh the timeout # Skip PONG packets, its only purpose is refresh the timeout
# generated upong ping. But still log them. # generated upong ping. But still log them.
PACKET_LOGGER.dispatch(self, packet, False) PACKET_LOGGER.dispatch(self, packet, False)
...@@ -772,7 +772,7 @@ class MTClientConnection(ClientConnection): ...@@ -772,7 +772,7 @@ class MTClientConnection(ClientConnection):
msg_id = self._getNextId() msg_id = self._getNextId()
packet.setId(msg_id) packet.setId(msg_id)
if queue is None: if queue is None:
if not isinstance(packet, Packets.Ping): if type(packet) is not Packets.Ping:
raise TypeError, 'Only Ping packet can be asked ' \ raise TypeError, 'Only Ping packet can be asked ' \
'without a queue, got a %r.' % (packet, ) 'without a queue, got a %r.' % (packet, )
else: else:
......
...@@ -33,7 +33,7 @@ class EventHandler(object): ...@@ -33,7 +33,7 @@ class EventHandler(object):
def __unexpectedPacket(self, conn, packet, message=None): def __unexpectedPacket(self, conn, packet, message=None):
"""Handle an unexpected packet.""" """Handle an unexpected packet."""
if message is None: if message is None:
message = 'unexpected packet type %s in %s' % (packet.getType(), message = 'unexpected packet type %s in %s' % (type(packet),
self.__class__.__name__) self.__class__.__name__)
else: else:
message = 'unexpected packet: %s in %s' % (message, message = 'unexpected packet: %s in %s' % (message,
......
...@@ -193,9 +193,6 @@ class Packet(object): ...@@ -193,9 +193,6 @@ class Packet(object):
assert self._id is not None, "No identifier applied on the packet" assert self._id is not None, "No identifier applied on the packet"
return self._id return self._id
def getType(self):
return self.__class__
def encode(self): def encode(self):
""" Encode a packet as a string to send it over the network """ """ Encode a packet as a string to send it over the network """
content = self._body content = self._body
......
...@@ -334,10 +334,10 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -334,10 +334,10 @@ class NeoUnitTestBase(NeoTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(packet.getType(), Packets.Error) self.assertEqual(type(packet), Packets.Error)
if decode: if decode:
return packet.decode() return packet.decode()
return protocol.decode_table[packet.getType()](packet._body) return protocol.decode_table[type(packet)](packet._body)
return packet return packet
def checkAskPacket(self, conn, packet_type, decode=False): def checkAskPacket(self, conn, packet_type, decode=False):
...@@ -346,7 +346,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -346,7 +346,7 @@ class NeoUnitTestBase(NeoTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(packet.getType(), packet_type) self.assertEqual(type(packet), packet_type)
if decode: if decode:
return packet.decode() return packet.decode()
return packet return packet
...@@ -357,7 +357,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -357,7 +357,7 @@ class NeoUnitTestBase(NeoTestBase):
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
packet = calls.pop().getParam(0) packet = calls.pop().getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(packet.getType(), packet_type) self.assertEqual(type(packet), packet_type)
if decode: if decode:
return packet.decode() return packet.decode()
return packet return packet
...@@ -367,7 +367,7 @@ class NeoUnitTestBase(NeoTestBase): ...@@ -367,7 +367,7 @@ class NeoUnitTestBase(NeoTestBase):
calls = conn.mockGetNamedCalls('notify') calls = conn.mockGetNamedCalls('notify')
packet = calls.pop(packet_number).getParam(0) packet = calls.pop(packet_number).getParam(0)
self.assertTrue(isinstance(packet, protocol.Packet)) self.assertTrue(isinstance(packet, protocol.Packet))
self.assertEqual(packet.getType(), packet_type) self.assertEqual(type(packet), packet_type)
if decode: if decode:
return packet.decode() return packet.decode()
return packet return packet
......
...@@ -934,7 +934,7 @@ class ClientApplicationTests(NeoUnitTestBase): ...@@ -934,7 +934,7 @@ class ClientApplicationTests(NeoUnitTestBase):
now = time.time() now = time.time()
app.pack(now) app.pack(now)
self.assertEqual(len(marker), 1) self.assertEqual(len(marker), 1)
self.assertEqual(marker[0].getType(), Packets.AskPack) self.assertEqual(type(marker[0]), Packets.AskPack)
# XXX: how to validate packet content ? # XXX: how to validate packet content ?
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -121,7 +121,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -121,7 +121,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')] packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
packet_list, next_range = packet_list[:-1], packet_list[-1] packet_list, next_range = packet_list[:-1], packet_list[-1]
self.assertEqual(next_range.getType(), Packets.AskCheckTIDRange) self.assertEqual(type(next_range), Packets.AskCheckTIDRange)
pmin_tid, plength, ppartition = next_range.decode() pmin_tid, plength, ppartition = next_range.decode()
self.assertEqual(pmin_tid, add64(next_tid, 1)) self.assertEqual(pmin_tid, add64(next_tid, 1))
self.assertEqual(plength, RANGE_LENGTH) self.assertEqual(plength, RANGE_LENGTH)
...@@ -132,7 +132,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -132,7 +132,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
self.assertEqual(len(packet_list), len(tid_list)) self.assertEqual(len(packet_list), len(tid_list))
for packet in packet_list: for packet in packet_list:
self.assertEqual(packet.getType(), self.assertEqual(type(packet),
Packets.AskTransactionInformation) Packets.AskTransactionInformation)
ptid = packet.decode()[0] ptid = packet.decode()[0]
for tid in tid_list: for tid in tid_list:
...@@ -147,7 +147,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -147,7 +147,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')] packet_list = [x.getParam(0) for x in conn.mockGetNamedCalls('ask')]
packet_list, next_range = packet_list[:-1], packet_list[-1] packet_list, next_range = packet_list[:-1], packet_list[-1]
self.assertEqual(next_range.getType(), Packets.AskCheckSerialRange) self.assertEqual(type(next_range), Packets.AskCheckSerialRange)
pmin_oid, pmin_serial, plength, ppartition = next_range.decode() pmin_oid, pmin_serial, plength, ppartition = next_range.decode()
self.assertEqual(pmin_oid, next_oid) self.assertEqual(pmin_oid, next_oid)
self.assertEqual(pmin_serial, add64(next_serial, 1)) self.assertEqual(pmin_serial, add64(next_serial, 1))
...@@ -422,7 +422,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -422,7 +422,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
calls = conn.mockGetNamedCalls('ask') calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
tid_packet = calls[0].getParam(0) tid_packet = calls[0].getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom) self.assertEqual(type(tid_packet), Packets.AskTIDsFrom)
pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode() pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid) self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid) self.assertEqual(pmax_tid, critical_tid)
...@@ -449,7 +449,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -449,7 +449,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
calls = conn.mockGetNamedCalls('ask') calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 2) self.assertEqual(len(calls), 2)
tid_packet = calls[0].getParam(0) tid_packet = calls[0].getParam(0)
self.assertEqual(tid_packet.getType(), Packets.AskTIDsFrom) self.assertEqual(type(tid_packet), Packets.AskTIDsFrom)
pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode() pmin_tid, pmax_tid, plength, ppartition = tid_packet.decode()
self.assertEqual(pmin_tid, min_tid) self.assertEqual(pmin_tid, min_tid)
self.assertEqual(pmax_tid, critical_tid) self.assertEqual(pmax_tid, critical_tid)
...@@ -577,7 +577,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase): ...@@ -577,7 +577,7 @@ class StorageReplicationHandlerTests(NeoUnitTestBase):
calls = conn.mockGetNamedCalls('ask') calls = conn.mockGetNamedCalls('ask')
self.assertEqual(len(calls), 1) self.assertEqual(len(calls), 1)
serial_packet = calls[0].getParam(0) serial_packet = calls[0].getParam(0)
self.assertEqual(serial_packet.getType(), Packets.AskObjectHistoryFrom) self.assertEqual(type(serial_packet), Packets.AskObjectHistoryFrom)
pmin_oid, pmin_serial, pmax_serial, plength, ppartition = \ pmin_oid, pmin_serial, pmax_serial, plength, ppartition = \
serial_packet.decode() serial_packet.decode()
self.assertEqual(pmin_oid, min_oid) self.assertEqual(pmin_oid, min_oid)
......
...@@ -115,7 +115,8 @@ class StorageReplicatorTests(NeoUnitTestBase): ...@@ -115,7 +115,8 @@ class StorageReplicatorTests(NeoUnitTestBase):
act() act()
unfinished_tids = app.master_conn.mockGetNamedCalls('ask')[0].getParam(0) unfinished_tids = app.master_conn.mockGetNamedCalls('ask')[0].getParam(0)
self.assertTrue(replicator.new_partition_set) self.assertTrue(replicator.new_partition_set)
self.assertEqual(unfinished_tids.getType(), Packets.AskUnfinishedTransactions) self.assertEqual(type(unfinished_tids),
Packets.AskUnfinishedTransactions)
self.assertTrue(replicator.waiting_for_unfinished_tids) self.assertTrue(replicator.waiting_for_unfinished_tids)
# nothing happens until waiting_for_unfinished_tids becomes False # nothing happens until waiting_for_unfinished_tids becomes False
act() act()
......
...@@ -413,7 +413,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -413,7 +413,7 @@ class ConnectionTests(NeoUnitTestBase):
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1) self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0] call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data.getType(), p.getType()) self.assertEqual(type(data), type(p))
self.assertEqual(data.getId(), p.getId()) self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode()) self.assertEqual(data.decode(), p.decode())
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
...@@ -455,13 +455,13 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -455,13 +455,13 @@ class ConnectionTests(NeoUnitTestBase):
# packet 1 # packet 1
call = bc._queue.mockGetNamedCalls("append")[0] call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data.getType(), p1.getType()) self.assertEqual(type(data), type(p1))
self.assertEqual(data.getId(), p1.getId()) self.assertEqual(data.getId(), p1.getId())
self.assertEqual(data.decode(), p1.decode()) self.assertEqual(data.decode(), p1.decode())
# packet 2 # packet 2
call = bc._queue.mockGetNamedCalls("append")[1] call = bc._queue.mockGetNamedCalls("append")[1]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data.getType(), p2.getType()) self.assertEqual(type(data), type(p2))
self.assertEqual(data.getId(), p2.getId()) self.assertEqual(data.getId(), p2.getId())
self.assertEqual(data.decode(), p2.decode()) self.assertEqual(data.decode(), p2.decode())
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
...@@ -497,7 +497,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -497,7 +497,7 @@ class ConnectionTests(NeoUnitTestBase):
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1) self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0] call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data.getType(), p.getType()) self.assertEqual(type(data), type(p))
self.assertEqual(data.getId(), p.getId()) self.assertEqual(data.getId(), p.getId())
self.assertEqual(data.decode(), p.decode()) self.assertEqual(data.decode(), p.decode())
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
...@@ -519,7 +519,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -519,7 +519,7 @@ class ConnectionTests(NeoUnitTestBase):
buffer.append(chunk) buffer.append(chunk)
answer = Packets.parse(buffer, parser_state) answer = Packets.parse(buffer, parser_state)
self.assertTrue(answer is not None) self.assertTrue(answer is not None)
self.assertTrue(answer.getType() == Packets.Pong) self.assertTrue(type(answer) == Packets.Pong)
self.assertEqual(answer.getId(), p.getId()) self.assertEqual(answer.getId(), p.getId())
def test_Connection_analyse6(self): def test_Connection_analyse6(self):
...@@ -636,7 +636,7 @@ class ConnectionTests(NeoUnitTestBase): ...@@ -636,7 +636,7 @@ class ConnectionTests(NeoUnitTestBase):
self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1) self.assertEqual(len(bc._queue.mockGetNamedCalls("append")), 1)
call = bc._queue.mockGetNamedCalls("append")[0] call = bc._queue.mockGetNamedCalls("append")[0]
data = call.getParam(0) data = call.getParam(0)
self.assertEqual(data.getType(), Packets.AnswerPrimary) self.assertEqual(type(data), Packets.AnswerPrimary)
self.assertEqual(data.getId(), 1) self.assertEqual(data.getId(), 1)
self._checkReadBuf(bc, '') self._checkReadBuf(bc, '')
# check not aborted # check not aborted
......
...@@ -259,7 +259,7 @@ class ProtocolTests(NeoUnitTestBase): ...@@ -259,7 +259,7 @@ class ProtocolTests(NeoUnitTestBase):
tid = self.getNextTID() tid = self.getNextTID()
oid_list = [self.getOID(1), self.getOID(2)] oid_list = [self.getOID(1), self.getOID(2)]
p = Packets.DeleteTransaction(tid, oid_list) p = Packets.DeleteTransaction(tid, oid_list)
self.assertEqual(p.getType(), Packets.DeleteTransaction) self.assertEqual(type(p), Packets.DeleteTransaction)
self.assertEqual(p.decode(), (tid, oid_list)) self.assertEqual(p.decode(), (tid, oid_list))
def test_31_commitTransaction(self): def test_31_commitTransaction(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment