Commit a30aec39 authored by Julien Muchembled's avatar Julien Muchembled

registry: whitelist RPCs rather than blacklist methods to not publish

Here, it's simpler and safer. We will also want to have private methods that
don't start with an underscore.
parent 19f6cacc
...@@ -29,8 +29,8 @@ class RequestHandler(BaseHTTPRequestHandler): ...@@ -29,8 +29,8 @@ class RequestHandler(BaseHTTPRequestHandler):
query = dict(parse_qsl(query, keep_blank_values=1, query = dict(parse_qsl(query, keep_blank_values=1,
strict_parsing=1)) strict_parsing=1))
_, path = path.split('/') _, path = path.split('/')
if not _ and path[0] != '_': if not _:
return self.server._handle_request(self, path, query) return self.server.handle_request(self, path, query)
except Exception: except Exception:
logging.info(self.requestline, exc_info=1) logging.info(self.requestline, exc_info=1)
self.send_error(httplib.BAD_REQUEST) self.send_error(httplib.BAD_REQUEST)
...@@ -104,13 +104,13 @@ def main(): ...@@ -104,13 +104,13 @@ def main():
empty_list = [] empty_list = []
while True: while True:
while True: while True:
next = server._timeout next = server.timeout
if next is None: if next is None:
break break
next -= time.time() next -= time.time()
if next > 0: if next > 0:
break break
server._onTimeout() server.onTimeout()
try: try:
r = select.select(server_list[:], empty_list, empty_list, r = select.select(server_list[:], empty_list, empty_list,
next)[0] next)[0]
......
...@@ -32,27 +32,16 @@ HMAC_HEADER = "Re6stHMAC" ...@@ -32,27 +32,16 @@ HMAC_HEADER = "Re6stHMAC"
RENEW_PERIOD = 30 * 86400 RENEW_PERIOD = 30 * 86400
GRACE_PERIOD = RENEW_PERIOD GRACE_PERIOD = RENEW_PERIOD
def rpc(f):
class getcallargs(type):
def __init__(cls, name, bases, d):
type.__init__(cls, name, bases, d)
for n, f in d.iteritems():
if n[0] == '_':
continue
try:
args, varargs, varkw, defaults = inspect.getargspec(f) args, varargs, varkw, defaults = inspect.getargspec(f)
except TypeError: assert not (varargs or varkw or defaults), f
continue
if varargs or varkw or defaults:
continue
f.getcallargs = eval("lambda %s: locals()" % ','.join(args[1:])) f.getcallargs = eval("lambda %s: locals()" % ','.join(args[1:]))
return f
class RegistryServer(object): class RegistryServer(object):
__metaclass__ = getcallargs peers = 0, ()
_peers = 0, ()
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
...@@ -98,12 +87,12 @@ class RegistryServer(object): ...@@ -98,12 +87,12 @@ class RegistryServer(object):
self.network = utils.networkFromCa(self.ca) self.network = utils.networkFromCa(self.ca)
logging.info("Network: %s/%u", utils.ipFromBin(self.network), logging.info("Network: %s/%u", utils.ipFromBin(self.network),
len(self.network)) len(self.network))
self._email = self.ca.get_subject().emailAddress self.email = self.ca.get_subject().emailAddress
self._onTimeout() self.onTimeout()
def _onTimeout(self): def onTimeout(self):
# XXX: Because we use threads to process requests, the statements # XXX: Because we use threads to process requests, the statements
# 'self._timeout = 1' below have no effect as long as the # 'self.timeout = 1' below have no effect as long as the
# 'select' call does not return. Ideally, we should interrupt it. # 'select' call does not return. Ideally, we should interrupt it.
logging.info("Checking if there's any old entry in the database ...") logging.info("Checking if there's any old entry in the database ...")
not_after = None not_after = None
...@@ -140,10 +129,10 @@ class RegistryServer(object): ...@@ -140,10 +129,10 @@ class RegistryServer(object):
elif not_after is None or x < not_after: elif not_after is None or x < not_after:
not_after = x not_after = x
# TODO: reduce 'cert' table by merging free slots # TODO: reduce 'cert' table by merging free slots
# (IOW, do the contrary of _newPrefix) # (IOW, do the contrary of newPrefix)
self._timeout = not_after and not_after + GRACE_PERIOD self.timeout = not_after and not_after + GRACE_PERIOD
def _handle_request(self, request, method, kw): def handle_request(self, request, method, kw):
m = getattr(self, method) m = getattr(self, method)
if method in ('topology',) and \ if method in ('topology',) and \
request.client_address[0] not in ('127.0.0.1', '::1'): request.client_address[0] not in ('127.0.0.1', '::1'):
...@@ -177,9 +166,10 @@ class RegistryServer(object): ...@@ -177,9 +166,10 @@ class RegistryServer(object):
if result: if result:
request.wfile.write(result) request.wfile.write(result)
@rpc
def hello(self, client_prefix): def hello(self, client_prefix):
with self.lock: with self.lock:
cert = self._getCert(client_prefix) cert = self.getCert(client_prefix)
key = hashlib.sha1(struct.pack('Q', key = hashlib.sha1(struct.pack('Q',
random.getrandbits(64))).digest() random.getrandbits(64))).digest()
self.sessions.setdefault(client_prefix, [])[1:] = key, self.sessions.setdefault(client_prefix, [])[1:] = key,
...@@ -188,11 +178,12 @@ class RegistryServer(object): ...@@ -188,11 +178,12 @@ class RegistryServer(object):
assert len(key) == len(sign) assert len(key) == len(sign)
return key + sign return key + sign
def _getCert(self, client_prefix): def getCert(self, client_prefix):
assert self.lock.locked() assert self.lock.locked()
return self.db.execute("SELECT cert FROM cert WHERE prefix = ?", return self.db.execute("SELECT cert FROM cert WHERE prefix = ?",
(client_prefix,)).next()[0] (client_prefix,)).next()[0]
@rpc
def requestToken(self, email): def requestToken(self, email):
with self.lock: with self.lock:
while True: while True:
...@@ -205,14 +196,14 @@ class RegistryServer(object): ...@@ -205,14 +196,14 @@ class RegistryServer(object):
break break
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
pass pass
self._timeout = 1 self.timeout = 1
# Creating and sending email # Creating and sending email
msg = MIMEText('Hello, your token to join re6st network is: %s\n' msg = MIMEText('Hello, your token to join re6st network is: %s\n'
% token) % token)
msg['Subject'] = '[re6stnet] Token Request' msg['Subject'] = '[re6stnet] Token Request'
if self._email: if self.email:
msg['From'] = self._email msg['From'] = self.email
msg['To'] = email msg['To'] = email
if os.path.isabs(self.config.mailhost) or \ if os.path.isabs(self.config.mailhost) or \
os.path.isfile(self.config.mailhost): os.path.isfile(self.config.mailhost):
...@@ -224,10 +215,10 @@ class RegistryServer(object): ...@@ -224,10 +215,10 @@ class RegistryServer(object):
m.close() m.close()
else: else:
s = smtplib.SMTP(self.config.mailhost) s = smtplib.SMTP(self.config.mailhost)
s.sendmail(self._email, email, msg.as_string()) s.sendmail(self.email, email, msg.as_string())
s.quit() s.quit()
def _newPrefix(self, prefix_len): def newPrefix(self, prefix_len):
max_len = 128 - len(self.network) max_len = 128 - len(self.network)
assert 0 < prefix_len <= max_len assert 0 < prefix_len <= max_len
try: try:
...@@ -243,8 +234,9 @@ class RegistryServer(object): ...@@ -243,8 +234,9 @@ class RegistryServer(object):
if len(prefix) < max_len or '1' in prefix: if len(prefix) < max_len or '1' in prefix:
return prefix return prefix
self.db.execute("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,)) self.db.execute("UPDATE cert SET cert = 'reserved' WHERE prefix = ?", (prefix,))
return self._newPrefix(prefix_len) return self.newPrefix(prefix_len)
@rpc
def requestCertificate(self, token, req): def requestCertificate(self, token, req):
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, req) req = crypto.load_certificate_request(crypto.FILETYPE_PEM, req)
with self.lock: with self.lock:
...@@ -263,17 +255,17 @@ class RegistryServer(object): ...@@ -263,17 +255,17 @@ class RegistryServer(object):
if not prefix_len: if not prefix_len:
return return
email = None email = None
prefix = self._newPrefix(prefix_len) prefix = self.newPrefix(prefix_len)
self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?", self.db.execute("UPDATE cert SET email = ? WHERE prefix = ?",
(email, prefix)) (email, prefix))
if self.prefix is None: if self.prefix is None:
self.prefix = prefix self.prefix = prefix
self.db.execute( self.db.execute(
"INSERT INTO config VALUES ('prefix',?)", (prefix,)) "INSERT INTO config VALUES ('prefix',?)", (prefix,))
return self._createCertificate(prefix, req.get_subject(), return self.createCertificate(prefix, req.get_subject(),
req.get_pubkey()) req.get_pubkey())
def _createCertificate(self, client_prefix, subject, pubkey): def createCertificate(self, client_prefix, subject, pubkey):
cert = crypto.X509() cert = crypto.X509()
cert.set_serial_number(0) # required for libssl < 1.0 cert.set_serial_number(0) # required for libssl < 1.0
cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notBefore(0)
...@@ -286,37 +278,36 @@ class RegistryServer(object): ...@@ -286,37 +278,36 @@ class RegistryServer(object):
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
self.db.execute("UPDATE cert SET cert = ? WHERE prefix = ?", self.db.execute("UPDATE cert SET cert = ? WHERE prefix = ?",
(cert, client_prefix)) (cert, client_prefix))
self._timeout = 1 self.timeout = 1
return cert return cert
@rpc
def renewCertificate(self, cn): def renewCertificate(self, cn):
with self.lock: with self.lock:
with self.db: with self.db:
pem = self._getCert(cn) pem = self.getCert(cn)
cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem) cert = crypto.load_certificate(crypto.FILETYPE_PEM, pem)
if utils.notAfter(cert) - RENEW_PERIOD < time.time(): if utils.notAfter(cert) - RENEW_PERIOD < time.time():
pem = self._createCertificate(cn, cert.get_subject(), pem = self.createCertificate(cn, cert.get_subject(),
cert.get_pubkey()) cert.get_pubkey())
return pem return pem
@rpc
def getCa(self): def getCa(self):
return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca) return crypto.dump_certificate(crypto.FILETYPE_PEM, self.ca)
@rpc
def getPrefix(self, cn): def getPrefix(self, cn):
return self.prefix return self.prefix
def getPrivateAddress(self, cn): @rpc
# BBB: Deprecated by getPrefix.
return utils.ipFromBin(self.network + self.prefix)
def getBootstrapPeer(self, cn): def getBootstrapPeer(self, cn):
with self.lock: with self.lock:
cert = self._getCert(cn) age, peers = self.peers
age, peers = self._peers
if age < time.time() or not peers: if age < time.time() or not peers:
peers = [x[1] for x in utils.iterRoutes(self.network)] peers = [x[1] for x in utils.iterRoutes(self.network)]
random.shuffle(peers) random.shuffle(peers)
self._peers = time.time() + 60, peers self.peers = time.time() + 60, peers
peer = peers.pop() peer = peers.pop()
if peer == cn: if peer == cn:
# Very unlikely (e.g. peer restarted with empty cache), # Very unlikely (e.g. peer restarted with empty cache),
...@@ -341,9 +332,11 @@ class RegistryServer(object): ...@@ -341,9 +332,11 @@ class RegistryServer(object):
else: else:
logging.info("Timeout while querying [%s]:%u", *address) logging.info("Timeout while querying [%s]:%u", *address)
return return
cert = self.getCert(cn)
logging.info("Sending bootstrap peer: %s", msg) logging.info("Sending bootstrap peer: %s", msg)
return utils.encrypt(cert, msg) return utils.encrypt(cert, msg)
@rpc
def topology(self): def topology(self):
with self.lock: with self.lock:
peers = deque(('%u/%u' % (int(self.prefix, 2), len(self.prefix)),)) peers = deque(('%u/%u' % (int(self.prefix, 2), len(self.prefix)),))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment