Commit 469279d9 authored by David Wilson's avatar David Wilson

master: refactor ThreadWatcher

In order to support a .remove() method, to prevent a minor but annoying
(log visible) memory leak while running the tests.
parent e3209d1d
...@@ -117,46 +117,67 @@ def scan_code_imports(co): ...@@ -117,46 +117,67 @@ def scan_code_imports(co):
co.co_consts[arg2] or ()) co.co_consts[arg2] or ())
_join_lock = threading.Lock() class ThreadWatcher(object):
_join_process_id = None """
_join_callbacks_by_target = {} Manage threads that waits for nother threads to shutdown, before invoking
_join_thread_by_target = {} `on_join()`. In CPython it seems possible to use this method to ensure a
non-main thread is signalled when the main thread has exitted, using yet
another thread as a proxy.
"""
_lock = threading.Lock()
_pid = None
_instances_by_target = {}
_thread_by_target = {}
def _join_thread_reset(): @classmethod
def _reset(cls):
"""If we have forked since the watch dictionaries were initialized, all """If we have forked since the watch dictionaries were initialized, all
that has is garbage, so clear it.""" that has is garbage, so clear it."""
global _join_process_id if os.getpid() != cls._pid:
cls._pid = os.getpid()
cls._instances_by_target.clear()
cls._thread_by_target.clear()
if os.getpid() != _join_process_id: def __init__(self, target, on_join):
_join_process_id = os.getpid() self.target = target
_join_callbacks_by_target.clear() self.on_join = on_join
_join_thread_by_target.clear()
@classmethod
def _watch(cls, target):
target.join()
for watcher in cls._instances_by_target[target]:
watcher.on_join()
def join_thread_async(target_thread, on_join): def install(self):
"""Start a thread that waits for another thread to shutdown, before self._lock.acquire()
invoking `on_join()`. In CPython it seems possible to use this method to
ensure a non-main thread is signalled when the main thread has exitted,
using yet another thread as a proxy."""
def _watch():
target_thread.join()
for on_join in _join_callbacks_by_target[target_thread]:
on_join()
_join_lock.acquire()
try: try:
_join_thread_reset() self._reset()
_join_callbacks_by_target.setdefault(target_thread, []).append(on_join) self._instances_by_target.setdefault(self.target, []).append(self)
if target_thread not in _join_thread_by_target: if self.target not in self._thread_by_target:
_join_thread_by_target[target_thread] = threading.Thread( self._thread_by_target[self.target] = threading.Thread(
name='mitogen.master.join_thread_async', name='mitogen.master.join_thread_async',
target=_watch, target=self._watch,
args=(self.target,)
) )
_join_thread_by_target[target_thread].start() self._thread_by_target[self.target].start()
finally: finally:
_join_lock.release() self._lock.release()
def remove(self):
self._lock.acquire()
try:
self._reset()
lst = self._instances_by_target.get(self.target, [])
if self in lst:
lst.remove(self)
finally:
self._lock.release()
@classmethod
def watch(cls, target, on_join):
watcher = cls(target, on_join)
watcher.install()
return watcher
class SelectError(mitogen.core.Error): class SelectError(mitogen.core.Error):
...@@ -604,12 +625,21 @@ class ModuleResponder(object): ...@@ -604,12 +625,21 @@ class ModuleResponder(object):
class Broker(mitogen.core.Broker): class Broker(mitogen.core.Broker):
shutdown_timeout = 5.0 shutdown_timeout = 5.0
_watcher = None
def __init__(self, install_watcher=True): def __init__(self, install_watcher=True):
if install_watcher: if install_watcher:
join_thread_async(threading.currentThread(), self.shutdown) self._watcher = ThreadWatcher.watch(
target=threading.currentThread(),
on_join=self.shutdown,
)
super(Broker, self).__init__() super(Broker, self).__init__()
def shutdown(self):
super(Broker, self).shutdown()
if self._watcher:
self._watcher.remove()
class Context(mitogen.core.Context): class Context(mitogen.core.Context):
via = None via = None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment