Commit 42495e18 authored by Julien Muchembled's avatar Julien Muchembled

Implement automatic renewal of client certificate

parent f2e11d86
#!/usr/bin/python #!/usr/bin/python
import argparse, atexit, errno, os, subprocess, sqlite3, sys import argparse, atexit, errno, os, subprocess, sqlite3, sys, time
from OpenSSL import crypto from OpenSSL import crypto
from re6st import registry, utils from re6st import registry, utils
...@@ -139,7 +139,11 @@ def main(): ...@@ -139,7 +139,11 @@ def main():
os.ftruncate(cert_fd, len(cert)) os.ftruncate(cert_fd, len(cert))
os.close(cert_fd) os.close(cert_fd)
print "Certificate setup complete." cert = loadCert(cert)
not_after = utils.notAfter(cert)
print("Setup complete. Certificate is valid until %s"
" and will be automatically renewed after %s" % (
time.ctime(not_after), time.ctime(not_after - registry.RENEW_PERIOD)))
if not os.path.lexists(conf_path): if not os.path.lexists(conf_path):
create(conf_path, """\ create(conf_path, """\
...@@ -160,7 +164,7 @@ dh %s ...@@ -160,7 +164,7 @@ dh %s
""" % (config.registry, ca_path, cert_path, key_path, dh_path)) """ % (config.registry, ca_path, cert_path, key_path, dh_path))
print "Sample configuration file created." print "Sample configuration file created."
cn = utils.subnetFromCert(loadCert(cert)) cn = utils.subnetFromCert(cert)
subnet = network + utils.binFromSubnet(cn) subnet = network + utils.binFromSubnet(cn)
print "Your subnet: %s/%u (CN=%s)" \ print "Your subnet: %s/%u (CN=%s)" \
% (utils.ipFromBin(subnet), len(subnet), cn) % (utils.ipFromBin(subnet), len(subnet), cn)
......
...@@ -8,6 +8,7 @@ from urllib import splittype, splithost, splitport, urlencode ...@@ -8,6 +8,7 @@ from urllib import splittype, splithost, splitport, urlencode
from . import tunnel, utils from . import tunnel, utils
HMAC_HEADER = "Re6stHMAC" HMAC_HEADER = "Re6stHMAC"
RENEW_PERIOD = 30 * 86400
class getcallargs(type): class getcallargs(type):
...@@ -190,28 +191,38 @@ class RegistryServer(object): ...@@ -190,28 +191,38 @@ class RegistryServer(object):
(token,)).next() (token,)).next()
except StopIteration: except StopIteration:
return return
self.db.execute("DELETE FROM token WHERE token = ?", (token,)) self.db.execute("DELETE FROM token WHERE token = ?",
(token,))
# Get a new prefix
prefix = self._getPrefix(prefix_len) prefix = self._getPrefix(prefix_len)
self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?",
(email, prefix))
return self._createCertificate(prefix, req.get_subject(),
req.get_pubkey())
# Create certificate def _createCertificate(self, client_prefix, subject, pubkey):
cert = crypto.X509() cert = crypto.X509()
cert.set_serial_number(0) # required for libssl < 1.0 cert.set_serial_number(0) # required for libssl < 1.0
cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(self.cert_duration) cert.gmtime_adj_notAfter(self.cert_duration)
cert.set_issuer(self.ca.get_subject()) cert.set_issuer(self.ca.get_subject())
subject = req.get_subject() subject.CN = "%u/%u" % (int(client_prefix, 2), len(client_prefix))
subject.CN = "%u/%u" % (int(prefix, 2), prefix_len) cert.set_subject(subject)
cert.set_subject(subject) cert.set_pubkey(pubkey)
cert.set_pubkey(req.get_pubkey()) cert.sign(self.key, 'sha1')
cert.sign(self.key, 'sha1') cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) self.db.execute("UPDATE cert SET cert = ? WHERE prefix = ?",
(cert, client_prefix))
# Insert certificate into db return cert
self.db.execute("UPDATE cert SET email = ?, cert = ? WHERE prefix = ?", (email, cert, prefix))
return cert def renewCertificate(self, cn):
with self.lock:
with self.db:
pem = self._getCert(cn)
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
if utils.notAfter(cert) - RENEW_PERIOD < time.time():
pem = self._createCertificate(cn, cert.get_subject(),
cert.get_pubkey())
return pem
def getCa(self): def getCa(self):
return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca) return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
......
...@@ -132,6 +132,9 @@ def networkFromCa(ca): ...@@ -132,6 +132,9 @@ def networkFromCa(ca):
def subnetFromCert(cert): def subnetFromCert(cert):
return cert.get_subject().CN return cert.get_subject().CN
def notAfter(cert):
return time.mktime(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ'))
def dump_address(address): def dump_address(address):
return ';'.join(map(','.join, address)) return ';'.join(map(','.join, address))
......
...@@ -4,8 +4,10 @@ import sqlite3, subprocess, sys, time, traceback ...@@ -4,8 +4,10 @@ import sqlite3, subprocess, sys, time, traceback
from collections import deque from collections import deque
from OpenSSL import crypto from OpenSSL import crypto
from re6st import db, plib, tunnel, utils from re6st import db, plib, tunnel, utils
from re6st.registry import RegistryClient from re6st.registry import RegistryClient, RENEW_PERIOD
class ReexecException(Exception):
pass
def getConfig(): def getConfig():
parser = utils.ArgParser(fromfile_prefix_chars='@', parser = utils.ArgParser(fromfile_prefix_chars='@',
...@@ -112,6 +114,8 @@ def getConfig(): ...@@ -112,6 +114,8 @@ def getConfig():
return parser.parse_args() return parser.parse_args()
def renew(*args):
raise ReexecException("Restart to renew certificate")
def main(): def main():
# Get arguments # Get arguments
...@@ -144,6 +148,24 @@ def main(): ...@@ -144,6 +148,24 @@ def main():
signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1)) signal.signal(signal.SIGHUP, lambda *args: sys.exit(-1))
signal.signal(signal.SIGTERM, lambda *args: sys.exit()) signal.signal(signal.SIGTERM, lambda *args: sys.exit())
registry = RegistryClient(config.registry, config.key, ca)
while True:
next_renew = utils.notAfter(cert) - RENEW_PERIOD
if time.time() < next_renew:
break
pem = registry.renewCertificate(prefix)
if not pem or pem == crypto.dump_certificate(crypto.FILETYPE_PEM, cert):
logging.warning("Certificate not renewed. Will retry tomorrow.")
next_renew = time.time() + 86400
break
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
path = config.cert + '.new'
with open(path, 'w') as f:
f.write(pem)
os.rename(path, config.cert)
logging.info("Certificate renewed until %s",
time.ctime(utils.notAfter(cert)))
if config.max_clients is None: if config.max_clients is None:
config.max_clients = config.client_count * 2 config.max_clients = config.client_count * 2
...@@ -232,7 +254,6 @@ def main(): ...@@ -232,7 +254,6 @@ def main():
# Create and open read_only pipe to get server events # Create and open read_only pipe to get server events
r_pipe, write_pipe = os.pipe() r_pipe, write_pipe = os.pipe()
read_pipe = os.fdopen(r_pipe) read_pipe = os.fdopen(r_pipe)
registry = RegistryClient(config.registry, config.key, ca)
peer_db = db.PeerDB(db_path, registry, config.key, prefix) peer_db = db.PeerDB(db_path, registry, config.key, prefix)
tunnel_manager = tunnel.TunnelManager(write_pipe, peer_db, tunnel_manager = tunnel.TunnelManager(write_pipe, peer_db,
config.openvpn_args, timeout, config.tunnel_refresh, config.openvpn_args, timeout, config.tunnel_refresh,
...@@ -313,7 +334,12 @@ def main(): ...@@ -313,7 +334,12 @@ def main():
# main loop # main loop
if tunnel_manager is None: if tunnel_manager is None:
sys.exit(os.WEXITSTATUS(os.wait()[1])) signal.signal(signal.SIGALRM, renew)
signal.alarm(int(next_renew - time.time()))
try:
sys.exit(os.WEXITSTATUS(os.wait()[1]))
finally:
signal.alarm(0)
cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll cleanup += tunnel_manager.delInterfaces, tunnel_manager.killAll
while True: while True:
next = tunnel_manager.next_refresh next = tunnel_manager.next_refresh
...@@ -333,6 +359,8 @@ def main(): ...@@ -333,6 +359,8 @@ def main():
t = time.time() t = time.time()
if t >= tunnel_manager.next_refresh: if t >= tunnel_manager.next_refresh:
tunnel_manager.refresh() tunnel_manager.refresh()
if t >= next_renew:
renew()
if forwarder and t >= forwarder.next_refresh: if forwarder and t >= forwarder.next_refresh:
forwarder.refresh() forwarder.refresh()
finally: finally:
...@@ -344,16 +372,18 @@ def main(): ...@@ -344,16 +372,18 @@ def main():
except sqlite3.Error: except sqlite3.Error:
logging.exception("Restarting with empty cache") logging.exception("Restarting with empty cache")
os.rename(db_path, db_path + '.bak') os.rename(db_path, db_path + '.bak')
try: except ReexecException, e:
sys.exitfunc() logging.info(e)
finally:
os.execvp(sys.argv[0], sys.argv)
except KeyboardInterrupt: except KeyboardInterrupt:
return 0 return 0
except Exception: except Exception:
f = traceback.format_exception(*sys.exc_info()) f = traceback.format_exception(*sys.exc_info())
logging.error('%s%s', f.pop(), ''.join(f)) logging.error('%s%s', f.pop(), ''.join(f))
sys.exit(1) sys.exit(1)
try:
sys.exitfunc()
finally:
os.execvp(sys.argv[0], sys.argv)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment