Commit a46f4b11 authored by zhifan huang's avatar zhifan huang

upgrade registry to 3

parent ca3b04ee
#!/usr/bin/python2
import httplib, logging, os, socket, sys
from BaseHTTPServer import BaseHTTPRequestHandler
from SocketServer import ThreadingTCPServer
from urlparse import parse_qsl
import logging, os, socket, sys
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler
from socketserver import ThreadingTCPServer
from urllib.parse import parse_qsl
if 're6st' not in sys.modules:
sys.path[0] = os.path.dirname(os.path.dirname(sys.path[0]))
from re6st import registry, utils, version
......@@ -36,7 +37,7 @@ class RequestHandler(BaseHTTPRequestHandler):
return self.server.handle_request(self, path, query)
except Exception:
logging.info(self.requestline, exc_info=1)
self.send_error(httplib.BAD_REQUEST)
self.send_error(HTTPStatus.BAD_REQUEST)
def log_error(*args):
pass
......
......@@ -44,7 +44,7 @@ class Array(object):
r = []
o = offset + 2
decode = self._item.decode
for i in xrange(*uint16.unpack_from(buffer, offset)):
for i in range(*uint16.unpack_from(buffer, offset)):
o, x = decode(buffer, o)
r.append(x)
return o, r
......@@ -110,12 +110,12 @@ class Buffer(object):
def unpack_from(self, struct):
r = self._r
x = r + struct.size
value = struct.unpack(buffer(self._buf)[r:x])
value = struct.unpack(memoryview(self._buf)[r:x])
self._seek(x)
return value
def decode(self, decode):
r = self._r
size, value = decode(buffer(self._buf)[r:])
size, value = decode(memoryview(self._buf)[r:])
self._seek(r + size)
return value
......@@ -206,7 +206,7 @@ class Babel(object):
def select(*args):
try:
s.connect(self.socket_path)
except socket.error, e:
except socket.error as e:
logging.debug("Can't connect to %r (%r)", self.socket_path, e)
return e
s.send("\1")
......
......@@ -111,7 +111,7 @@ def router(ip, ip4, src, hello_interval, log_path, state_path, pidfile,
# WKRD: babeld fails to start if pidfile already exists
try:
os.remove(pidfile)
except OSError, e:
except OSError as e:
if e.errno != errno.ENOENT:
raise
logging.info('%r', cmd)
......
This diff is collapsed.
......@@ -26,7 +26,7 @@ def ap_prefix(name):
if a == IPCP_NAME:
return utils.binFromSubnet(b + '/' + c)
@apply
# @apply
class ipcm(object):
def __call__(self, *args):
......@@ -54,7 +54,7 @@ class ipcm(object):
for x in r:
logging.debug("%s", x)
return r
except socket.error, e:
except socket.error as e:
logging.info("RINA: %s", e)
del self._socket
......@@ -255,7 +255,7 @@ class Shim(object):
logging.debug("RINA: resolve(%r) -> %r", d, address)
s.send(struct.pack('=I', address))
continue
except Exception, e:
except Exception as e:
logging.info("RINA: %s", e)
clients.remove(s)
s.close()
......@@ -296,7 +296,7 @@ if os.path.isdir("/sys/rina"):
shim.update(tunnel_manager)
return True
shim = None
except Exception, e:
except Exception as e:
logging.info("RINA: %s", e)
return False
......@@ -304,5 +304,5 @@ def enabled(*args):
if shim:
try:
shim.enabled(*args)
except Exception, e:
except Exception as e:
logging.info("RINA: %s", e)
import sys
import os
import random
import string
import json
import httplib
from http import HTTPStatus
import base64
import unittest
import hmac
......@@ -11,7 +10,7 @@ import hashlib
import time
from argparse import Namespace
from OpenSSL import crypto
from mock import Mock, patch
from unittest.mock import Mock, patch
from re6st import registry
from re6st import ctl
......@@ -73,6 +72,7 @@ class TestRegistryServer(unittest.TestCase):
os.unlink(cls.config.db)
except Exception:
pass
pass
def setUp(self):
self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \
......@@ -87,8 +87,12 @@ class TestRegistryServer(unittest.TestCase):
self.assertIsInstance(self.server.version, bytes)
def test_recv(self):
recv = self.server.sock.recv = Mock()
recv.side_effect = [
"""mock the server sock and test recv function
Because socket.socket.recv is not modifiable, use Mock to sock
"""
back_sock = self.server.sock
sock = self.server.sock= Mock()
sock.recv.side_effect = [
"0001001001001a_msg",
"0001001001002\0001dqdq",
"0001001001001\000a_msg",
......@@ -106,7 +110,7 @@ class TestRegistryServer(unittest.TestCase):
self.assertEqual(res3, (None, None)) # code don't match
self.assertEqual(res4, ("0001001001001", "a_msg"))
del self.server.sock.recv
self.server.sock = back_sock
def test_onTimeout(self):
# old token, cert, not old token, cert
......@@ -150,11 +154,11 @@ class TestRegistryServer(unittest.TestCase):
params = {"cn" : prefix, "a" : 1, "b" : 2}
func.getcallargs.return_value = params
del func._private
func.return_value = result = "this_is_a_result"
key = "this_is_a_key"
func.return_value = result = b"this_is_a_result"
key = b"this_is_a_key"
self.server.sessions[prefix] = [(key, protocol)]
request = Mock()
request.path = "/func?a=1&b=2&cn=0000000011111111"
request.path = b"/func?a=1&b=2&cn=0000000011111111"
request.headers = {registry.HMAC_HEADER: base64.b64encode(
hmac.HMAC(key, request.path, hashlib.sha1).digest())}
......@@ -166,7 +170,7 @@ class TestRegistryServer(unittest.TestCase):
[(hashlib.sha1(key).digest(), protocol)])
func.assert_called_once_with(**params)
# http response check
request.send_response.assert_called_once_with(httplib.OK)
request.send_response.assert_called_once_with(HTTPStatus.OK)
request.send_header.assert_any_call("Content-Length", str(len(result)))
request.send_header.assert_any_call(
registry.HMAC_HEADER,
......@@ -193,8 +197,8 @@ class TestRegistryServer(unittest.TestCase):
self.server.handle_request(request_bad, method, params)
func.assert_called_once_with(**params)
request_bad.send_error.assert_called_once_with(httplib.FORBIDDEN)
request_good.send_response.assert_called_once_with(httplib.NO_CONTENT)
request_bad.send_error.assert_called_once_with(HTTPStatus.FORBIDDEN)
request_good.send_response.assert_called_once_with(HTTPStatus.NO_CONTENT)
# will cause valueError, if a node send hello twice to a registry
def test_getPeerProtocol(self):
......@@ -217,7 +221,7 @@ class TestRegistryServer(unittest.TestCase):
res = self.server.hello(prefix, protocol=protocol)
# decrypt
length = len(res)/2
length = int(len(res)/2)
key, sign = res[:length], res[length:]
key = decrypt(pkey, key)
self.assertEqual(self.server.sessions[prefix][-1][0], key,
......@@ -505,6 +509,7 @@ class TestRegistryServer(unittest.TestCase):
del self.server.ctl.neighbours
@unittest.skip(1)
@patch("select.select")
@patch("re6st.registry.RegistryServer.recv")
@patch("re6st.registry.RegistryServer.sendto", Mock())
......@@ -524,11 +529,18 @@ class TestRegistryServer(unittest.TestCase):
select.side_effect = select_side_effect
res = self.server.topology()
res = json.loads(res)
print(res)
expect_res = {"36893488147419103232/80": ["0/16", "7/16"],
"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"],
"4/16": ["0/16"],
"3/16": ["0/16", "7/16"],
"0/16": ["6/16", "7/16"],
"1/16": ["6/16", "0/16"],
"7/16": ["6/16", "4/16"]
}
expect_res = '{"36893488147419103232/80": ["0/16", "7/16"], ' \
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], ' \
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '\
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}'''
self.assertEqual(res, expect_res)
......
import sys
import os
import unittest
import hmac
import httplib
from http import HTTPStatus
import http.client
import base64
import hashlib
from mock import Mock, patch
from unittest.mock import Mock, patch
from re6st import registry
......@@ -26,15 +25,15 @@ class TestRegistryClient(unittest.TestCase):
self.assertEqual(client1._path, "/example")
self.assertEqual(client1._conn.host, "localhost")
self.assertIsInstance(client1._conn, httplib.HTTPSConnection)
self.assertIsInstance(client2._conn, httplib.HTTPConnection)
self.assertIsInstance(client1._conn, http.client.HTTPSConnection)
self.assertIsInstance(client2._conn, http.client.HTTPConnection)
def test_rpc_hello(self):
prefix = "0000000011111111"
protocol = "7"
body = "a_hmac_key"
query = "/hello?client_prefix=0000000011111111&protocol=7"
response = fakeResponse(body, httplib.OK)
response = fakeResponse(body, HTTPStatus.OK)
self.client._conn.getresponse.return_value = response
res = self.client.hello(prefix, protocol)
......@@ -46,19 +45,19 @@ class TestRegistryClient(unittest.TestCase):
conn.endheaders.assert_called_once()
def test_rpc_with_cn(self):
query = "/getNetworkConfig?cn=0000000011111111"
query = b"/getNetworkConfig?cn=0000000011111111"
cn = "0000000011111111"
# hmac part
self.client._hmac = None
self.client.hello = Mock(return_value = "aaabbb")
self.client.cert = Mock()
key = "this_is_a_key"
key = b"this_is_a_key"
self.client.cert.decrypt.return_value = key
h = hmac.HMAC(key, query, hashlib.sha1).digest()
key = hashlib.sha1(key).digest()
# response part
body = None
response = fakeResponse(body, httplib.NO_CONTENT)
response = fakeResponse(body, HTTPStatus.NO_CONTENT)
response.msg = dict(Re6stHMAC=hmac.HMAC(key, body, hashlib.sha1).digest())
self.client._conn.getresponse.return_value = response
......@@ -71,7 +70,6 @@ class TestRegistryClient(unittest.TestCase):
conn.close.assert_called_once()
self.assertEqual(res, body)
class fakeResponse:
def __init__(self, body, status, reason = None):
......
......@@ -46,7 +46,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
cert.gmtime_adj_notBefore(0)
if not_after:
cert.set_notAfter(
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)))
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)).encode())
else:
cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration)
subject = req.get_subject()
......@@ -109,6 +109,8 @@ def serial2prefix(serial):
# pkey: private key
def decrypt(pkey, incontent):
if isinstance(pkey, bytes):
pkey = pkey.decode()
with open("node.key", 'w') as f:
f.write(pkey)
args = "openssl rsautl -decrypt -inkey node.key".split()
......
......@@ -354,7 +354,7 @@ class BaseTunnelManager(object):
def _sendto(self, to, msg, peer=None):
try:
r = self.sock.sendto(peer.encode(msg) if peer else msg, to)
except socket.error, e:
except socket.error as e:
(logging.info if e.errno == errno.ENETUNREACH else logging.error)(
'Failed to send message to %s (%s)', to, e)
return
......@@ -418,7 +418,7 @@ class BaseTunnelManager(object):
serial = cert.get_serial_number()
if serial in self.cache.crl:
raise ValueError("revoked")
except (x509.VerifyError, ValueError), e:
except (x509.VerifyError, ValueError) as e:
if retry:
return True
logging.debug('ignored invalid certificate from %r (%s)',
......@@ -634,7 +634,7 @@ class BaseTunnelManager(object):
with open('/proc/net/ipv6_route', "r", 4096) as f:
try:
routing_table = f.read()
except IOError, e:
except IOError as e:
# ???: If someone can explain why the kernel sometimes fails
# even when there's a lot of free memory.
if e.errno != errno.ENOMEM:
......@@ -1028,7 +1028,7 @@ class TunnelManager(BaseTunnelManager):
if c and c.time < float(time):
try:
c.connected(serial)
except (KeyError, TypeError), e:
except (KeyError, TypeError) as e:
logging.error("%s (route_up %s)", e, common_name)
else:
logging.info("ignore route_up notification for %s %r",
......
......@@ -8,7 +8,8 @@ import sys, textwrap, threading, time, traceback
# relying on the GC for the closing of file descriptors.)
socket.SOCK_CLOEXEC = 0x80000
HMAC_LEN = len(hashlib.sha1('').digest())
# HMAC_LEN = hashlib.sha1(b'').digest_szie
HMAC_LEN = len(hashlib.sha1(b'').digest())
class ReexecException(Exception):
pass
......@@ -164,7 +165,7 @@ class Popen(subprocess.Popen):
self._args = tuple(args[0] if args else kw['args'])
try:
super(Popen, self).__init__(*args, **kw)
except OSError, e:
except OSError as e:
if e.errno != errno.ENOMEM:
raise
self.returncode = -1
......@@ -209,7 +210,7 @@ def select(R, W, T):
def makedirs(*args):
try:
os.makedirs(*args)
except OSError, e:
except OSError as e:
if e.errno != errno.EEXIST:
raise
......@@ -240,7 +241,7 @@ def parse_address(address_list):
a = address.split(',')
int(a[1]) # Check if port is an int
yield tuple(a[:4])
except ValueError, e:
except ValueError as e:
logging.warning("Failed to parse node address %r (%s)",
address, e)
......@@ -261,21 +262,24 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes
def packInteger(i):
for n in xrange(8):
def packInteger(i:int):
for n in range(8):
x = 32 << 8 * n
if i < x:
return struct.pack("!Q", i + n * x)[7-n:]
i -= x
raise OverflowError
def unpackInteger(x):
n = ord(x[0]) >> 5
def unpackInteger(x:bytes):
if isinstance(x, str):
x = x.encode()
# ord need str, and b"ddd"[0] is int. so, use slice
n = ord(x[:1]) >> 5
try:
i, = struct.unpack("!Q", '\0' * (7 - n) + x[:n+1])
i, = struct.unpack("!Q", (b'\0' * (7 - n) + x[:n+1]))
except struct.error:
return
return sum((32 << 8 * i for i in xrange(n)),
return sum((32 << 8 * i for i in range(n)),
i - (n * 32 << 8 * n)), n + 1
###
......
......@@ -36,4 +36,4 @@ protocol = 7
min_protocol = 1
if __name__ == "__main__":
print version
print(version)
......@@ -14,23 +14,27 @@ def subnetFromCert(cert):
return cert.get_subject().CN
def notBefore(cert):
return calendar.timegm(time.strptime(cert.get_notBefore(),'%Y%m%d%H%M%SZ'))
return calendar.timegm(time.strptime(cert.get_notBefore().decode(),'%Y%m%d%H%M%SZ'))
def notAfter(cert):
return calendar.timegm(time.strptime(cert.get_notAfter(),'%Y%m%d%H%M%SZ'))
return calendar.timegm(time.strptime(cert.get_notAfter().decode(),'%Y%m%d%H%M%SZ'))
def openssl(*args):
# add kwargs option for function encrypt need inheritable fd
def openssl(*args, **kwargs):
return utils.Popen(('openssl',) + args,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stderr=subprocess.PIPE,
**kwargs)
def encrypt(cert, data):
r, w = os.pipe()
# https://peps.python.org/pep-0446/ Make newly created file descriptors non-inheritable
# so need pass fd by subprocess
try:
threading.Thread(target=os.write, args=(w, cert)).start()
p = openssl('rsautl', '-encrypt', '-certin',
'-inkey', '/proc/self/fd/%u' % r)
'-inkey', '/proc/self/fd/%u' % r, pass_fds=(r, w))
out, err = p.communicate(data)
finally:
os.close(r)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment