Commit 9f97ea8d authored by Jim Fulton's avatar Jim Fulton

Finished SSL tests

parent 2fdc0283
...@@ -87,7 +87,7 @@ class Protocol(base.Protocol): ...@@ -87,7 +87,7 @@ class Protocol(base.Protocol):
ssl=self.ssl, server_hostname=self.ssl_server_hostname) ssl=self.ssl, server_hostname=self.ssl_server_hostname)
else: else:
cr = self.loop.create_unix_connection( cr = self.loop.create_unix_connection(
self.protocol_factory, self.addr) self.protocol_factory, self.addr, ssl=self.ssl)
self._connecting = cr = asyncio.async(cr, loop=self.loop) self._connecting = cr = asyncio.async(cr, loop=self.loop)
......
...@@ -31,7 +31,8 @@ class Loop: ...@@ -31,7 +31,8 @@ class Loop:
future.set_exception(ConnectionRefusedError()) future.set_exception(ConnectionRefusedError())
def create_connection( def create_connection(
self, protocol_factory, host=None, port=None, sock=None self, protocol_factory, host=None, port=None, sock=None,
ssl=None, server_hostname=None
): ):
future = asyncio.Future(loop=self) future = asyncio.Future(loop=self)
if sock is None: if sock is None:
......
...@@ -127,7 +127,8 @@ class CommitLockTests: ...@@ -127,7 +127,8 @@ class CommitLockTests:
# list is a socket domain (AF_INET, AF_UNIX, etc.) and an # list is a socket domain (AF_INET, AF_UNIX, etc.) and an
# address. # address.
addr = self._storage._addr addr = self._storage._addr
new = ZEO.ClientStorage.ClientStorage(addr, wait=1) new = ZEO.ClientStorage.ClientStorage(
addr, wait=1, **self._client_options())
new.registerDB(DummyDB()) new.registerDB(DummyDB())
return new return new
......
...@@ -36,6 +36,8 @@ import ZODB.tests.util ...@@ -36,6 +36,8 @@ import ZODB.tests.util
import transaction import transaction
from transaction import Transaction from transaction import Transaction
from . import testssl
logger = logging.getLogger('ZEO.tests.ConnectionTests') logger = logging.getLogger('ZEO.tests.ConnectionTests')
ZERO = '\0'*8 ZERO = '\0'*8
...@@ -66,7 +68,6 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -66,7 +68,6 @@ class CommonSetupTearDown(StorageTestBase):
keep = 0 keep = 0
invq = None invq = None
timeout = None timeout = None
monitor = 0
db_class = DummyDB db_class = DummyDB
def setUp(self, before=None): def setUp(self, before=None):
...@@ -147,18 +148,17 @@ class CommonSetupTearDown(StorageTestBase): ...@@ -147,18 +148,17 @@ class CommonSetupTearDown(StorageTestBase):
min_disconnect_poll=0.1, min_disconnect_poll=0.1,
read_only=read_only, read_only=read_only,
read_only_fallback=read_only_fallback, read_only_fallback=read_only_fallback,
username=username, **self._client_options())
password=password,
realm=realm)
storage.registerDB(DummyDB()) storage.registerDB(DummyDB())
return storage return storage
def _client_options(self):
return {}
def getServerConfig(self, addr, ro_svr): def getServerConfig(self, addr, ro_svr):
zconf = forker.ZEOConfig(addr) zconf = forker.ZEOConfig(addr)
if ro_svr: if ro_svr:
zconf.read_only = 1 zconf.read_only = 1
if self.monitor:
zconf.monitor_address = ("", 42000)
if self.invq: if self.invq:
zconf.invalidation_queue_size = self.invq zconf.invalidation_queue_size = self.invq
if self.timeout: if self.timeout:
...@@ -564,6 +564,18 @@ class ConnectionTests(CommonSetupTearDown): ...@@ -564,6 +564,18 @@ class ConnectionTests(CommonSetupTearDown):
self.assertRaises(ClientDisconnected, self.assertRaises(ClientDisconnected,
self._storage.load, b'\0'*8, '') self._storage.load, b'\0'*8, '')
class SSLConnectionTests(ConnectionTests):
def getServerConfig(self, addr, ro_svr):
return testssl.server_config.replace(
'127.0.0.1:0',
'{}: {}\nread-only {}'.format(
addr[0], addr[1], 'true' if ro_svr else 'false'))
def _client_options(self):
return {'ssl': testssl.client_ssl()}
class InvqTests(CommonSetupTearDown): class InvqTests(CommonSetupTearDown):
invq = 3 invq = 3
......
...@@ -25,7 +25,7 @@ import tempfile ...@@ -25,7 +25,7 @@ import tempfile
import six import six
import ZODB.tests.util import ZODB.tests.util
import zope.testing.setupstack import zope.testing.setupstack
from ZEO._compat import BytesIO from ZEO._compat import StringIO
logger = logging.getLogger('ZEO.tests.forker') logger = logging.getLogger('ZEO.tests.forker')
...@@ -69,9 +69,9 @@ class ZEOConfig: ...@@ -69,9 +69,9 @@ class ZEOConfig:
""" % (self.loglevel, self.logpath), file=f) """ % (self.loglevel, self.logpath), file=f)
def __str__(self): def __str__(self):
f = BytesIO() f = StringIO()
self.dump(f) self.dump(f)
return f.getvalue().decode() return f.getvalue()
def encode_format(fmt): def encode_format(fmt):
...@@ -194,7 +194,7 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False, ...@@ -194,7 +194,7 @@ def start_zeo_server(storage_conf=None, zeo_conf=None, port=None, keep=False,
# Store the config info in a temp file. # Store the config info in a temp file.
tmpfile = tempfile.mktemp(".conf", dir=os.getcwd()) tmpfile = tempfile.mktemp(".conf", dir=os.getcwd())
fp = open(tmpfile, 'w') fp = open(tmpfile, 'w')
fp.write(zeo_conf + '\n\n') fp.write(str(zeo_conf) + '\n\n')
fp.write(storage_conf) fp.write(storage_conf)
fp.close() fp.close()
......
...@@ -12,28 +12,14 @@ ...@@ -12,28 +12,14 @@
# #
############################################################################## ##############################################################################
import mock
import os
import ssl
import unittest import unittest
from zope.testing import setupstack from zope.testing import setupstack
from ZODB.config import storageFromString from ZODB.config import storageFromString
from ..Exceptions import ClientDisconnected
from .. import runzeo
from .forker import start_zeo_server from .forker import start_zeo_server
here = os.path.dirname(__file__)
server_cert = os.path.join(here, 'server.pem')
server_key = os.path.join(here, 'server_key.pem')
serverpw_cert = os.path.join(here, 'serverpw.pem')
serverpw_key = os.path.join(here, 'serverpw_key.pem')
client_cert = os.path.join(here, 'client.pem')
client_key = os.path.join(here, 'client_key.pem')
class ZEOConfigTest(setupstack.TestCase): class ZEOConfigTest(setupstack.TestCase):
setUp = setupstack.setUpDirectory setUp = setupstack.setUpDirectory
...@@ -129,301 +115,5 @@ class ZEOConfigTest(setupstack.TestCase): ...@@ -129,301 +115,5 @@ class ZEOConfigTest(setupstack.TestCase):
self.test_default_zeo_config(blob_cache_size=424242, self.test_default_zeo_config(blob_cache_size=424242,
blob_cache_size_check=50) blob_cache_size_check=50)
def test_ssl_basic(self):
# This shows that configuring ssl has an actual effect on connections.
# Other SSL configuration tests will be Mockiavellian.
# Also test that an SSL connection mismatch doesn't kill
# the server loop.
# An SSL client can't talk to a non-SSL server:
addr, stop = self.start_server()
with self.assertRaises(ClientDisconnected):
self.start_client(
addr,
"""<ssl>
certificate {}
key {}
</ssl>""".format(client_cert, client_key), wait_timeout=1)
# But a non-ssl one can:
client = self.start_client(addr)
self._client_assertions(client, addr)
client.close()
stop()
# A non-SSL client can't talk to an SSL server:
addr, stop = self.start_server(
"""<ssl>
certificate {}
key {}
authenticate {}
</ssl>""".format(server_cert, server_key, client_cert)
)
with self.assertRaises(ClientDisconnected):
self.start_client(addr, wait_timeout=1)
# But an SSL one can:
client = self.start_client(
addr,
"""<ssl>
certificate {}
key {}
authenticate {}
</ssl>""".format(client_cert, client_key, server_cert))
self._client_assertions(client, addr)
client.close()
stop()
def test_ssl_hostname_check(self):
addr, stop = self.start_server(
"""<ssl>
certificate {}
key {}
authenticate {}
</ssl>""".format(server_cert, server_key, client_cert)
)
# Connext with bad hostname fails:
with self.assertRaises(ClientDisconnected):
client = self.start_client(
addr,
"""<ssl>
certificate {}
key {}
authenticate {}
server-hostname example.org
</ssl>""".format(client_cert, client_key, server_cert),
wait_timeout=1)
# Connext with good hostname succeeds:
client = self.start_client(
addr,
"""<ssl>
certificate {}
key {}
authenticate {}
server-hostname zodb.org
</ssl>""".format(client_cert, client_key, server_cert))
self._client_assertions(client, addr)
client.close()
stop()
def test_ssl_pw(self):
addr, stop = self.start_server(
"""<ssl>
certificate {}
key {}
authenticate {}
password-function ZEO.tests.testConfig.pwfunc
</ssl>""".format(serverpw_cert, serverpw_key, client_cert)
)
stop()
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_no_ssl(self, factory):
server = create_server()
self.assertFalse(factory.called)
self.assertEqual(server.acceptor._Acceptor__ssl, None)
server.close()
def assert_context(
self, factory, context,
cert=(server_cert, server_key, None),
verify_mode=ssl.CERT_REQUIRED,
check_hostname=False,
cafile=None, capath=None,
):
factory.assert_called_with(
ssl.Purpose.CLIENT_AUTH, cafile=cafile, capath=capath)
context.load_cert_chain.assert_called_with(*cert)
self.assertEqual(context, factory.return_value)
self.assertEqual(context.verify_mode, verify_mode)
self.assertEqual(context.check_hostname, check_hostname)
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_ssl_no_auth(self, factory):
with self.assertRaises(SystemExit):
# auth is required
create_server(certificate=server_cert, key=server_key)
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_ssl_auth_file(self, factory):
server = create_server(
certificate=server_cert, key=server_key, authenticate=__file__)
context = server.acceptor._Acceptor__ssl
self.assert_context(factory, context, cafile=__file__)
server.close()
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_ssl_auth_dir(self, factory):
server = create_server(
certificate=server_cert, key=server_key, authenticate=here)
context = server.acceptor._Acceptor__ssl
self.assert_context(factory, context, capath=here)
server.close()
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_ssl_pw(self, factory):
server = create_server(
certificate=server_cert,
key=server_key,
password_function='ZEO.tests.testConfig.pwfunc',
authenticate=here,
)
context = server.acceptor._Acceptor__ssl
self.assert_context(
factory, context, (server_cert, server_key, pwfunc), capath=here)
server.close()
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_no_ssl(self, ClientStorage, factory):
client = ssl_client()
self.assertFalse('ssl' in ClientStorage.call_args[1])
self.assertFalse('ssl_server_hostname' in ClientStorage.call_args[1])
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_server_signed(
self, ClientStorage, factory
):
client = ssl_client(certificate=client_cert, key=client_key)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, None),
check_hostname=True)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_auth_dir(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=here)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, None),
capath=here,
)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_auth_file(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, None),
cafile=server_cert,
)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_pw(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key,
password_function='ZEO.tests.testConfig.pwfunc',
authenticate=server_cert)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, pwfunc),
check_hostname=False,
cafile=server_cert,
)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_server_hostname(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert,
server_hostname='example.com')
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
'example.com')
self.assert_context(
factory, context, (client_cert, client_key, None),
cafile=server_cert,
check_hostname=True,
)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_check_hostname(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert,
check_hostname=True)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, None),
cafile=server_cert,
check_hostname=True,
)
def args(*a, **kw):
return a, kw
def ssl_conf(**ssl_settings):
if ssl_settings:
ssl_conf = '<ssl>\n' + '\n'.join(
'{} {}'.format(name.replace('_', '-'), value)
for name, value in ssl_settings.items()
) + '\n</ssl>\n'
else:
ssl_conf = ''
return ssl_conf
def ssl_client(**ssl_settings):
return storageFromString(
"""%import ZEO
<clientstorage>
server localhost:0
{}
</clientstorage>
""".format(ssl_conf(**ssl_settings))
)
def create_server(**ssl_settings):
with open('conf', 'w') as f:
f.write(
"""
<zeo>
address localhost:0
{}
</zeo>
<mappingstorage>
</mappingstorage>
""".format(ssl_conf(**ssl_settings)))
options = runzeo.ZEOOptions()
options.realize(['-C', 'conf'])
s = runzeo.ZEOServer(options)
s.open_storages()
s.create_server()
return s.server
pwfunc = lambda : '1234'
def test_suite(): def test_suite():
return unittest.makeSuite(ZEOConfigTest) return unittest.makeSuite(ZEOConfigTest)
...@@ -79,6 +79,12 @@ class MappingStorageConnectionTests( ...@@ -79,6 +79,12 @@ class MappingStorageConnectionTests(
): ):
"""Mapping storage connection tests.""" """Mapping storage connection tests."""
class SSLConnectionTests(
MappingStorageConfig,
ConnectionTests.SSLConnectionTests,
):
pass
# The ReconnectionTests can't work with MappingStorage because it's only an # The ReconnectionTests can't work with MappingStorage because it's only an
# in-memory storage and has no persistent state. # in-memory storage and has no persistent state.
...@@ -88,6 +94,12 @@ class MappingStorageTimeoutTests( ...@@ -88,6 +94,12 @@ class MappingStorageTimeoutTests(
): ):
pass pass
class SSLConnectionTests(
MappingStorageConfig,
ConnectionTests.SSLConnectionTests,
):
pass
test_classes = [FileStorageConnectionTests, test_classes = [FileStorageConnectionTests,
FileStorageReconnectionTests, FileStorageReconnectionTests,
...@@ -95,6 +107,7 @@ test_classes = [FileStorageConnectionTests, ...@@ -95,6 +107,7 @@ test_classes = [FileStorageConnectionTests,
FileStorageTimeoutTests, FileStorageTimeoutTests,
MappingStorageConnectionTests, MappingStorageConnectionTests,
MappingStorageTimeoutTests, MappingStorageTimeoutTests,
SSLConnectionTests,
] ]
def invalidations_while_connecting(): def invalidations_while_connecting():
......
...@@ -39,6 +39,7 @@ import re ...@@ -39,6 +39,7 @@ import re
import shutil import shutil
import signal import signal
import stat import stat
import ssl
import sys import sys
import tempfile import tempfile
import threading import threading
...@@ -55,6 +56,8 @@ import ZODB.tests.util ...@@ -55,6 +56,8 @@ import ZODB.tests.util
import ZODB.utils import ZODB.utils
import zope.testing.setupstack import zope.testing.setupstack
from . import testssl
logger = logging.getLogger('ZEO.tests.testZEO') logger = logging.getLogger('ZEO.tests.testZEO')
class DummyDB: class DummyDB:
...@@ -99,7 +102,7 @@ class MiscZEOTests: ...@@ -99,7 +102,7 @@ class MiscZEOTests:
def checkZEOInvalidation(self): def checkZEOInvalidation(self):
addr = self._storage._addr addr = self._storage._addr
storage2 = self._wrap_client( storage2 = self._wrap_client(
ClientStorage(addr, wait=1, min_disconnect_poll=0.1)) ClientStorage(addr, wait=1, **self._client_options()))
try: try:
oid = self._storage.new_oid() oid = self._storage.new_oid()
ob = MinPO('first') ob = MinPO('first')
...@@ -128,13 +131,13 @@ class MiscZEOTests: ...@@ -128,13 +131,13 @@ class MiscZEOTests:
# Earlier, a ClientStorage would not have the last transaction id # Earlier, a ClientStorage would not have the last transaction id
# available right after successful connection, this is required now. # available right after successful connection, this is required now.
addr = self._storage._addr addr = self._storage._addr
storage2 = ClientStorage(addr) storage2 = ClientStorage(addr, **self._client_options())
self.assert_(storage2.is_connected()) self.assert_(storage2.is_connected())
self.assertEquals(ZODB.utils.z64, storage2.lastTransaction()) self.assertEquals(ZODB.utils.z64, storage2.lastTransaction())
storage2.close() storage2.close()
self._dostore() self._dostore()
storage3 = ClientStorage(addr) storage3 = ClientStorage(addr, **self._client_options())
self.assert_(storage3.is_connected()) self.assert_(storage3.is_connected())
self.assertEquals(8, len(storage3.lastTransaction())) self.assertEquals(8, len(storage3.lastTransaction()))
self.assertNotEquals(ZODB.utils.z64, storage3.lastTransaction()) self.assertNotEquals(ZODB.utils.z64, storage3.lastTransaction())
...@@ -164,26 +167,33 @@ class GenericTests( ...@@ -164,26 +167,33 @@ class GenericTests(
def setUp(self): def setUp(self):
StorageTestBase.StorageTestBase.setUp(self) StorageTestBase.StorageTestBase.setUp(self)
logger.info("setUp() %s", self.id()) logger.info("setUp() %s", self.id())
port = get_port(self)
zconf = forker.ZEOConfig(('', port))
zport, stop = forker.start_zeo_server(self.getConfig(), zport, stop = forker.start_zeo_server(self.getConfig(),
zconf, port) self.getZEOConfig())
self._servers = [stop] self._servers = [stop]
if not self.blob_cache_dir: if not self.blob_cache_dir:
# This is the blob cache for ClientStorage # This is the blob cache for ClientStorage
self.blob_cache_dir = tempfile.mkdtemp( self.blob_cache_dir = tempfile.mkdtemp(
'blob_cache', 'blob_cache',
dir=os.path.abspath(os.getcwd())) dir=os.path.abspath(os.getcwd()))
self._storage = self._wrap_client(ClientStorage( self._storage = self._wrap_client(
ClientStorage(
zport, '1', cache_size=20000000, zport, '1', cache_size=20000000,
min_disconnect_poll=0.5, wait=1, min_disconnect_poll=0.5, wait=1,
wait_timeout=60, blob_dir=self.blob_cache_dir, wait_timeout=60, blob_dir=self.blob_cache_dir,
shared_blob_dir=self.shared_blob_dir)) shared_blob_dir=self.shared_blob_dir,
**self._client_options()),
)
self._storage.registerDB(DummyDB()) self._storage.registerDB(DummyDB())
def getZEOConfig(self):
return forker.ZEOConfig(('', get_port(self)))
def _wrap_client(self, client): def _wrap_client(self, client):
return client return client
def _client_options(self):
return {}
def tearDown(self): def tearDown(self):
self._storage.close() self._storage.close()
for stop in self._servers: for stop in self._servers:
...@@ -204,7 +214,8 @@ class GenericTests( ...@@ -204,7 +214,8 @@ class GenericTests(
# cleaner way. # cleaner way.
addr = self._storage._addr addr = self._storage._addr
self._storage.close() self._storage.close()
self._storage = ClientStorage(addr, read_only=read_only, wait=1) self._storage = ClientStorage(
addr, read_only=read_only, wait=1, **self._client_options())
def checkWriteMethods(self): def checkWriteMethods(self):
# ReadOnlyStorage defines checkWriteMethods. The decision # ReadOnlyStorage defines checkWriteMethods. The decision
...@@ -223,7 +234,8 @@ class GenericTests( ...@@ -223,7 +234,8 @@ class GenericTests(
def _do_store_in_separate_thread(self, oid, revid, voted): def _do_store_in_separate_thread(self, oid, revid, voted):
def do_store(): def do_store():
store = ZEO.ClientStorage.ClientStorage(self._storage._addr) store = ZEO.ClientStorage.ClientStorage(
self._storage._addr, **self._client_options())
try: try:
t = transaction.get() t = transaction.get()
store.tpc_begin(t) store.tpc_begin(t)
...@@ -335,6 +347,16 @@ class FileStorageTests(FullGenericTests): ...@@ -335,6 +347,16 @@ class FileStorageTests(FullGenericTests):
self._storage._info['interfaces'] self._storage._info['interfaces']
) )
class FileStorageSSLTests(FileStorageTests):
def getZEOConfig(self):
return testssl.server_config
def _client_options(self):
return {'ssl': testssl.client_ssl()}
class FileStorageHexTests(FileStorageTests): class FileStorageHexTests(FileStorageTests):
_expected_interfaces = ( _expected_interfaces = (
('ZODB.interfaces', 'IStorageRestoreable'), ('ZODB.interfaces', 'IStorageRestoreable'),
...@@ -1486,7 +1508,8 @@ def can_use_empty_string_for_local_host_on_client(): ...@@ -1486,7 +1508,8 @@ def can_use_empty_string_for_local_host_on_client():
slow_test_classes = [ slow_test_classes = [
#BlobAdaptedFileStorageTests, BlobWritableCacheTests, #BlobAdaptedFileStorageTests, BlobWritableCacheTests,
MappingStorageTests, DemoStorageTests, MappingStorageTests, DemoStorageTests,
FileStorageTests, FileStorageHexTests, FileStorageClientHexTests, FileStorageTests, FileStorageSSLTests,
FileStorageHexTests, FileStorageClientHexTests,
] ]
quick_test_classes = [FileStorageRecoveryTests, ZRPCConnectionTests] quick_test_classes = [FileStorageRecoveryTests, ZRPCConnectionTests]
......
import mock
import os
import ssl
import unittest
from ZODB.config import storageFromString
from ..Exceptions import ClientDisconnected
from .. import runzeo
from .testConfig import ZEOConfigTest
here = os.path.dirname(__file__)
server_cert = os.path.join(here, 'server.pem')
server_key = os.path.join(here, 'server_key.pem')
serverpw_cert = os.path.join(here, 'serverpw.pem')
serverpw_key = os.path.join(here, 'serverpw_key.pem')
client_cert = os.path.join(here, 'client.pem')
client_key = os.path.join(here, 'client_key.pem')
class SSLConfigTest(ZEOConfigTest):
def test_ssl_basic(self):
# This shows that configuring ssl has an actual effect on connections.
# Other SSL configuration tests will be Mockiavellian.
# Also test that an SSL connection mismatch doesn't kill
# the server loop.
# An SSL client can't talk to a non-SSL server:
addr, stop = self.start_server()
with self.assertRaises(ClientDisconnected):
self.start_client(
addr,
"""<ssl>
certificate {}
key {}
</ssl>""".format(client_cert, client_key), wait_timeout=1)
# But a non-ssl one can:
client = self.start_client(addr)
self._client_assertions(client, addr)
client.close()
stop()
# A non-SSL client can't talk to an SSL server:
addr, stop = self.start_server(
"""<ssl>
certificate {}
key {}
authenticate {}
</ssl>""".format(server_cert, server_key, client_cert)
)
with self.assertRaises(ClientDisconnected):
self.start_client(addr, wait_timeout=1)
# But an SSL one can:
client = self.start_client(
addr,
"""<ssl>
certificate {}
key {}
authenticate {}
</ssl>""".format(client_cert, client_key, server_cert))
self._client_assertions(client, addr)
client.close()
stop()
def test_ssl_hostname_check(self):
addr, stop = self.start_server(
"""<ssl>
certificate {}
key {}
authenticate {}
</ssl>""".format(server_cert, server_key, client_cert)
)
# Connext with bad hostname fails:
with self.assertRaises(ClientDisconnected):
client = self.start_client(
addr,
"""<ssl>
certificate {}
key {}
authenticate {}
server-hostname example.org
</ssl>""".format(client_cert, client_key, server_cert),
wait_timeout=1)
# Connext with good hostname succeeds:
client = self.start_client(
addr,
"""<ssl>
certificate {}
key {}
authenticate {}
server-hostname zodb.org
</ssl>""".format(client_cert, client_key, server_cert))
self._client_assertions(client, addr)
client.close()
stop()
def test_ssl_pw(self):
addr, stop = self.start_server(
"""<ssl>
certificate {}
key {}
authenticate {}
password-function ZEO.tests.testssl.pwfunc
</ssl>""".format(serverpw_cert, serverpw_key, client_cert)
)
stop()
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_no_ssl(self, factory):
server = create_server()
self.assertFalse(factory.called)
self.assertEqual(server.acceptor._Acceptor__ssl, None)
server.close()
def assert_context(
self, factory, context,
cert=(server_cert, server_key, None),
verify_mode=ssl.CERT_REQUIRED,
check_hostname=False,
cafile=None, capath=None,
):
factory.assert_called_with(
ssl.Purpose.CLIENT_AUTH, cafile=cafile, capath=capath)
context.load_cert_chain.assert_called_with(*cert)
self.assertEqual(context, factory.return_value)
self.assertEqual(context.verify_mode, verify_mode)
self.assertEqual(context.check_hostname, check_hostname)
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_ssl_no_auth(self, factory):
with self.assertRaises(SystemExit):
# auth is required
create_server(certificate=server_cert, key=server_key)
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_ssl_auth_file(self, factory):
server = create_server(
certificate=server_cert, key=server_key, authenticate=__file__)
context = server.acceptor._Acceptor__ssl
self.assert_context(factory, context, cafile=__file__)
server.close()
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_ssl_auth_dir(self, factory):
server = create_server(
certificate=server_cert, key=server_key, authenticate=here)
context = server.acceptor._Acceptor__ssl
self.assert_context(factory, context, capath=here)
server.close()
@mock.patch('ssl.create_default_context')
def test_ssl_mockiavellian_server_ssl_pw(self, factory):
server = create_server(
certificate=server_cert,
key=server_key,
password_function='ZEO.tests.testssl.pwfunc',
authenticate=here,
)
context = server.acceptor._Acceptor__ssl
self.assert_context(
factory, context, (server_cert, server_key, pwfunc), capath=here)
server.close()
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_no_ssl(self, ClientStorage, factory):
client = ssl_client()
self.assertFalse('ssl' in ClientStorage.call_args[1])
self.assertFalse('ssl_server_hostname' in ClientStorage.call_args[1])
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_server_signed(
self, ClientStorage, factory
):
client = ssl_client(certificate=client_cert, key=client_key)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, None),
check_hostname=True)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_auth_dir(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=here)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, None),
capath=here,
)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_auth_file(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, None),
cafile=server_cert,
)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_pw(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key,
password_function='ZEO.tests.testssl.pwfunc',
authenticate=server_cert)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, pwfunc),
check_hostname=False,
cafile=server_cert,
)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_server_hostname(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert,
server_hostname='example.com')
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
'example.com')
self.assert_context(
factory, context, (client_cert, client_key, None),
cafile=server_cert,
check_hostname=True,
)
@mock.patch('ssl.create_default_context')
@mock.patch('ZEO.ClientStorage.ClientStorage')
def test_ssl_mockiavellian_client_check_hostname(
self, ClientStorage, factory
):
client = ssl_client(
certificate=client_cert, key=client_key, authenticate=server_cert,
check_hostname=True)
context = ClientStorage.call_args[1]['ssl']
self.assertEqual(ClientStorage.call_args[1]['ssl_server_hostname'],
None)
self.assert_context(
factory, context, (client_cert, client_key, None),
cafile=server_cert,
check_hostname=True,
)
def args(*a, **kw):
return a, kw
def ssl_conf(**ssl_settings):
if ssl_settings:
ssl_conf = '<ssl>\n' + '\n'.join(
'{} {}'.format(name.replace('_', '-'), value)
for name, value in ssl_settings.items()
) + '\n</ssl>\n'
else:
ssl_conf = ''
return ssl_conf
def ssl_client(**ssl_settings):
return storageFromString(
"""%import ZEO
<clientstorage>
server localhost:0
{}
</clientstorage>
""".format(ssl_conf(**ssl_settings))
)
def create_server(**ssl_settings):
with open('conf', 'w') as f:
f.write(
"""
<zeo>
address localhost:0
{}
</zeo>
<mappingstorage>
</mappingstorage>
""".format(ssl_conf(**ssl_settings)))
options = runzeo.ZEOOptions()
options.realize(['-C', 'conf'])
s = runzeo.ZEOServer(options)
s.open_storages()
s.create_server()
return s.server
pwfunc = lambda : '1234'
def test_suite():
return unittest.makeSuite(SSLConfigTest)
# Helpers for other tests:
server_config = """
<zeo>
address 127.0.0.1:0
<ssl>
certificate {}
key {}
authenticate {}
</ssl>
</zeo>
""".format(server_cert, server_key, client_cert)
def client_ssl():
context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH, cafile=server_cert)
context.load_cert_chain(client_cert, client_key)
context.verify_mode = ssl.CERT_REQUIRED
context.check_hostname = False
return context
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment