Commit bcb18c38 authored by Julien Muchembled's avatar Julien Muchembled

Fix most race conditions causing bad cleanup

parent cd653523
...@@ -159,11 +159,11 @@ class TunnelManager(object): ...@@ -159,11 +159,11 @@ class TunnelManager(object):
while iface_list: while iface_list:
self._tuntap(iface_list.pop()) self._tuntap(iface_list.pop())
def getFreeInterface(self, prefix): def _getFreeInterface(self, prefix):
try: try:
iface = self._free_iface_list.pop() iface = self._free_iface_list.pop()
except IndexError: except IndexError:
iface = self._tuntap() iface = self._tuntap()
self._iface_to_prefix[iface] = prefix self._iface_to_prefix[iface] = prefix
return iface return iface
...@@ -222,8 +222,9 @@ class TunnelManager(object): ...@@ -222,8 +222,9 @@ class TunnelManager(object):
return False return False
logging.info('Establishing a connection with %u/%u', logging.info('Establishing a connection with %u/%u',
int(prefix, 2), len(prefix)) int(prefix, 2), len(prefix))
iface = self.getFreeInterface(prefix) with utils.exit:
self._connection_dict[prefix] = c = Connection(address, iface, prefix) iface = self._getFreeInterface(prefix)
self._connection_dict[prefix] = c = Connection(address, iface, prefix)
if self._gateway_manager is not None: if self._gateway_manager is not None:
for ip in c: for ip in c:
self._gateway_manager.add(ip, True) self._gateway_manager.add(ip, True)
......
...@@ -98,6 +98,45 @@ class ArgParser(argparse.ArgumentParser): ...@@ -98,6 +98,45 @@ class ArgParser(argparse.ArgumentParser):
ca /etc/re6stnet/ca.crt""", **kw) ca /etc/re6stnet/ca.crt""", **kw)
class exit(object):
status = None
def __init__(self):
l = threading.Lock()
self.acquire = l.acquire
r = l.release
def release():
try:
if self.status is not None:
self.release = r
sys.exit(self.status)
finally:
r()
self.release = release
def __enter__(self):
self.acquire()
def __exit__(self, t, v, tb):
self.release()
def kill_main(self, status):
self.status = status
os.kill(os.getpid(), signal.SIGTERM)
def signal(self, status, *sigs):
def handler(*args):
if self.status is None:
self.status = status
if self.acquire(0):
self.release()
for sig in sigs:
signal.signal(sig, handler)
exit = exit()
class Popen(subprocess.Popen): class Popen(subprocess.Popen):
def stop(self): def stop(self):
......
...@@ -5,6 +5,7 @@ from collections import deque ...@@ -5,6 +5,7 @@ from collections import deque
from OpenSSL import crypto from OpenSSL import crypto
from re6st import db, plib, tunnel, utils from re6st import db, plib, tunnel, utils
from re6st.registry import RegistryClient, RENEW_PERIOD from re6st.registry import RegistryClient, RENEW_PERIOD
from re6st.utils import exit
class ReexecException(Exception): class ReexecException(Exception):
pass pass
...@@ -148,10 +149,6 @@ def maybe_renew(path, cert, info, renew): ...@@ -148,10 +149,6 @@ def maybe_renew(path, cert, info, renew):
info, exc_info=exc_info) info, exc_info=exc_info)
return cert, time.time() + 86400 return cert, time.time() + 86400
def exit(status):
exit.status = status
os.kill(os.getpid(), signal.SIGTERM)
def main(): def main():
# Get arguments # Get arguments
config = getConfig() config = getConfig()
...@@ -179,9 +176,8 @@ def main(): ...@@ -179,9 +176,8 @@ def main():
if config.ovpnlog: if config.ovpnlog:
plib.ovpn_log = config.log plib.ovpn_log = config.log
signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1)) exit.signal(0, signal.SIGINT, signal.SIGTERM)
signal.signal(signal.SIGTERM, lambda *args: exit.signal(-1, signal.SIGHUP, signal.SIGUSR2)
sys.exit(getattr(exit, 'status', None)))
registry = RegistryClient(config.registry, config.key, ca) registry = RegistryClient(config.registry, config.key, ca)
cert, next_renew = maybe_renew(config.cert, cert, "Certificate", cert, next_renew = maybe_renew(config.cert, cert, "Certificate",
...@@ -299,6 +295,7 @@ def main(): ...@@ -299,6 +295,7 @@ def main():
tunnel_manager = write_pipe = None tunnel_manager = write_pipe = None
try: try:
exit.acquire()
# Source address selection is defined by RFC 6724, and in most # Source address selection is defined by RFC 6724, and in most
# applications, it usually works thanks to rule 5 (prefer outgoing # applications, it usually works thanks to rule 5 (prefer outgoing
# interface). But here, it rarely applies because we use several # interface). But here, it rarely applies because we use several
...@@ -358,21 +355,23 @@ def main(): ...@@ -358,21 +355,23 @@ def main():
call(if_rt[:3] + ['add', 'proto', 'static'] + if_rt[4:]) call(if_rt[:3] + ['add', 'proto', 'static'] + if_rt[4:])
else: else:
def check_no_default_route(): def check_no_default_route():
for route in call(('ip', '-6', 'route', 'show',
'default')).splitlines():
if ' proto 42 ' not in route:
sys.exit("Detected default route (%s)"
" whereas you specified --table=0."
" Fix your configuration." % route)
check_no_default_route()
def check_no_default_route_thread():
try: try:
while True: while True:
for route in call(('ip', '-6', 'route', 'show',
'default')).splitlines():
if ' proto 42 ' not in route:
logging.fatal("Detected default route (%s)"
" whereas you specified --table=0."
" Fix your configuration.", route)
return
time.sleep(60) time.sleep(60)
check_no_default_route()
except: except:
utils.log_exception() utils.log_exception()
finally: finally:
exit(1) exit.kill_main(1)
t = threading.Thread(target=check_no_default_route) t = threading.Thread(target=check_no_default_route_thread)
t.daemon = True t.daemon = True
t.start() t.start()
ip('route', 'unreachable', *x) ip('route', 'unreachable', *x)
...@@ -384,17 +383,21 @@ def main(): ...@@ -384,17 +383,21 @@ def main():
config.babel_pidfile, tunnel_interfaces, config.babel_pidfile, tunnel_interfaces,
*config.babel_args).stop) *config.babel_args).stop)
if config.up: if config.up:
exit.release()
r = os.system(config.up) r = os.system(config.up)
if r: if r:
sys.exit(r) sys.exit(r)
exit.acquire()
for cmd in config.daemon or (): for cmd in config.daemon or ():
cleanup.append(utils.Popen(cmd, shell=True).stop) cleanup.append(utils.Popen(cmd, shell=True).stop)
# main loop # main loop
if tunnel_manager is None: if tunnel_manager is None:
exit.release()
time.sleep(max(0, next_renew - time.time())) time.sleep(max(0, next_renew - time.time()))
raise ReexecException("Restart to renew certificate") raise ReexecException("Restart to renew certificate")
cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll
exit.release()
while True: while True:
next = tunnel_manager.next_refresh next = tunnel_manager.next_refresh
if forwarder: if forwarder:
...@@ -418,18 +421,20 @@ def main(): ...@@ -418,18 +421,20 @@ def main():
if forwarder and t >= forwarder.next_refresh: if forwarder and t >= forwarder.next_refresh:
forwarder.refresh() forwarder.refresh()
finally: finally:
# XXX: We have a possible race condition if a signal is handled at
# the beginning of this clause, just before the following line.
exit.acquire(0) # inhibit signals
while cleanup: while cleanup:
try: try:
cleanup.pop()() cleanup.pop()()
except: except:
pass pass
exit.release()
except sqlite3.Error: except sqlite3.Error:
logging.exception("Restarting with empty cache") logging.exception("Restarting with empty cache")
os.rename(db_path, db_path + '.bak') os.rename(db_path, db_path + '.bak')
except ReexecException, e: except ReexecException, e:
logging.info(e) logging.info(e)
except KeyboardInterrupt:
return 0
except Exception: except Exception:
utils.log_exception() utils.log_exception()
sys.exit(1) sys.exit(1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment