Commit 5f3b3db2 authored by zhifan huang's avatar zhifan huang Committed by zhifan huang

test: add unit tests for registry, conf, tunnel

registry: add test for RegistryServer and registryClient,
testRegistryServer mainly test methods concerned to http rpc request.
Other methods that call request_dump like getBootStrpPeer is not
include.
testRegistryClient test the rpc call with or not with "cn" parameter.

cli.conf: test each situation call the cli conf.

tunnel: add test for BaseTunnelManager, MultGatewayManager

tools is a util cotain method to make cert and ket files
parent 308eaa7f
__all__ = ["test_unit"]
-----BEGIN DH PARAMETERS-----
MIIBCAKCAQEAoXvAhNiPPi9WTYjDhkrLfSGV7lQdAnKJHohSdsR85SdH8u9whvlb
a4Jt2aEJCFqL1LMziF8Dy3ipcUe/xYbmZJ+w1wNAnuzzeJWH5z57duZy6jPvPxsW
uLTsnjlUn+nYG7vrkmWEqgzQDLY2aV9maPREmqAvxorIdffWXKsh6wyBhVuOghC/
8pqxDY5C+VewBMkqibZnNtQ8IWMw+6SmPKx4bLA44P1VlfpF+3VBNgs3JD26djFX
vJs+Pd4+j0GM2hPlxTSIB12cKSiDix0YXdHrVQtnatsnG3wyzVSTKbvxQRMEFM1X
BnUtiqlB3IlGCg6RWXRGcdEZ50blLQZ0kwIBAg==
-----END DH PARAMETERS-----
{
"verbose": 1,
"ca": "root.crt",
"port": 9090,
"anonymous_prefix_length": null,
"smtp_pwd": null,
"client_count": 10,
"authorized_origin": [
"127.0.0.1",
"::1"
],
"bind6": "::",
"ipv4": null,
"prefix_length": 16,
"min_protocol": 1,
"smtp_starttls": false,
"run": "run",
"bind4": "0.0.0.0",
"dh": "dh2048.pem",
"db": "registry.db",
"mailhost": "miku@miku.com",
"key": "registry.key",
"encrypt": false,
"logfile": "registry.log",
"max_clients": null,
"smtp_user": null,
"same_country": null,
"tunnel_refresh": 300,
"hello": 15
}
\ No newline at end of file
-----BEGIN RSA PRIVATE KEY-----
MIIEpQIBAAKCAQEA2yS7SlkxfMc8ydVbFPv8kZFQQp1dyda4Dt41P3+klO01TE0k
JY25pvpBcYFz3tbADB/m6zJlvuA9eVTy7cxodioRaYiUCZjgUbXZ2/BziUIGygDL
WcYbvf0A1aX7qGYPx1cLwLxvwZlHa+tVB/MKgRDC+L/qRAftRc6+uMrdjddUaKbC
B2k8GPxUWPpKeSB4ymXDXeuey0nHZ7KgjHqRcHKVEK24eBR9NR38Epm4bAkp8GAT
wA9qhDLlKJZXqMJCl3VilOaRpwoO9dmTvIPEHx0R9DNFFRUlc4DPq60UIYsfvUkp
467PfpJH7aXLM7PIyvU0aZ5QHtdcKHXnsRc+yQIDAQABAoIBACvgpOd0CGaVdeRr
pbsD4UQ8NjfAToEVTvEbKMo4AnoXLK7EW1JxmBSI0wWpB8w8b2N+F7xL8PdQ6r4a
djGK1fei4K2ivRFW3MM/iAlzkY6P+9ACbLTi57cYq0wb2dGT7eDZ2u6STEYVLKm9
Ct92mEnTU1Z/Bqbsd2Ocy68wX0AA2Ho/ktpL3hGhlBTOm9PwMV/aj/99YjEN+iLD
Q2e/YGNOslHvil4YErC01g6WR2SMujhQ/42ValYpacIVlirajaYgKdQ0PctwB1es
6Miwbwt/BrA8yH8gDCpHdqrfjAnb1J1xr5idu3i09Dt1F3FxHLKfU8DOng2WQD9v
HCNNaIkCgYEA7XllFs490pd7fraEde3IIRUVd/fXu/1qWjjpbijaMirtH9zWOE8b
JSklLrjYDXd6OS1dDft4DeRHbKOj3qe51Ux6kt/Zw1WdTfo3UtNjcK+EETpJuKyG
ctQHQG8HiQtgN36mkVUnc0W5IsURUnjNFO51r9V9U5ItYcBon5bGIUsCgYEA7D1A
3fMoDKbZzVrtDue9h+UmyCPCJgZASTt6iEm98guswbSQ/Qa1C94YbmQXAGO4NPwS
b/b78/JK6VTZpuuPB97m/antq1MlI8iNwomp1QR9m7jKg7WHui+nvSewh1o35B90
GstfIWmm6Pvqw0iu8UgSTEW7y3/SK51nEknPp7sCgYEAy4hiNfuqTRZ8SAxS12hn
QMN7VQldI8h9ILrqhvoImTrlZYu3JyfV0jHDppnSwygF33+b4+IF8ZIYDWrrhmgn
BEO6QqwNTjfQzQaJ6Dk5X1lvTfyxNtDXow9K79S5lqHjY2zvglyDpW660Kwqvo6+
5xPCVmQaOEhvEPsCMNXfFqUCgYEAjztINAm0c49KKNcDOfFJmbZXACumEBXkLkKQ
tUc4kiN/9+X5rl+9r1dWKsAmrgbH7eATca0m764suzHF0Q2rJ9N+67d2sVR1BTAY
uyVqQgw5+AtfReHvS/SO2AHTZw1NK9PiOkiqAgEjwMjUeth7sTDIX1Q8W1LBY85I
au8zpvcCgYEA2ZpA8QsbU6mijBsrc4aOSqnP4VBiqS5aLNqwOcxwlYtSn5/8xTvb
eI1imNQf5Js4l9/7fRTuxSYhZIPM8SLbHWCkesJubiiKc+R+m7uBZ+0h7W6ZtlY3
avFFoiUUtviOdvqBketnWkIiwzL8lzjHOwjYvdRiGv9Fkoty0SKXbp8=
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIID4jCCAsqgAwIBAgIHASABDbgAQjANBgkqhkiG9w0BAQsFADCBhjELMAkGA1UE
BhMCRlIxDjAMBgNVBAgMBUxJTExFMQ4wDAYDVQQHDAVMSUxMRTEhMB8GA1UECgwY
SW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMQwwCgYDVQQDDANIWkYxJjAkBgkqhkiG
9w0BCQEWF3poaWZhbi5odWFuZ0BuZXhlZGkuY29tMB4XDTIyMDIxNDA4NDQxN1oX
DTMyMDIxMjA4NDQxN1owgYYxCzAJBgNVBAYTAkZSMQ4wDAYDVQQIDAVMSUxMRTEO
MAwGA1UEBwwFTElMTEUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0
ZDEMMAoGA1UEAwwDSFpGMSYwJAYJKoZIhvcNAQkBFhd6aGlmYW4uaHVhbmdAbmV4
ZWRpLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMvG4CCFLEJ8
vl9KiCXkfqrPhlSwktKs0PrT2kxncIJGwlkMdVlgKY5AG41ayf3+ZhpXKOWA7CDw
D653ugPc6JADh9fxWCAxzeXyvkBVMRBZ1zADDdiLKrTox8pCPu0wVPnxdsq+4TZW
lB3CGkn2wxPNenI4aFiK6J79C35kX6pMupDgbwOUuh4e+Z+fTe2ouAGRlG/s69F5
5ehjChUpNKRrIekKdFzVF76XK5twH9N2x6Iyd+dFfrQ0qiAhDugKWbqVxD1bdPlQ
rtU4LpQ9z1c3cXVGViopL7CMZN3qT4x9bA/j41ISDpdRm9cwJGLuEo+bWTf+nFaM
jzQAoPbyPpMCAwEAAaNTMFEwHQYDVR0OBBYEFOMXLeN2qnczE67jRn6PTAarT/nO
MB8GA1UdIwQYMBaAFOMXLeN2qnczE67jRn6PTAarT/nOMA8GA1UdEwEB/wQFMAMB
Af8wDQYJKoZIhvcNAQELBQADggEBAEfYZF1NijKxcB8dU58mJWUtvx3LNUfOlasB
4ykaEzDsA2zpnJ31msJ86G4VHy5umA4bbX80Eo1fBXT4W7GmfakbYIahb5A05Vrf
b1lggQAGsVEptfAAFgRynaPOCgyBmor55izPBt64jnZOS1Hgx8kmSowDYR+CVqc6
Ur+9e71jmkTv2LQwOl0fRD77vw2QZMV68C7y3SY32ErPb2anGuBzdrlrGHFy4Jam
FdiYcw7uEkdrJX3eXRI9gUBcIuljTiYQv2NZzhmeL+qbWDb/DI0NXNvZ8oe0ZHOd
UKfcEWjnPMyZHpuyPOeV6ywTLOdHXG9GqwTgAfln+vOhoKQ8fkE=
-----END CERTIFICATE-----
"""Re6st unittest module
"""
# contatin the test case
__all__ = ["test_registry",
"test_registry_client",
"test_conf",
"test_tunnel"]
#!/usr/bin/python2
""" unit test for re6st-conf
"""
import os
import sys
import unittest
from shutil import rmtree
from StringIO import StringIO
from mock import patch
from re6st.cli import conf
from re6st.tests.tools import generate_cert, serial2prefix
# gloable value from conf.py
conf_path = 're6stnet.conf'
ca_path = 'ca.crt'
cert_path = 'cert.crt'
key_path = 'cert.key'
# TODO test for is needed
class TestConf(unittest.TestCase):
""" Unit test case for re6st-conf"""
@classmethod
def setUpClass(cls):
# because conf will change directory
cls.origin_dir = os.getcwd()
cls.work_dir = "temp"
if not os.path.exists(cls.work_dir):
os.makedirs(cls.work_dir)
# mocked service cert and pkey
with open("root.crt") as f:
cls.cert = f.read()
with open("registry.key") as f:
cls.pkey = f.read()
cls.command = "re6st-conf --registry http://localhost/" \
" --dir %s" % cls.work_dir
cls.serial = 0
cls.stdout = sys.stdout
cls.null = open(os.devnull, 'w')
sys.stdout = cls.null
@classmethod
def tearDownClass(cls):
# remove work directory
rmtree(cls.work_dir)
cls.null.close()
sys.stdout = cls.stdout
def setUp(self):
patcher = patch("re6st.registry.RegistryClient")
self.addCleanup(patcher.stop)
self.client = patcher.start()()
self.client.getCa.return_value = self.cert
prefix = serial2prefix(self.serial)
self.client.requestCertificate.side_effect = \
lambda _, req: generate_cert(self.cert, self.pkey, req, prefix, self.serial)
self.serial += 1
def tearDown(self):
# go back to original dir
os.chdir(self.origin_dir)
@patch("__builtin__.raw_input")
def test_basic(self, mock_raw_input):
""" go through all the step
getCa, requestToken, requestCertificate
"""
mail = "example@email.com"
token = "a_token"
mock_raw_input.side_effect = [mail, token]
command = self.command \
+ " --fingerprint sha1:a1861330f1299b98b529fa52c3d8e5d1a94dc63a" \
+ " --req L lille"
sys.argv = command.split()
conf.main()
self.client.requestToken.assert_called_once_with(mail)
self.assertEqual(self.client.requestCertificate.call_args[0][0],
token)
# created file part
self.assertTrue(os.path.exists(ca_path))
self.assertTrue(os.path.exists(key_path))
self.assertTrue(os.path.exists(cert_path))
self.assertTrue(os.path.exists(conf_path))
def test_fingerprint_mismatch(self):
""" wrong fingerprint with same size,
"""
command = self.command \
+ " --fingerprint sha1:a1861330f1299b98b529fa52c3d8e5d1a94dc000"
sys.argv = command.split()
with self.assertRaises(SystemExit) as e:
conf.main()
self.assertIn("fingerprint doesn't match", str(e.exception))
def test_ca_only(self):
""" only create ca file and exit
"""
command = self.command + " --ca-only"
sys.argv = command.split()
with self.assertRaises(SystemExit):
conf.main()
self.assertTrue(os.path.exists(ca_path))
def test_anonymous(self):
""" with args anonymous, so script will use '' as token
"""
command = self.command + " --anonymous"
sys.argv = command.split()
conf.main()
self.assertEqual(self.client.requestCertificate.call_args[0][0],
'')
def test_anonymous_failed(self):
""" with args anonymous and token, so script will failed
"""
command = self.command + " --anonymous" \
+ " --token a"
sys.argv = command.split()
text = StringIO()
old_err = sys.stderr
sys.stderr = text
with self.assertRaises(SystemExit):
conf.main()
# check the error message
self.assertIn("anonymous conflicts", text.getvalue())
sys.stderr = old_err
def test_req_reserved(self):
""" with args req, but contain reserved value
"""
command = self.command + " --req CN 1111"
sys.argv = command.split()
with self.assertRaises(SystemExit) as e:
conf.main()
self.assertIn("CN field", str(e.exception))
def test_get_null_cert(self):
""" simulate fake token, and get null cert
"""
command = self.command + " --token a"
sys.argv = command.split()
self.client.requestCertificate.side_effect = "",
with self.assertRaises(SystemExit) as e:
conf.main()
self.assertIn("invalid or expired token", str(e.exception))
if __name__ == "__main__":
unittest.main()
import sys
import os
import random
import string
import json
import httplib
import base64
import unittest
import hmac
import hashlib
import time
from argparse import Namespace
from OpenSSL import crypto
from mock import Mock, patch
from re6st import registry
from re6st.tests.tools import *
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions
def load_config(filename="registry.json"):
with open(filename) as f:
config = json.load(f)
return Namespace(**config)
def get_cert(cur, prefix):
res = cur.execute(
"SELECT cert FROM cert WHERE prefix=?", (prefix,)).fetchone()
return res[0]
def insert_cert(cur, ca, prefix, not_after=None, email=None):
key, csr = generate_csr()
cert = generate_cert(ca.ca, ca.key, csr, prefix, insert_cert.serial, not_after)
cur.execute("INSERT INTO cert VALUES (?,?,?)", (prefix, email, cert))
insert_cert.serial += 1
return key, cert
insert_cert.serial = 0
def delete_cert(cur, prefix):
cur.execute("DELETE FROM cert WHERE prefix = ?", (prefix,))
# TODO function for get a unique prefix
class TestRegistryServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
# instance a server
cls.config = load_config()
cls.server = registry.RegistryServer(cls.config)
@classmethod
def tearDownClass(cls):
# remove database
try:
os.unlink(cls.config.db)
except Exception:
pass
def setUp(self):
self.email = ''.join(random.sample(string.ascii_lowercase, 4)) \
+ "@mail.com"
def test_recv(self):
recv = self.server.sock.recv = Mock()
recv.side_effect = [
"0001001001001a_msg",
"0001001001002\0001dqdq",
"0001001001001\000a_msg",
"0001001001001\000\4a_msg",
"0000000000000\0" # ERROR, IndexError: msg is null
]
res1 = self.server.recv(4)
res2 = self.server.recv(4)
res3 = self.server.recv(4)
res4 = self.server.recv(4)
self.assertEqual(res1, (None, None)) # not contain \0
self.assertEqual(res2, (None, None)) # binary to digital failed
self.assertEqual(res3, (None, None)) # code don't match
self.assertEqual(res4, ("0001001001001", "a_msg"))
del self.server.sock.recv
def test_onTimeout(self):
# old token, cert, not old token, cert
# not after will equal to now -1
# condtion prefix == self.prefix not covered
cur = self.server.db.cursor()
token_old, token = "bbbbdddd", "ddddbbbb"
prefix_old, prefix = "1110", "1111"
# 20 magic number, make sure we create old enough new cert/token
now = int(time.time()) - registry.GRACE_PERIOD + 20
# makeup data
insert_cert(cur, self.server.cert, prefix_old, 1)
insert_cert(cur, self.server.cert, prefix, now -1)
cur.execute("INSERT INTO token VALUES (?,?,?,?)",
(token_old, self.email, 4, 2))
cur.execute("INSERT INTO token VALUES (?,?,?,?)",
(token, self.email, 4, now))
cur.close()
self.server.onTimeout()
self.assertIsNone(self.server.isToken(token_old))
self.assertIsNotNone(self.server.isToken(token))
cur = self.server.db.cursor()
self.assertIsNone(get_cert(cur, prefix_old), "old cert not deleted")
self.assertIsNotNone(get_cert(cur, prefix))
self.assertEqual(self.server.timeout,
now - 1 + registry.GRACE_PERIOD,
"time_out set wrongly")
delete_cert(cur, prefix)
cur.close()
self.server.deleteToken(token)
@patch("re6st.registry.RegistryServer.func", create=True)
def test_handle_request(self, func):
'''rpc with cn and have result'''
prefix = "0000000011111111"
method = "func"
protocol = 7
params = {"cn" : prefix, "a" : 1, "b" : 2}
func.getcallargs.return_value = params
del func._private
func.return_value = result = "this_is_a_result"
key = "this_is_a_key"
self.server.sessions[prefix] = [(key, protocol)]
request = Mock()
request.path = "/func?a=1&b=2&cn=0000000011111111"
request.headers = {registry.HMAC_HEADER: base64.b64encode(
hmac.HMAC(key, request.path, hashlib.sha1).digest())}
self.server.handle_request(request, method, params)
# hmac check
key = hashlib.sha1(key).digest()
self.assertEqual(self.server.sessions[prefix],
[(hashlib.sha1(key).digest(), protocol)])
func.assert_called_once_with(**params)
# http response check
request.send_response.assert_called_once_with(httplib.OK)
request.send_header.assert_any_call("Content-Length", str(len(result)))
request.send_header.assert_any_call(
registry.HMAC_HEADER,
base64.b64encode(hmac.HMAC(key, result, hashlib.sha1).digest()))
request.wfile.write.assert_called_once_with(result)
# remove the create session \n
del self.server.sessions[prefix]
@patch("re6st.registry.RegistryServer.func", create=True)
def test_handle_request_private(self, func):
"""case request with _private attr"""
method = "func"
params = {"a" : 1, "b" : 2}
func.getcallargs.return_value = params
func.return_value = None
request_good = Mock()
request_good.client_address = self.config.authorized_origin
request_good.headers = {'X-Forwarded-For':self.config.authorized_origin[0]}
request_bad = Mock()
request_bad.client_address = ["wrong_address"]
self.server.handle_request(request_good, method, params)
self.server.handle_request(request_bad, method, params)
func.assert_called_once_with(**params)
request_bad.send_error.assert_called_once_with(httplib.FORBIDDEN)
request_good.send_response.assert_called_once_with(httplib.NO_CONTENT)
# will cause valueError, if a node send hello twice to a registry
def test_getPeerProtocol(self):
prefix = "0000000011111110"
insert_cert(self.server.db, self.server.cert, prefix)
protocol = 7
self.server.hello(prefix, protocol)
# self.server.hello(prefix)
res = self.server.getPeerProtocol(prefix)
self.assertEqual(res, protocol)
def test_hello(self):
prefix = "0000000011111111"
protocol = 7
cur = self.server.db.cursor()
pkey, _ = insert_cert(cur, self.server.cert, prefix)
res = self.server.hello(prefix, protocol=protocol)
# decrypt
length = len(res)/2
key, sign = res[:length], res[length:]
key = decrypt(pkey, key)
self.assertEqual(self.server.sessions[prefix][-1][0], key,
"different hmac key")
self.assertEqual(self.server.sessions[prefix][-1][1], protocol)
self.server.sessions[prefix][-1] = None
delete_cert(cur, prefix)
def test_addToken(self):
# generate random token
token_spec = "aaaabbbb"
token = self.server.addToken(self.email, None)
self.server.addToken(self.email, token_spec)
self.assertIsNotNone(token)
self.assertTrue(self.server.isToken(token))
self.assertTrue(self.server.isToken(token_spec))
# remove the affect of the function
self.server.deleteToken(token)
self.server.deleteToken(token_spec)
def test_newPrefix(self):
length = 16
res = self.server.newPrefix(length)
self.assertEqual(len(res), length)
self.assertLessEqual(set(res), {'0', '1'}, "%s is not a binary" % res)
# TODO test too many prefix
@patch("re6st.registry.RegistryServer.sendto", Mock())
@patch("re6st.registry.RegistryServer.createCertificate")
def test_requestCertificate(self, mock_func):
token = self.server.addToken(self.email, None)
fake_token = "aaaabbbb"
_, csr = generate_csr()
# unvalide token
self.server.requestCertificate(fake_token, csr)
# valide token
self.server.requestCertificate(token, csr)
self.assertIsNone(self.server.isToken(token), "token not delete")
mock_func.assert_called_once()
# check the call parameter
prefix, subject, pubkey = mock_func.call_args[0]
self.assertIsNotNone(subject.serialNumber)
def test_requestCertificate_anoymous(self):
_, csr = generate_csr()
if self.config.anonymous_prefix_length is None:
with self.assertRaises(registry.HTTPError):
self.server.requestCertificate(None, csr)
def test_getSubjectSerial(self):
serial = self.server.getSubjectSerial()
self.assertIsInstance(serial, int)
# test the smallest unique possible
nb_less = 0
for cert in self.server.iterCert():
s = cert[0].get_subject().serialNumber
if(s and int(s) <= serial):
nb_less += 1
self.assertEqual(nb_less, serial)
def test_createCertificate(self):
_, csr = generate_csr()
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
prefix = "00011111101001110"
subject = req.get_subject()
subject.serialNumber = str(self.server.getSubjectSerial())
self.server.db.execute("INSERT INTO cert VALUES (?,null,null)", (prefix,))
cert = self.server.createCertificate(prefix, subject, req.get_pubkey())
cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert)
self.assertEqual(cert.get_subject().CN, prefix2cn(prefix))
self.assertEqual(cert.get_serial_number(), self.server.getConfig('serial', 0))
self.assertIsNotNone(get_cert(self.server.db, prefix))
@patch("re6st.registry.RegistryServer.createCertificate")
def test_renewCertificate(self, mock_func):
# TODO condition crl
cur = self.server.db.cursor()
prefix_old = "11111"
prefix_new = "11110"
insert_cert(cur, self.server.cert, prefix_old, 1)
_, cert_new = insert_cert(cur, self.server.cert, prefix_new,
time.time() + 2 * registry.RENEW_PERIOD)
cur.close()
# need renew
self.server.renewCertificate(prefix_old)
# no need renew
res_new = self.server.renewCertificate(prefix_new)
prefix, subject, pubkey, not_after = mock_func.call_args[0]
self.assertEqual(prefix, prefix_old)
self.assertEqual(not_after, None)
self.assertEqual(res_new, cert_new)
cur = self.server.db.cursor()
delete_cert(cur, prefix_old)
delete_cert(cur, prefix_new)
cur.close()
@patch("re6st.registry.RegistryServer.sendto", Mock())
@patch("re6st.registry.RegistryServer.recv")
@patch("select.select", Mock(return_value=[1]))
def test_queryAddress(self, recv):
prefix = "000100100010001"
# one bad, one correct prefix
recv.side_effect = [("0", "a msg"), (prefix, "other msg")]
res = self.server._queryAddress(prefix)
self.assertEqual(res, "other msg")
@patch('re6st.registry.RegistryServer.updateNetworkConfig')
def test_revoke(self, mock_func):
# case: no ValueError
serial = insert_cert.serial
prefix = bin(serial)[2:].rjust(16, '0') # length 16 prefix
insert_cert(self.server.db, self.server.cert, prefix)
self.server.revoke(serial)
# ValueError if serial correspond cert not exist
mock_func.assert_called_once()
@patch('re6st.registry.RegistryServer.updateNetworkConfig', Mock())
def test_revoke_value(self):
# case: ValueError
serial = insert_cert.serial
prefix = bin(serial)[2:].rjust(16, '0') # length 16 prefix
insert_cert(self.server.db, self.server.cert, prefix, 1)
self.server.sessions.setdefault(prefix, "something")
self.server.revoke("%u/16" % serial) # 16 is length
self.assertIsNone(self.server.sessions.get(prefix))
self.assertIsNone(get_cert(self.server.db, prefix))
@patch("re6st.registry.RegistryServer.sendto", Mock())
def test_updateHMAC(self):
def get_hmac():
return [self.server.getConfig(registry.BABEL_HMAC[i], None)
for i in range(3)]
for i in range(3):
self.server.delHMAC(i)
# step 1
self.server.updateHMAC()
hmacs = get_hmac()
key_1 = hmacs[1]
self.assertEqual(hmacs, [None, key_1, ''])
# step 2
self.server.updateHMAC()
self.assertEqual(get_hmac(), [key_1, None, None])
# step 3
self.server.updateHMAC()
hmacs = get_hmac()
key_2 = hmacs[1]
self.assertEqual(get_hmac(), [key_1, key_2, None])
# step 4
self.server.updateHMAC()
self.assertEqual(get_hmac(), [None, key_2, key_1])
#setp 5
self.server.updateHMAC()
self.assertEqual(get_hmac(), [key_2, None, None])
def test_getNodePrefix(self):
# prefix in short format
prefix = "0000000101"
insert_cert(self.server.db, self.server.cert, prefix, email=self.email)
res = self.server.getNodePrefix(self.email)
self.assertEqual(res, prefix2cn(prefix))
@patch("select.select")
@patch("re6st.registry.RegistryServer.recv")
@patch("re6st.registry.RegistryServer.sendto", Mock())
# use case which recored form demo
def test_topology(self, recv, select):
recv_case = [
('0000000000000000', '2 6/16 7/16 1/16 3/16 36893488147419103232/80 4/16'),
('00000000000000100000000000000000000000000000000000000000000000000000000000000000', '2 0/16 7/16'),
('0000000000000011', '2 0/16 7/16'),
('0000000000000111', '2 4/16 6/16 0/16 3/16 36893488147419103232/80'),
('0000000000000111', '2 4/16 6/16 0/16 3/16 36893488147419103232/80'),
('0000000000000001', '2 0/16 6/16')
]
recv.side_effect = recv_case
def side_effct(rlist, wlist, elist, timeout):
# rlist is true until the len(recv_case)th call
side_effct.i -= side_effct.i > 0
return [side_effct.i, wlist, None]
side_effct.i = len(recv_case) + 1
select.side_effect = side_effct
res = self.server.topology()
expect_res = '{"36893488147419103232/80": ["0/16", "7/16"], ' \
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], ' \
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '\
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}'''
self.assertEqual(res, expect_res)
if __name__ == "__main__":
unittest.main()
import sys
import os
import unittest
import hmac
import httplib
import base64
import hashlib
from mock import Mock, patch
from re6st import registry
class TestRegistryClient(unittest.TestCase):
@classmethod
def setUpClass(cls):
server_url = "http://10.0.0.2/"
cls.client = registry.RegistryClient(server_url)
cls.client._conn = Mock()
def test_init(self):
url1 = "https://localhost/example/"
url2 = "http://10.0.0.2/"
client1 = registry.RegistryClient(url1)
client2 = registry.RegistryClient(url2)
self.assertEqual(client1._path, "/example")
self.assertEqual(client1._conn.host, "localhost")
self.assertIsInstance(client1._conn, httplib.HTTPSConnection)
self.assertIsInstance(client2._conn, httplib.HTTPConnection)
def test_rpc_hello(self):
prefix = "0000000011111111"
protocol = "7"
body = "a_hmac_key"
query = "/hello?client_prefix=0000000011111111&protocol=7"
response = fakeResponse(body, httplib.OK)
self.client._conn.getresponse.return_value = response
res = self.client.hello(prefix, protocol)
self.assertEqual(res, body)
conn = self.client._conn
conn.putrequest.assert_called_once_with('GET', query, skip_accept_encoding=1)
conn.close.assert_not_called()
conn.endheaders.assert_called_once()
def test_rpc_with_cn(self):
query = "/getNetworkConfig?cn=0000000011111111"
cn = "0000000011111111"
# hmac part
self.client._hmac = None
self.client.hello = Mock(return_value = "aaabbb")
self.client.cert = Mock()
key = "this_is_a_key"
self.client.cert.decrypt.return_value = key
h = hmac.HMAC(key, query, hashlib.sha1).digest()
key = hashlib.sha1(key).digest()
# response part
body = None
response = fakeResponse(body, httplib.NO_CONTENT)
response.msg = dict(Re6stHMAC=hmac.HMAC(key, body, hashlib.sha1).digest())
self.client._conn.getresponse.return_value = response
res = self.client.getNetworkConfig(cn)
self.client.cert.verify.assert_called_once_with("bbb", "aaa")
self.assertEqual(self.client._hmac, hashlib.sha1(key).digest())
conn = self.client._conn
conn.putheader.assert_called_with("Re6stHMAC", base64.b64encode(h))
conn.close.assert_called_once()
self.assertEqual(res, body)
class fakeResponse:
def __init__(self, body, status, reason = None):
self.body = body
self.status = status
self.reason = reason
def read(self):
return self.body
if __name__ == "__main__":
unittest.main()
__all__ = ["test_multi_gateway_manager", "test_base_tunnel_manager"]
#!/usr/bin/python2
import os
import sys
import unittest
import time
from mock import patch, Mock
from re6st import tunnel
from re6st import x509
from re6st import cache
from re6st.tests import tools
class testBaseTunnelManager(unittest.TestCase):
@classmethod
def setUpClass(cls):
ca_key, ca = tools.create_ca_file("ca.key", "ca.cert")
tools.create_cert_file("node.key", "node.cert", ca, ca_key, "00000001", 1)
cls.cert = x509.Cert("ca.cert", "node.key", "node.cert")
cls.control_socket = "babeld.sock"
def setUp(self):
patcher = patch("re6st.cache.Cache")
pacher_sock = patch("socket.socket")
self.addCleanup(patcher.stop)
self.addCleanup(pacher_sock.stop)
self.cache = patcher.start()()
self.sock = pacher_sock.start()
self.cache.same_country = False
address = [(2, [('10.0.0.2', '1194', 'udp'), ('10.0.0.2', '1194', 'tcp')])]
self.tunnel = tunnel.BaseTunnelManager(self.control_socket,
self.cache, self.cert, None, address)
def tearDown(self):
self.tunnel.close()
del self.tunnel
@patch("re6st.tunnel.BaseTunnelManager._babel_dump_one", create=True)
@patch("re6st.tunnel.BaseTunnelManager._babel_dump_two", create=True)
def test_babel_dump(self, two, one):
""" case two func in requesting_dump"""
self.tunnel._BaseTunnelManager__requesting_dump = set(['one', 'two'])
self.tunnel.babel_dump()
# assert is empty
self.assertFalse(self.tunnel._BaseTunnelManager__requesting_dump)
one.assert_called_once()
two.assert_called_once()
@patch("re6st.ctl.Babel.request_dump")
def test_request_dump_empty(self, request_dump):
"""case when self.__requesting_dump is None or empty"""
reason = "rina"
self.tunnel._BaseTunnelManager__request_dump(reason)
self.assertEqual(self.tunnel._BaseTunnelManager__requesting_dump, set([reason]))
request_dump.assert_called_once()
@patch("re6st.ctl.Babel.request_dump")
def test___request_dump_not_empty(self, request_dump):
"""case when self.__requesting_dump is not empty"""
self.tunnel._BaseTunnelManager__requesting_dump = set(["rina"])
reason = "reason"
self.tunnel._BaseTunnelManager__request_dump(reason)
self.assertEqual(self.tunnel._BaseTunnelManager__requesting_dump, set([reason, "rina"]))
request_dump.assert_not_called()
def test_selectTimeout_add_callback(self):
"""case add new callback"""
self.tunnel._timeouts = [(1, self.tunnel.close)]
callback = self.tunnel.babel_dump
self.tunnel.selectTimeout(10, callback)
self.assertIn((10, callback), self.tunnel._timeouts)
def test_selectTimeout_removing(self):
"""case remove a callback"""
removed = self.tunnel.babel_dump
self.tunnel._timeouts = [(1, self.tunnel.close), (10, removed)]
self.tunnel.selectTimeout(None, removed)
self.assertEqual(self.tunnel._timeouts, [(1, self.tunnel.close)])
def test_selectTimeout_update(self):
"""case update a callback"""
updated = self.tunnel.babel_dump
self.tunnel._timeouts = [(1, self.tunnel.close), (10, updated)]
self.tunnel.selectTimeout(100, updated)
self.assertEqual(self.tunnel._timeouts, [(1, self.tunnel.close), (100, updated)])
@patch("re6st.tunnel.BaseTunnelManager.selectTimeout")
def test_invalidatePeers(self, selectTimeout):
"""normal case, stop_date: p2 < now < p1 < p3
expect:
_peers -> [p1, p3]
next = p1.stoptime
"""
p1 = x509.Peer("00")
p2 = x509.Peer("01")
p3 = x509.Peer("10")
p1.stop_date = time.time() + 1000
p2.stop_date = 1
p3.stop_date = p1.stop_date + 500
self.tunnel._peers = [p1, p2, p3]
self.tunnel.invalidatePeers()
self.assertEqual(self.tunnel._peers, [p1, p3])
selectTimeout.assert_called_once_with(p1.stop_date, self.tunnel.invalidatePeers)
# Because _makeTunnel is defined in sub class of BaseTunnelManager, so i comment
# the follow test
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# def test_processPacket_address_with_msg_peer(self, makeTunnel):
# """code is 1, peer and msg not none """
# c = chr(1)
# msg = "address"
# peer = x509.Peer("000001")
# self.tunnel._connecting = {peer}
# self.tunnel._processPacket(c + msg, peer)
# self.cache.addPeer.assert_called_once_with(peer, msg)
# self.assertFalse(self.tunnel._connecting)
# makeTunnel.assert_called_once_with(peer, msg)
def test_processPacket_address(self):
"""code is 1, for address. And peer or msg are none"""
c = chr(1)
self.tunnel._address = {1: "1,1", 2: "2,2"}
res = self.tunnel._processPacket(c)
self.assertEqual(res, "1,1;2,2")
def test_processPacket_address_with_peer(self):
"""code is 1, peer is not none, msg is none
in my opion, this function return address in form address,port,portocl
and each address join by ;
it will truncate address which has more than 3 element
"""
c = chr(1)
peer = x509.Peer("000001")
peer.protocol = 1
self.tunnel._peers.append(peer)
self.tunnel._address = {1: "1,1,1;0,0,0", 2: "2,2,2,2"}
res = self.tunnel._processPacket(c, peer)
self.assertEqual(res, "1,1,1;0,0,0;2,2,2")
@patch("re6st.tunnel.BaseTunnelManager.selectTimeout")
def test_processPacket_version(self, selectTimeout):
"""code is 0, for network version, peer is none"""
c = chr(0)
self.tunnel._processPacket(c)
self.assertEqual(selectTimeout.call_args[0][1], self.tunnel.newVersion)
@patch("re6st.x509.Cert.verifyVersion", Mock(return_value=True))
@patch("re6st.tunnel.BaseTunnelManager.selectTimeout")
def test_processPacket_version(self, selectTimeout):
"""code is 0, for network version, peer is not none
2 case, one modify the version, one not
"""
c = chr(0)
peer = x509.Peer("000001")
version1 = "00003"
version2 = "00007"
self.tunnel._version = "00005"
self.tunnel._peers.append(peer)
res = self.tunnel._processPacket(c + version1, peer)
self.tunnel._processPacket(c + version2, peer)
self.assertEqual(res, "00005")
self.assertEqual(self.tunnel._version, version2)
self.assertEqual(peer.version, version2)
self.assertEqual(selectTimeout.call_args[0][1], self.tunnel.newVersion)
if __name__ == "__main__":
unittest.main()
#!/usr/bin/python2
import os
import sys
import unittest
from mock import patch
from re6st import tunnel
class testMultGatewayManager(unittest.TestCase):
def setUp(self):
self.manager = tunnel.MultiGatewayManager(lambda x:x+x)
patcher = patch("subprocess.check_call")
self.addCleanup(patcher.stop)
self.sub = patcher.start()
@patch("logging.trace", create=True)
def test_add(self, log_trace):
"""add new dest twice"""
dest = "dest"
self.manager.add(dest, True)
self.manager.add(dest, True)
self.assertEqual(self.manager[dest][1], 1)
self.sub.assert_called_once()
cmd = log_trace.call_args[0][1]
self.assertIn(dest+dest, cmd)
self.assertIn("add", cmd)
def test_add_null_route(self):
""" add two dest which don't call ip route"""
dest1 = "dest1"
dest2 = ""
self.manager.add(dest1, False)
self.manager.add(dest2, True)
self.sub.assert_not_called()
@patch("logging.trace", create=True)
def test_remove(self, log_trace):
"remove a dest twice"
dest = "dest"
gw = "gw"
self.manager[dest] = [gw,1]
self.manager.remove(dest)
self.assertEqual(self.manager[dest][1], 0)
self.manager.remove(dest)
self.sub.assert_called_once()
self.assertIsNone(self.manager.get(dest))
cmd = log_trace.call_args[0][1]
self.assertIn(gw, cmd)
self.assertIn("del", cmd)
def test_remove_null_gw(self):
""" remove a dest which don't have gw"""
dest = "dest"
gw = ""
self.manager[dest] = [gw, 0]
self.manager.remove(dest)
self.assertIsNone(self.manager.get(dest))
self.sub.assert_not_called()
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
import sys
import os
import time
import subprocess
from OpenSSL import crypto
from re6st import registry
def generate_csr():
"""generate a certificate request
return:
crypto.Pekey and crypto.X509Req both in pem format
"""
key = crypto.PKey()
key.generate_key(crypto.TYPE_RSA, 2048)
req = crypto.X509Req()
req.set_pubkey(key)
req.get_subject().CN = "test ca"
req.sign(key, 'sha256')
csr = crypto.dump_certificate_request(crypto.FILETYPE_PEM, req)
pkey = crypto.dump_privatekey(crypto.FILETYPE_PEM, key)
return pkey, csr
def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
"""generate a certificate
return
crypto.X509Cert in pem format
"""
if type(ca) is str:
ca = crypto.load_certificate(crypto.FILETYPE_PEM, ca)
if type(ca_key) is str:
ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, ca_key)
req = crypto.load_certificate_request(crypto.FILETYPE_PEM, csr)
cert = crypto.X509()
cert.gmtime_adj_notBefore(0)
if not_after:
cert.set_notAfter(
time.strftime("%Y%m%d%H%M%SZ", time.gmtime(not_after)))
else:
cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration)
subject = req.get_subject()
if prefix:
subject.CN = prefix2cn(prefix)
cert.set_subject(req.get_subject())
cert.set_issuer(ca.get_subject())
cert.set_pubkey(req.get_pubkey())
cert.set_serial_number(serial)
cert.sign(ca_key, 'sha512')
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
return cert
def create_cert_file(pkey_file, cert_file, ca, ca_key, prefix, serial):
pkey, csr = generate_csr()
cert = generate_cert(ca, ca_key, csr, prefix, serial)
with open(pkey_file, 'w') as f:
f.write(pkey)
with open(cert_file, 'w') as f:
f.write(cert)
return pkey, cert
def create_ca_file(pkey_file, cert_file):
key = crypto.PKey()
key.generate_key(crypto.TYPE_RSA, 2048)
cert = crypto.X509()
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(registry.RegistryServer.cert_duration)
subject= cert.get_subject()
subject.C = "FR"
subject.ST = "Lille"
subject.L = "Lille"
subject.O = "nexedi"
subject.CN = "TEST-CA"
cert.set_issuer(cert.get_subject())
cert.set_serial_number(10000)
cert.set_pubkey(key)
cert.sign(key, "sha512")
with open(pkey_file, 'w') as pkey_file:
pkey_file.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, key))
with open(cert_file, 'w') as cert_file:
cert_file.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert))
return key, cert
def prefix2cn(prefix):
return "%u/%u" % (int(prefix, 2), len(prefix))
def serial2prefix(serial):
return bin(serial)[2:].rjust(16, '0')
# pkey: private key
def decrypt(pkey, incontent):
with open("node.key", 'w') as f:
f.write(pkey)
args = "openssl rsautl -decrypt -inkey node.key".split()
p = subprocess.Popen(
args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
outcontent, err = p.communicate(incontent)
return outcontent
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment