Commit 8aed780b authored by Grégory Wisniewski's avatar Grégory Wisniewski

Fix a typo in protocol, add unregister() method on Dispatcher for consistency

adn remove an XXX. Unify imports for protocol, remove unused variable.


git-svn-id: https://svn.erp5.org/repos/neo/branches/prototype3@1085 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent 6772d446
...@@ -18,10 +18,9 @@ ...@@ -18,10 +18,9 @@
from neo import logging from neo import logging
from neo.client.handlers import BaseHandler, AnswerBaseHandler from neo.client.handlers import BaseHandler, AnswerBaseHandler
from neo.protocol import MASTER_NODE_TYPE, STORAGE_NODE_TYPE, \ from neo.node import MasterNode
RUNNING_STATE, TEMPORARILY_DOWN_STATE
from neo.node import MasterNode, StorageNode
from neo.pt import MTPartitionTable as PartitionTable from neo.pt import MTPartitionTable as PartitionTable
from neo import protocol
from neo.util import dump from neo.util import dump
class PrimaryBootstrapHandler(AnswerBaseHandler): class PrimaryBootstrapHandler(AnswerBaseHandler):
...@@ -37,7 +36,7 @@ class PrimaryBootstrapHandler(AnswerBaseHandler): ...@@ -37,7 +36,7 @@ class PrimaryBootstrapHandler(AnswerBaseHandler):
app = self.app app = self.app
node = app.nm.getNodeByServer(conn.getAddress()) node = app.nm.getNodeByServer(conn.getAddress())
# this must be a master node # this must be a master node
if node_type != MASTER_NODE_TYPE: if node_type != protocol.MASTER_NODE_TYPE:
conn.close() conn.close()
return return
if conn.getAddress() != address: if conn.getAddress() != address:
...@@ -162,24 +161,19 @@ class PrimaryNotificationsHandler(BaseHandler): ...@@ -162,24 +161,19 @@ class PrimaryNotificationsHandler(BaseHandler):
def handleNotifyNodeInformation(self, conn, packet, node_list): def handleNotifyNodeInformation(self, conn, packet, node_list):
app = self.app app = self.app
nm = app.nm
self.app.nm.update(node_list) self.app.nm.update(node_list)
for node_type, addr, uuid, state in node_list: for node_type, addr, uuid, state in node_list:
if node_type != STORAGE_NODE_TYPE or state != RUNNING_STATE: if node_type != protocol.STORAGE_NODE_TYPE \
or state != protocol.RUNNING_STATE:
continue continue
# close connection to this storage if no longer running # close connection to this storage if no longer running
conn = self.app.em.getConnectionByUUID(uuid) conn = self.app.em.getConnectionByUUID(uuid)
if conn is not None: if conn is not None:
conn.close() conn.close()
if node_type == STORAGE_NODE_TYPE: if node_type == protocol.STORAGE_NODE_TYPE:
# Remove from pool connection # Remove from pool connection
app.cp.removeConnection(n) app.cp.removeConnection(n)
# Put fake packets to task queues. self.dispatcher.unregister(conn)
# XXX: this should be done in MTClientConnection
for key in self.dispatcher.message_table.keys():
if id(conn) == key[0]:
queue = self.dispatcher.message_table.pop(key)
queue.put((conn, None))
class PrimaryAnswersHandler(AnswerBaseHandler): class PrimaryAnswersHandler(AnswerBaseHandler):
""" Handle that process expected packets from the primary master """ """ Handle that process expected packets from the primary master """
......
...@@ -38,6 +38,14 @@ class Dispatcher: ...@@ -38,6 +38,14 @@ class Dispatcher:
key = (id(conn), msg_id) key = (id(conn), msg_id)
self.message_table[key] = payload self.message_table[key] = payload
def unregister(self, conn):
""" Unregister a connection and put fake packet in queues to unlock
threads bloking it them """
for key in self.message_table.keys():
if id(conn) == key[0]:
queue = self.message_table.pop(key)
queue.put((conn, None))
def registered(self, conn): def registered(self, conn):
"""Check if a connection is registered into message table.""" """Check if a connection is registered into message table."""
# XXX: serch algorythm could be improved by improving data structure. # XXX: serch algorythm could be improved by improving data structure.
...@@ -46,4 +54,3 @@ class Dispatcher: ...@@ -46,4 +54,3 @@ class Dispatcher:
if searched_id == conn_id: if searched_id == conn_id:
return True return True
return False return False
...@@ -1276,7 +1276,7 @@ def abortTransaction(tid): ...@@ -1276,7 +1276,7 @@ def abortTransaction(tid):
def askStoreTransaction(tid, user, desc, ext, oid_list): def askStoreTransaction(tid, user, desc, ext, oid_list):
lengths = (len(oid_list), len(user), len(desc), len(ext)) lengths = (len(oid_list), len(user), len(desc), len(ext))
body = [pack('!8sLHHH', tid, *length)] body = [pack('!8sLHHH', tid, *lengths)]
body.append(user) body.append(user)
body.append(desc) body.append(desc)
body.append(ext) body.append(ext)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment