Commit b38f1439 authored by Vincent Pelletier's avatar Vincent Pelletier

Keep client consistent after close.

- set master_conn to None to clarify disconnection
- purge node pool after closing all connections
- allow restarting polling thread after its shutdown
Also, only start polling thread when needed (side-effect of last point).

git-svn-id: https://svn.erp5.org/repos/neo/trunk@2414 71dcc9de-d417-0410-9af5-da40c76e7ee4
parent f9b02534
...@@ -126,13 +126,12 @@ class Application(object): ...@@ -126,13 +126,12 @@ class Application(object):
# Start polling thread # Start polling thread
self.em = EventManager() self.em = EventManager()
self.poll_thread = ThreadedPoll(self.em, name=name) self.poll_thread = ThreadedPoll(self.em, name=name)
neo.logging.debug('Started %s', self.poll_thread)
psThreadedPoll() psThreadedPoll()
# Internal Attributes common to all thread # Internal Attributes common to all thread
self._db = None self._db = None
self.name = name self.name = name
self.connector_handler = getConnectorHandler(connector) self.connector_handler = getConnectorHandler(connector)
self.dispatcher = Dispatcher() self.dispatcher = Dispatcher(self.poll_thread)
self.nm = NodeManager() self.nm = NodeManager()
self.cp = ConnectionPool(self) self.cp = ConnectionPool(self)
self.pt = None self.pt = None
...@@ -1209,6 +1208,8 @@ class Application(object): ...@@ -1209,6 +1208,8 @@ class Application(object):
# down zope, so use __del__ to close connections # down zope, so use __del__ to close connections
for conn in self.em.getConnectionList(): for conn in self.em.getConnectionList():
conn.close() conn.close()
self.cp.flush()
self.master_conn = None
# Stop polling thread # Stop polling thread
neo.logging.debug('Stopping %s', self.poll_thread) neo.logging.debug('Stopping %s', self.poll_thread)
self.poll_thread.stop() self.poll_thread.stop()
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
from threading import Thread, Event, enumerate as thread_enum from threading import Thread, Event, enumerate as thread_enum
import neo import neo
class ThreadedPoll(Thread): class _ThreadedPoll(Thread):
"""Polling thread.""" """Polling thread."""
# Garbage collector hint: # Garbage collector hint:
...@@ -31,20 +31,55 @@ class ThreadedPoll(Thread): ...@@ -31,20 +31,55 @@ class ThreadedPoll(Thread):
self.em = em self.em = em
self.setDaemon(True) self.setDaemon(True)
self._stop = Event() self._stop = Event()
self.start()
def run(self): def run(self):
while not self._stop.isSet(): neo.logging.debug('Started %s', self)
while not self.stopping():
# First check if we receive any new message from other node # First check if we receive any new message from other node
try: try:
self.em.poll() # XXX: Delay cannot be infinite here, unless we have a way to
# interrupt this call when stopping.
self.em.poll(1)
except: except:
self.neo.logging.error('poll raised, retrying', exc_info=1) self.neo.logging.error('poll raised, retrying', exc_info=1)
self.neo.logging.debug('Threaded poll stopped') self.neo.logging.debug('Threaded poll stopped')
self._stop.clear()
def stop(self): def stop(self):
self._stop.set() self._stop.set()
def stopping(self):
return self._stop.isSet()
class ThreadedPoll(object):
"""
Wrapper for polloing thread, just to be able to start it again when
it stopped.
"""
_thread = None
_started = False
def __init__(self, *args, **kw):
self._args = args
self._kw = kw
self.newThread()
def newThread(self):
self._thread = _ThreadedPoll(*self._args, **self._kw)
def start(self):
if self._started:
self.newThread()
else:
self._started = True
self._thread.start()
def __getattr__(self, key):
return getattr(self._thread, key)
def __repr__(self):
return repr(self._thread)
def psThreadedPoll(log=None): def psThreadedPoll(log=None):
""" """
Logs alive ThreadedPoll threads. Logs alive ThreadedPoll threads.
......
...@@ -178,3 +178,7 @@ class ConnectionPool(object): ...@@ -178,3 +178,7 @@ class ConnectionPool(object):
"""Explicitly remove connection when a node is broken.""" """Explicitly remove connection when a node is broken."""
self.connection_dict.pop(node.getUUID(), None) self.connection_dict.pop(node.getUUID(), None)
def flush(self):
"""Remove all connections"""
self.connection_dict.clear()
...@@ -705,6 +705,7 @@ class MTClientConnection(ClientConnection): ...@@ -705,6 +705,7 @@ class MTClientConnection(ClientConnection):
self.acquire = lock.acquire self.acquire = lock.acquire
self.release = lock.release self.release = lock.release
self.dispatcher = kwargs.pop('dispatcher') self.dispatcher = kwargs.pop('dispatcher')
self.dispatcher.needPollThread()
self.lock() self.lock()
try: try:
super(MTClientConnection, self).__init__(*args, **kwargs) super(MTClientConnection, self).__init__(*args, **kwargs)
......
...@@ -44,12 +44,13 @@ def giant_lock(func): ...@@ -44,12 +44,13 @@ def giant_lock(func):
class Dispatcher: class Dispatcher:
"""Register a packet, connection pair as expecting a response packet.""" """Register a packet, connection pair as expecting a response packet."""
def __init__(self): def __init__(self, poll_thread=None):
self.message_table = {} self.message_table = {}
self.queue_dict = {} self.queue_dict = {}
lock = Lock() lock = Lock()
self.lock_acquire = lock.acquire self.lock_acquire = lock.acquire
self.lock_release = lock.release self.lock_release = lock.release
self.poll_thread = poll_thread
@giant_lock @giant_lock
@profiler_decorator @profiler_decorator
...@@ -64,10 +65,27 @@ class Dispatcher: ...@@ -64,10 +65,27 @@ class Dispatcher:
queue.put(data) queue.put(data)
return True return True
def needPollThread(self):
thread = self.poll_thread
# If thread has been stopped, wait for it to stop
# Note: This is not, ironically, thread safe: if one thread is
# stopping poll thread while we are checking its state here, a
# race condition will occur. If safety is required, locks should
# be added to control the access to thread's "start", "stopping"
# and "stop" methods.
if thread.stopping():
# XXX: ideally, we should wake thread up here, to be sure not
# to wait forever.
thread.join()
if not thread.isAlive():
thread.start()
@giant_lock @giant_lock
@profiler_decorator @profiler_decorator
def register(self, conn, msg_id, queue): def register(self, conn, msg_id, queue):
"""Register an expectation for a reply.""" """Register an expectation for a reply."""
if self.poll_thread is not None:
self.needPollThread()
self.message_table.setdefault(id(conn), {})[msg_id] = queue self.message_table.setdefault(id(conn), {})[msg_id] = queue
queue_dict = self.queue_dict queue_dict = self.queue_dict
key = id(queue) key = id(queue)
......
...@@ -25,7 +25,8 @@ class DispatcherTests(NeoTestBase): ...@@ -25,7 +25,8 @@ class DispatcherTests(NeoTestBase):
def setUp(self): def setUp(self):
NeoTestBase.setUp(self) NeoTestBase.setUp(self)
self.dispatcher = Dispatcher() self.fake_thread = Mock({'stopping': True})
self.dispatcher = Dispatcher(self.fake_thread)
def testRegister(self): def testRegister(self):
conn = object() conn = object()
...@@ -38,6 +39,7 @@ class DispatcherTests(NeoTestBase): ...@@ -38,6 +39,7 @@ class DispatcherTests(NeoTestBase):
self.assertTrue(queue.get(block=False) is MARKER) self.assertTrue(queue.get(block=False) is MARKER)
self.assertTrue(queue.empty()) self.assertTrue(queue.empty())
self.assertFalse(self.dispatcher.dispatch(conn, 2, None)) self.assertFalse(self.dispatcher.dispatch(conn, 2, None))
self.assertEqual(len(self.fake_thread.mockGetNamedCalls('start')), 1)
def testUnregister(self): def testUnregister(self):
conn = object() conn = object()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment