Commit 8ebdd500 authored by Julien Muchembled's avatar Julien Muchembled

Certificate revocation, with broadcast of CRL

parent f73c51ec
...@@ -20,3 +20,8 @@ ...@@ -20,3 +20,8 @@
- registry: add '--home PATH' command line option so that / display an HTML - registry: add '--home PATH' command line option so that / display an HTML
page from PATH (use new str.format for templating) page from PATH (use new str.format for templating)
- Better UI to revoke certificates, for example with a HTML form.
Currently, one have to forge the URL manually. Examples:
wget -O /dev/null http://re6st.example.com/revoke?cn_or_serial=123
wget -O /dev/null http://re6st.example.com/revoke?cn_or_serial=4/16
...@@ -4,6 +4,8 @@ from . import utils, version, x509 ...@@ -4,6 +4,8 @@ from . import utils, version, x509
class Cache(object): class Cache(object):
crl = ()
def __init__(self, db_path, registry, cert, db_size=200): def __init__(self, db_path, registry, cert, db_size=200):
self._prefix = cert.prefix self._prefix = cert.prefix
self._db_size = db_size self._db_size = db_size
...@@ -40,6 +42,7 @@ class Cache(object): ...@@ -40,6 +42,7 @@ class Cache(object):
# when it tried to send us new parameters. # when it tried to send us new parameters.
or self._prefix == self.registry_prefix): or self._prefix == self.registry_prefix):
self.updateConfig() self.updateConfig()
self.next_renew = cert.maybeRenew(self._registry, self.crl)
if version.protocol < self.min_protocol: if version.protocol < self.min_protocol:
logging.critical("Your version of re6stnet is too old." logging.critical("Your version of re6stnet is too old."
" Please update.") " Please update.")
...@@ -64,7 +67,11 @@ class Cache(object): ...@@ -64,7 +67,11 @@ class Cache(object):
cls = self.__class__ cls = self.__class__
logging.debug("Loading network parameters:") logging.debug("Loading network parameters:")
for k, v in config: for k, v in config:
hasattr(cls, k) or setattr(self, k, v) if k == 'crl':
v = set(json.loads(v))
elif hasattr(cls, k):
continue
setattr(self, k, v)
logging.debug("- %s: %r", k, v) logging.debug("- %s: %r", k, v)
def updateConfig(self): def updateConfig(self):
...@@ -77,6 +84,7 @@ class Cache(object): ...@@ -77,6 +84,7 @@ class Cache(object):
config = dict((str(k), v.decode('base64') if k in base64 else config = dict((str(k), v.decode('base64') if k in base64 else
str(v) if type(v) is unicode else v) str(v) if type(v) is unicode else v)
for k, v in config.iteritems()) for k, v in config.iteritems())
config['crl'] = json.dumps(config['crl'])
except socket.error, e: except socket.error, e:
logging.warning(e) logging.warning(e)
return return
......
...@@ -9,7 +9,7 @@ if script_type == 'up': ...@@ -9,7 +9,7 @@ if script_type == 'up':
os.execlp('ip', 'ip', 'link', 'set', os.environ['dev'], 'up', os.execlp('ip', 'ip', 'link', 'set', os.environ['dev'], 'up',
'mtu', os.environ['tun_mtu']) 'mtu', os.environ['tun_mtu'])
# Write into pipe external ip address received if script_type == 'route-up':
import time import time
os.write(int(sys.argv[1]), "%s %s %s %s\n" % (script_type, os.write(int(sys.argv[1]), repr((os.environ['common_name'], time.time(),
os.environ['common_name'], time.time(), os.environ['OPENVPN_external_ip'])) int(os.environ['tls_serial_0']), os.environ['OPENVPN_external_ip'])))
...@@ -2,15 +2,16 @@ ...@@ -2,15 +2,16 @@
import os, sys import os, sys
script_type = os.environ['script_type'] script_type = os.environ['script_type']
external_ip = lambda: os.getenv('trusted_ip') or os.environ['trusted_ip6'] external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
# Write into pipe connect/disconnect events
fd = int(sys.argv[1])
os.write(fd, repr((script_type, (os.environ['common_name'], os.environ['dev'],
int(os.environ['tls_serial_0']), external_ip))))
if script_type == 'client-connect': if script_type == 'client-connect':
if os.read(fd, 1) == '\0':
sys.exit(1)
# Send client its external ip address # Send client its external ip address
with open(sys.argv[2], 'w') as f: with open(sys.argv[2], 'w') as f:
f.write('push "setenv-safe external_ip %s"\n' % external_ip()) f.write('push "setenv-safe external_ip %s"\n' % external_ip)
# Write into pipe connect/disconnect events
arg1 = sys.argv[1]
if arg1 != 'None':
os.write(int(arg1), '%s %s %s\n' % (
script_type, os.environ['common_name'], external_ip()))
...@@ -25,10 +25,8 @@ def openvpn(iface, encrypt, *args, **kw): ...@@ -25,10 +25,8 @@ def openvpn(iface, encrypt, *args, **kw):
ovpn_link_mtu_dict = {'udp': 1481, 'udp6': 1450} ovpn_link_mtu_dict = {'udp': 1481, 'udp6': 1450}
def server(iface, max_clients, dh_path, pipe_fd, port, proto, encrypt, *args, **kw): def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
client_script = '%s %s' % (ovpn_server, pipe_fd) client_script = '%s %s' % (ovpn_server, fd)
if pipe_fd is not None:
args = ('--client-disconnect', client_script) + args
try: try:
args = ('--link-mtu', str(ovpn_link_mtu_dict[proto]), args = ('--link-mtu', str(ovpn_link_mtu_dict[proto]),
# mtu-disc ignored for udp6 due to a bug in OpenVPN # mtu-disc ignored for udp6 due to a bug in OpenVPN
...@@ -39,6 +37,7 @@ def server(iface, max_clients, dh_path, pipe_fd, port, proto, encrypt, *args, ** ...@@ -39,6 +37,7 @@ def server(iface, max_clients, dh_path, pipe_fd, port, proto, encrypt, *args, **
'--tls-server', '--tls-server',
'--mode', 'server', '--mode', 'server',
'--client-connect', client_script, '--client-connect', client_script,
'--client-disconnect', client_script,
'--dh', dh_path, '--dh', dh_path,
'--max-clients', str(max_clients), '--max-clients', str(max_clients),
'--port', str(port), '--port', str(port),
......
...@@ -25,6 +25,7 @@ from collections import defaultdict, deque ...@@ -25,6 +25,7 @@ from collections import defaultdict, deque
from datetime import datetime from datetime import datetime
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from email.mime.text import MIMEText from email.mime.text import MIMEText
from operator import itemgetter
from OpenSSL import crypto from OpenSSL import crypto
from urllib import splittype, splithost, splitport, urlencode from urllib import splittype, splithost, splitport, urlencode
from . import ctl, tunnel, utils, version, x509 from . import ctl, tunnel, utils, version, x509
...@@ -71,6 +72,20 @@ class RegistryServer(object): ...@@ -71,6 +72,20 @@ class RegistryServer(object):
"email TEXT", "email TEXT",
"cert TEXT"): "cert TEXT"):
self.db.execute("INSERT INTO cert VALUES ('',null,null)") self.db.execute("INSERT INTO cert VALUES ('',null,null)")
if utils.sqliteCreateTable(self.db, "crl",
"serial INTEGER PRIMARY KEY NOT NULL",
# Expiration date of revoked certificate.
# TODO: purge rows with dates in the past.
"date INTEGER NOT NULL"):
# Revoke certificates produced by previous version.
# They all have serial 0.
try:
date = max(x509.notAfter(x[0]) for x in self.iterCert())
except ValueError:
pass
else:
if time.time() < date:
self.db.execute("INSERT INTO crl VALUES (0,?)", (date,))
self.cert = x509.Cert(self.config.ca, self.config.key) self.cert = x509.Cert(self.config.ca, self.config.key)
# Get vpn network prefix # Get vpn network prefix
...@@ -97,9 +112,11 @@ class RegistryServer(object): ...@@ -97,9 +112,11 @@ class RegistryServer(object):
self.db.execute("INSERT OR REPLACE INTO config VALUES (?, ?)", self.db.execute("INSERT OR REPLACE INTO config VALUES (?, ?)",
name_value) name_value)
def updateNetworkConfig(self): def updateNetworkConfig(self, _it0=itemgetter(0)):
kw = { kw = {
'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125', 'babel_default': 'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125',
'crl': map(_it0, self.db.execute(
"SELECT serial FROM crl ORDER BY serial")),
'protocol': version.protocol, 'protocol': version.protocol,
'registry_prefix': self.prefix, 'registry_prefix': self.prefix,
} }
...@@ -220,7 +237,7 @@ class RegistryServer(object): ...@@ -220,7 +237,7 @@ class RegistryServer(object):
def handle_request(self, request, method, kw, def handle_request(self, request, method, kw,
_localhost=('127.0.0.1', '::1')): _localhost=('127.0.0.1', '::1')):
m = getattr(self, method) m = getattr(self, method)
if method in ('versions', 'topology'): if method in ('revoke', 'versions', 'topology'):
x_forwarded_for = request.headers.get('X-Forwarded-For') x_forwarded_for = request.headers.get('X-Forwarded-For')
if request.client_address[0] not in _localhost or \ if request.client_address[0] not in _localhost or \
x_forwarded_for and x_forwarded_for not in _localhost: x_forwarded_for and x_forwarded_for not in _localhost:
...@@ -393,15 +410,16 @@ class RegistryServer(object): ...@@ -393,15 +410,16 @@ class RegistryServer(object):
@rpc @rpc
def renewCertificate(self, cn): def renewCertificate(self, cn):
with self.lock: with self.lock:
with self.db: with self.db as db:
pem = self.getCert(cn) pem = self.getCert(cn)
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem) cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
if x509.notAfter(cert) - RENEW_PERIOD < time.time(): if x509.notAfter(cert) - RENEW_PERIOD < time.time():
not_after = None not_after = None
elif cert.get_serial_number(): elif db.execute("SELECT count(*) FROM crl WHERE serial=?",
return pem (cert.get_serial_number(),)).fetchone()[0]:
else:
not_after = cert.get_notAfter() not_after = cert.get_notAfter()
else:
return pem
return self.createCertificate(cn, return self.createCertificate(cn,
cert.get_subject(), cert.get_pubkey(), not_after) cert.get_subject(), cert.get_pubkey(), not_after)
...@@ -452,6 +470,29 @@ class RegistryServer(object): ...@@ -452,6 +470,29 @@ class RegistryServer(object):
logging.info("Sending bootstrap peer: %s", msg) logging.info("Sending bootstrap peer: %s", msg)
return x509.encrypt(cert, msg) return x509.encrypt(cert, msg)
@rpc
def revoke(self, cn_or_serial):
with self.lock:
with self.db:
q = self.db.execute
try:
serial = int(cn_or_serial)
except ValueError:
prefix = utils.binFromSubnet(cn_or_serial)
cert = self.getCert(prefix)
q("UPDATE cert SET email=null, cert=null WHERE prefix=?",
(prefix,))
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
serial = cert.get_serial_number()
self.sessions.pop(prefix, None)
else:
cert, = (cert for cert, prefix, email in self.iterCert()
if cert.get_serial_number() == serial)
not_after = x509.notAfter(cert)
if time.time() < not_after:
q("INSERT INTO crl VALUES (?,?)", (serial, not_after))
self.updateNetworkConfig()
@rpc @rpc
def versions(self): def versions(self):
with self.peers_lock: with self.peers_lock:
......
import errno, logging, os, random, socket, subprocess, time, weakref import errno, logging, os, random, socket, subprocess, struct, time, weakref
from collections import defaultdict, deque from collections import defaultdict, deque
from bisect import bisect, insort from bisect import bisect, insort
from OpenSSL import crypto from OpenSSL import crypto
...@@ -40,6 +40,7 @@ class MultiGatewayManager(dict): ...@@ -40,6 +40,7 @@ class MultiGatewayManager(dict):
class Connection(object): class Connection(object):
_retry = 0 _retry = 0
serial = None
time = float('inf') time = float('inf')
def __init__(self, tunnel_manager, address_list, iface, prefix): def __init__(self, tunnel_manager, address_list, iface, prefix):
...@@ -69,15 +70,19 @@ class Connection(object): ...@@ -69,15 +70,19 @@ class Connection(object):
'--connect-retry-max', '3', '--tls-exit', '--connect-retry-max', '3', '--tls-exit',
'--remap-usr1', 'SIGTERM', '--remap-usr1', 'SIGTERM',
'--ping-exit', str(tm.timeout), '--ping-exit', str(tm.timeout),
'--route-up', '%s %u' % (plib.ovpn_client, tm.write_pipe), '--route-up', '%s %u' % (plib.ovpn_client, tm.write_sock.fileno()),
*tm.ovpn_args) *tm.ovpn_args)
tm.resetTunnelRefresh() tm.resetTunnelRefresh()
self._retry += 1 self._retry += 1
def connected(self): def connected(self, serial):
cache = self.tunnel_manager.cache
if serial in cache.crl:
self.tunnel_manager._kill(self._prefix)
return
self.serial = serial
i = self._retry - 1 i = self._retry - 1
self._retry = None self._retry = None
cache = self.tunnel_manager.cache
if i: if i:
cache.addPeer(self._prefix, ','.join(self.address_list[i]), True) cache.addPeer(self._prefix, ','.join(self.address_list[i]), True)
else: else:
...@@ -167,14 +172,14 @@ class BaseTunnelManager(object): ...@@ -167,14 +172,14 @@ class BaseTunnelManager(object):
_forward = None _forward = None
def __init__(self, cache, cert, cert_renew, address=()): def __init__(self, cache, cert, address=()):
self.cert = cert self.cert = cert
self._network = cert.network self._network = cert.network
self._prefix = cert.prefix self._prefix = cert.prefix
self.cache = cache self.cache = cache
self._connecting = set() self._connecting = set()
self._connection_dict = {} self._connection_dict = {}
self._served = set() self._served = defaultdict(dict)
self._version = cache.version self._version = cache.version
address_dict = defaultdict(list) address_dict = defaultdict(list)
...@@ -190,9 +195,9 @@ class BaseTunnelManager(object): ...@@ -190,9 +195,9 @@ class BaseTunnelManager(object):
self.sock.bind(('::', PORT)) self.sock.bind(('::', PORT))
p = x509.Peer(self._prefix) p = x509.Peer(self._prefix)
p.stop_date = cert_renew p.stop_date = cache.next_renew
self._peers = [p] self._peers = [p]
self._timeouts = [(cert_renew, self.invalidatePeers)] self._timeouts = [(p.stop_date, self.invalidatePeers)]
def select(self, r, w, t): def select(self, r, w, t):
r[self.sock] = self.handlePeerEvent r[self.sock] = self.handlePeerEvent
...@@ -307,6 +312,9 @@ class BaseTunnelManager(object): ...@@ -307,6 +312,9 @@ class BaseTunnelManager(object):
cert = self.cert.loadVerify(msg, cert = self.cert.loadVerify(msg,
True, crypto.FILETYPE_ASN1) True, crypto.FILETYPE_ASN1)
stop_date = x509.notAfter(cert) stop_date = x509.notAfter(cert)
serial = cert.get_serial_number()
if serial in self.cache.crl:
raise ValueError("revoked")
except (x509.VerifyError, ValueError), e: except (x509.VerifyError, ValueError), e:
logging.debug('ignored invalid certificate from %r (%s)', logging.debug('ignored invalid certificate from %r (%s)',
address, e.args[-1]) address, e.args[-1])
...@@ -320,6 +328,7 @@ class BaseTunnelManager(object): ...@@ -320,6 +328,7 @@ class BaseTunnelManager(object):
peer = x509.Peer(p) peer = x509.Peer(p)
insort(self._peers, peer) insort(self._peers, peer)
peer.cert = cert peer.cert = cert
peer.serial = serial
peer.stop_date = stop_date peer.stop_date = stop_date
self.selectTimeout(stop_date, self.invalidatePeers, False) self.selectTimeout(stop_date, self.invalidatePeers, False)
if seqno: if seqno:
...@@ -398,7 +407,7 @@ class BaseTunnelManager(object): ...@@ -398,7 +407,7 @@ class BaseTunnelManager(object):
raise utils.ReexecException( raise utils.ReexecException(
"Restart with new network parameters") "Restart with new network parameters")
def broadcastVersion(self): def _newVersion(self):
pass pass
def newVersion(self): def newVersion(self):
...@@ -410,32 +419,77 @@ class BaseTunnelManager(object): ...@@ -410,32 +419,77 @@ class BaseTunnelManager(object):
logging.info("changed: %r", changed) logging.info("changed: %r", changed)
self.selectTimeout(None, self.newVersion) self.selectTimeout(None, self.newVersion)
self._version = self.cache.version self._version = self.cache.version
self.broadcastVersion() self._newVersion()
self.cache.warnProtocol() self.cache.warnProtocol()
if not self.NEED_RESTART.isdisjoint(changed) or \ crl = self.cache.crl
version.protocol < self.cache.min_protocol: for i in reversed([i for i, peer in enumerate(self._peers)
if peer.serial in crl]):
del self._peers[i]
if self.cert.cert.get_serial_number() in crl:
raise utils.ReexecException("Our certificate has just been revoked."
" Let's try to renew it.")
if (not self.NEED_RESTART.isdisjoint(changed)
or version.protocol < self.cache.min_protocol
# TODO: With --management, we could kill clients without restarting.
or not all(crl.isdisjoint(serials.itervalues())
for serials in self._served.itervalues())):
# Wait at least 1 second to broadcast new version to neighbours. # Wait at least 1 second to broadcast new version to neighbours.
# If re6stnet is too old, don't abort now, because a new version # If re6stnet is too old, don't abort now, because a new version
# may have been installed without restart. # may have been installed without restart.
self.selectTimeout(time.time() + 1 + self.cache.delay_restart, self.selectTimeout(time.time() + 1 + self.cache.delay_restart,
self._restart) self._restart)
def handleServerEvent(self, sock):
event, args = eval(sock.recv(65536))
logging.debug("%s%r", event, args)
r = getattr(self, '_ovpn_' + event.replace('-', '_'))(*args)
if r is not None:
sock.send(chr(r))
def _ovpn_client_connect(self, common_name, iface, serial, trusted_ip):
if serial in self.cache.crl:
return False
prefix = utils.binFromSubnet(common_name)
self._served[prefix][iface] = serial
if isinstance(self, TunnelManager): # XXX
if self._gateway_manager is not None:
self._gateway_manager.add(trusted_ip, False)
if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix)
self.cache.connecting(prefix, 0)
return True
def _ovpn_client_disconnect(self, common_name, iface, serial, trusted_ip):
prefix = utils.binFromSubnet(common_name)
serials = self._served.get(prefix)
try:
del serials[iface]
except (KeyError, TypeError):
logging.exception("ovpn_client_disconnect%r",
(common_name, iface, serial, trusted_ip))
return
if not serials:
del self._served[prefix]
if isinstance(self, TunnelManager): # XXX
self._abortTunnelKiller(prefix, iface)
if self._gateway_manager is not None:
self._gateway_manager.remove(trusted_ip)
class TunnelManager(BaseTunnelManager): class TunnelManager(BaseTunnelManager):
NEED_RESTART = BaseTunnelManager.NEED_RESTART.union(( NEED_RESTART = BaseTunnelManager.NEED_RESTART.union((
'client_count', 'max_clients', 'tunnel_refresh')) 'client_count', 'max_clients', 'tunnel_refresh'))
def __init__(self, control_socket, cache, cert, cert_renew, openvpn_args, def __init__(self, control_socket, cache, cert, openvpn_args,
timeout, client_count, iface_list, address, ip_changed, timeout, client_count, iface_list, address, ip_changed,
remote_gateway, disable_proto, neighbour_list=()): remote_gateway, disable_proto, neighbour_list=()):
super(TunnelManager, self).__init__(cache, cert, cert_renew, address) super(TunnelManager, self).__init__(cache, cert, address)
self.ctl = ctl.Babel(control_socket, weakref.proxy(self), self._network) self.ctl = ctl.Babel(control_socket, weakref.proxy(self), self._network)
self.ovpn_args = openvpn_args self.ovpn_args = openvpn_args
self.timeout = timeout self.timeout = timeout
# Create and open read_only pipe to get server events self._read_sock, self.write_sock = socket.socketpair(
r, self.write_pipe = os.pipe() socket.AF_UNIX, socket.SOCK_DGRAM)
self._read_pipe = os.fdopen(r)
self._disconnected = 0 self._disconnected = 0
self._distant_peers = [] self._distant_peers = []
self._iface_to_prefix = {} self._iface_to_prefix = {}
...@@ -497,7 +551,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -497,7 +551,7 @@ class TunnelManager(BaseTunnelManager):
def select(self, r, w, t): def select(self, r, w, t):
super(TunnelManager, self).select(r, w, t) super(TunnelManager, self).select(r, w, t)
r[self._read_pipe] = self.handleTunnelEvent r[self._read_sock] = self.handleClientEvent
if self._next_refresh: if self._next_refresh:
t.append((self._next_refresh, self.refresh)) t.append((self._next_refresh, self.refresh))
self.ctl.select(r, w, t) self.ctl.select(r, w, t)
...@@ -572,11 +626,13 @@ class TunnelManager(BaseTunnelManager): ...@@ -572,11 +626,13 @@ class TunnelManager(BaseTunnelManager):
prefix = min(peer_set, key=self._tunnelScore) prefix = min(peer_set, key=self._tunnelScore)
self._killing[prefix] = TunnelKiller(prefix, self, True) self._killing[prefix] = TunnelKiller(prefix, self, True)
def _abortTunnelKiller(self, prefix): def _abortTunnelKiller(self, prefix, iface=None):
tunnel_killer = self._killing.get(prefix) tunnel_killer = self._killing.get(prefix)
if tunnel_killer: if tunnel_killer:
if tunnel_killer.state: if tunnel_killer.state:
tunnel_killer.abort() if not iface or \
iface == self.ctl.interfaces[tunnel_killer.ifindex]:
tunnel_killer.abort()
else: else:
del self._killing[prefix] del self._killing[prefix]
...@@ -719,42 +775,15 @@ class TunnelManager(BaseTunnelManager): ...@@ -719,42 +775,15 @@ class TunnelManager(BaseTunnelManager):
for prefix in self._connection_dict.keys(): for prefix in self._connection_dict.keys():
self._kill(prefix) self._kill(prefix)
def handleTunnelEvent(self): def handleClientEvent(self):
try: msg = self._read_sock.recv(65536)
msg = self._read_pipe.readline().rstrip() logging.debug("route_up%s", msg)
args = msg.split() common_name, time, serial, ip = eval(msg)
m = getattr(self, '_ovpn_' + args.pop(0).replace('-', '_'))
except (AttributeError, ValueError):
logging.warning("Unknown message received from OpenVPN: %s", msg)
else:
logging.debug(msg)
m(*args)
def _ovpn_client_connect(self, common_name, trusted_ip):
prefix = utils.binFromSubnet(common_name)
self._served.add(prefix)
if self._gateway_manager is not None:
self._gateway_manager.add(trusted_ip, False)
if prefix in self._connection_dict and self._prefix < prefix:
self._kill(prefix)
self.cache.connecting(prefix, 0)
def _ovpn_client_disconnect(self, common_name, trusted_ip):
prefix = utils.binFromSubnet(common_name)
try:
self._served.remove(prefix)
except KeyError:
return
self._abortTunnelKiller(prefix)
if self._gateway_manager is not None:
self._gateway_manager.remove(trusted_ip)
def _ovpn_route_up(self, common_name, time, ip):
prefix = utils.binFromSubnet(common_name) prefix = utils.binFromSubnet(common_name)
c = self._connection_dict.get(prefix) c = self._connection_dict.get(prefix)
if c and c.time < float(time): if c and c.time < float(time):
try: try:
c.connected() c.connected(serial)
except (KeyError, TypeError), e: except (KeyError, TypeError), e:
logging.error("%s (route_up %s)", e, common_name) logging.error("%s (route_up %s)", e, common_name)
else: else:
...@@ -765,7 +794,7 @@ class TunnelManager(BaseTunnelManager): ...@@ -765,7 +794,7 @@ class TunnelManager(BaseTunnelManager):
if address: if address:
self._address[family] = utils.dump_address(address) self._address[family] = utils.dump_address(address)
def broadcastVersion(self): def _newVersion(self):
for prefix in self.ctl.neighbours: for prefix in self.ctl.neighbours:
if prefix: if prefix:
peer = self._getPeer(prefix) peer = self._getPeer(prefix)
...@@ -774,3 +803,6 @@ class TunnelManager(BaseTunnelManager): ...@@ -774,3 +803,6 @@ class TunnelManager(BaseTunnelManager):
elif (peer.version < self._version and elif (peer.version < self._version and
self.sendto(prefix, '\0' + self._version)): self.sendto(prefix, '\0' + self._version)):
peer.version = self._version peer.version = self._version
for prefix, c in self._connection_dict.items():
if c.serial in self.cache.crl:
self._kill(prefix)
...@@ -42,10 +42,12 @@ def encrypt(cert, data): ...@@ -42,10 +42,12 @@ def encrypt(cert, data):
def fingerprint(cert, alg='sha1'): def fingerprint(cert, alg='sha1'):
return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)) return hashlib.new(alg, crypto.dump_certificate(crypto.FILETYPE_ASN1, cert))
def maybe_renew(path, cert, info, renew): def maybe_renew(path, cert, info, renew, force=False):
from .registry import RENEW_PERIOD from .registry import RENEW_PERIOD
while True: while True:
if cert.get_serial_number(): if force:
force = False
else:
next_renew = notAfter(cert) - RENEW_PERIOD next_renew = notAfter(cert) - RENEW_PERIOD
if time.time() < next_renew: if time.time() < next_renew:
return cert, next_renew return cert, next_renew
...@@ -110,11 +112,10 @@ class Cert(object): ...@@ -110,11 +112,10 @@ class Cert(object):
'--cert', self.cert_path, '--cert', self.cert_path,
'--key', self.key_path) '--key', self.key_path)
def maybeRenew(self, registry): def maybeRenew(self, registry, crl):
from .registry import RegistryClient
registry = RegistryClient(registry, self)
self.cert, next_renew = maybe_renew(self.cert_path, self.cert, self.cert, next_renew = maybe_renew(self.cert_path, self.cert,
"Certificate", lambda: registry.renewCertificate(self.prefix)) "Certificate", lambda: registry.renewCertificate(self.prefix),
self.cert.get_serial_number() in crl)
self.ca, ca_renew = maybe_renew(self.ca_path, self.ca, self.ca, ca_renew = maybe_renew(self.ca_path, self.ca,
"CA Certificate", registry.getCa) "CA Certificate", registry.getCa)
return min(next_renew, ca_renew) return min(next_renew, ca_renew)
...@@ -181,6 +182,7 @@ class Peer(object): ...@@ -181,6 +182,7 @@ class Peer(object):
""" """
_hello = _last = 0 _hello = _last = 0
_key = newHmacSecret() _key = newHmacSecret()
serial = None
stop_date = float('inf') stop_date = float('inf')
version = '' version = ''
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import atexit, errno, logging, os, shutil, signal import atexit, errno, logging, os, shutil, signal
import socket, subprocess, sys, time, threading import socket, subprocess, sys, time, threading
from collections import deque from collections import deque
from functools import partial
from re6st import plib, tunnel, utils, version, x509 from re6st import plib, tunnel, utils, version, x509
from re6st.cache import Cache from re6st.cache import Cache
from re6st.utils import exit, ReexecException from re6st.utils import exit, ReexecException
...@@ -130,7 +131,6 @@ def main(): ...@@ -130,7 +131,6 @@ def main():
exit.signal(0, signal.SIGINT, signal.SIGTERM) exit.signal(0, signal.SIGINT, signal.SIGTERM)
exit.signal(-1, signal.SIGHUP, signal.SIGUSR2) exit.signal(-1, signal.SIGHUP, signal.SIGUSR2)
next_renew = cert.maybeRenew(config.registry)
cache = Cache(db_path, config.registry, cert) cache = Cache(db_path, config.registry, cert)
network = cert.network network = cert.network
...@@ -249,14 +249,12 @@ def main(): ...@@ -249,14 +249,12 @@ def main():
control_socket = os.path.join(config.run, 'babeld.sock') control_socket = os.path.join(config.run, 'babeld.sock')
if config.client_count and not config.client: if config.client_count and not config.client:
tunnel_manager = tunnel.TunnelManager(control_socket, tunnel_manager = tunnel.TunnelManager(control_socket,
cache, cert, next_renew, config.openvpn_args, timeout, cache, cert, config.openvpn_args, timeout,
config.client_count, config.iface_list, address, ip_changed, config.client_count, config.iface_list, address, ip_changed,
remote_gateway, config.disable_proto, config.neighbour) remote_gateway, config.disable_proto, config.neighbour)
tunnel_interfaces += tunnel_manager.new_iface_list tunnel_interfaces += tunnel_manager.new_iface_list
write_pipe = tunnel_manager.write_pipe
else: else:
write_pipe = None tunnel_manager = tunnel.BaseTunnelManager(cache, cert)
tunnel_manager = tunnel.BaseTunnelManager(cache, cert, next_renew)
cleanup.append(tunnel_manager.sock.close) cleanup.append(tunnel_manager.sock.close)
try: try:
...@@ -275,6 +273,7 @@ def main(): ...@@ -275,6 +273,7 @@ def main():
# an public IP so Babel must be changed to set a source # an public IP so Babel must be changed to set a source
# address on routes it installs. # address on routes it installs.
ip('addrlabel', 'prefix', my_network, 'label', '99') ip('addrlabel', 'prefix', my_network, 'label', '99')
R = {}
# prepare persistent interfaces # prepare persistent interfaces
if config.client: if config.client:
address_list = [x for x in utils.parse_address(config.client) address_list = [x for x in utils.parse_address(config.client)
...@@ -288,9 +287,13 @@ def main(): ...@@ -288,9 +287,13 @@ def main():
elif server_tunnels: elif server_tunnels:
required('dh') required('dh')
for iface, (port, proto) in server_tunnels.iteritems(): for iface, (port, proto) in server_tunnels.iteritems():
r, x = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
cleanup.append(plib.server(iface, config.max_clients, cleanup.append(plib.server(iface, config.max_clients,
config.dh, write_pipe, port, proto, cache.encrypt, config.dh, x.fileno(), port, proto, cache.encrypt,
'--ping-exit', str(timeout), *config.openvpn_args).stop) '--ping-exit', str(timeout), *config.openvpn_args,
preexec_fn=r.close).stop)
R[r] = partial(tunnel_manager.handleServerEvent, r)
x.close()
ip('addr', my_ip, 'dev', config.main_interface) ip('addr', my_ip, 'dev', config.main_interface)
if_rt = ['ip', '-6', 'route', 'del', if_rt = ['ip', '-6', 'route', 'del',
...@@ -371,7 +374,7 @@ def main(): ...@@ -371,7 +374,7 @@ def main():
select_list = [forwarder.select] if forwarder else [] select_list = [forwarder.select] if forwarder else []
select_list += tunnel_manager.select, utils.select select_list += tunnel_manager.select, utils.select
while True: while True:
args = {}, {}, [] args = R.copy(), {}, []
for s in select_list: for s in select_list:
s(*args) s(*args)
finally: finally:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment