Commit fb3ff711 authored by Hendrik von Prince's avatar Hendrik von Prince Committed by oroulet

Enable usage of certificates that are already loaded into memory

Certificates, that are already loaded into memory as bytes, can now directly be handed over
to set_security and load_certificate.
parent a0c8a450
...@@ -159,10 +159,10 @@ class Client: ...@@ -159,10 +159,10 @@ class Client:
async def set_security( async def set_security(
self, self,
policy: Type[ua.SecurityPolicy], policy: Type[ua.SecurityPolicy],
certificate: Union[str, uacrypto.CertProperties], certificate: Union[str, uacrypto.CertProperties, bytes],
private_key: Union[str, uacrypto.CertProperties], private_key: Union[str, uacrypto.CertProperties, bytes],
private_key_password: Optional[Union[str, bytes]] = None, private_key_password: Optional[Union[str, bytes]] = None,
server_certificate: Optional[Union[str, uacrypto.CertProperties]] = None, server_certificate: Optional[Union[str, uacrypto.CertProperties, bytes]] = None,
mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt, mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt,
): ):
""" """
...@@ -196,10 +196,10 @@ class Client: ...@@ -196,10 +196,10 @@ class Client:
): ):
if isinstance(server_cert, uacrypto.CertProperties): if isinstance(server_cert, uacrypto.CertProperties):
server_cert = await uacrypto.load_certificate(server_cert.path, server_cert.extension) server_cert = await uacrypto.load_certificate(server_cert.path_or_content, server_cert.extension)
cert = await uacrypto.load_certificate(certificate.path, certificate.extension) cert = await uacrypto.load_certificate(certificate.path_or_content, certificate.extension)
pk = await uacrypto.load_private_key( pk = await uacrypto.load_private_key(
private_key.path, private_key.path_or_content,
private_key.password, private_key.password,
private_key.extension, private_key.extension,
) )
......
...@@ -17,18 +17,27 @@ from dataclasses import dataclass ...@@ -17,18 +17,27 @@ from dataclasses import dataclass
@dataclass @dataclass
class CertProperties: class CertProperties:
path: str path_or_content: Union[str, bytes]
extension: Optional[str] = None extension: Optional[str] = None
password: Optional[Union[str, bytes]] = None password: Optional[Union[str, bytes]] = None
async def load_certificate(path: str, extension: Optional[str] = None): async def get_content(path_or_content: Union[str, bytes]) -> bytes:
_, ext = os.path.splitext(path) if isinstance(path_or_content, bytes):
async with aiofiles.open(path, mode='rb') as f: return path_or_content
async with aiofiles.open(path_or_content, mode='rb') as f:
return await f.read()
async def load_certificate(path_or_content: Union[str, bytes], extension: Optional[str] = None):
_, ext = os.path.splitext(path_or_content)
content = await get_content(path_or_content)
if ext == ".pem" or extension == 'pem' or extension == 'PEM': if ext == ".pem" or extension == 'pem' or extension == 'PEM':
return x509.load_pem_x509_certificate(await f.read(), default_backend()) return x509.load_pem_x509_certificate(content, default_backend())
else: else:
return x509.load_der_x509_certificate(await f.read(), default_backend()) return x509.load_der_x509_certificate(content, default_backend())
def x509_from_der(data): def x509_from_der(data):
...@@ -37,17 +46,18 @@ def x509_from_der(data): ...@@ -37,17 +46,18 @@ def x509_from_der(data):
return x509.load_der_x509_certificate(data, default_backend()) return x509.load_der_x509_certificate(data, default_backend())
async def load_private_key(path: str, async def load_private_key(path_or_content: Union[str, bytes],
password: Optional[Union[str, bytes]] = None, password: Optional[Union[str, bytes]] = None,
extension: Optional[str] = None): extension: Optional[str] = None):
_, ext = os.path.splitext(path) _, ext = os.path.splitext(path_or_content)
if isinstance(password, str): if isinstance(password, str):
password = password.encode('utf-8') password = password.encode('utf-8')
async with aiofiles.open(path, mode='rb') as f:
content = await get_content(path_or_content)
if ext == ".pem" or extension == 'pem' or extension == 'PEM': if ext == ".pem" or extension == 'pem' or extension == 'PEM':
return serialization.load_pem_private_key(await f.read(), password=password, backend=default_backend()) return serialization.load_pem_private_key(content, password=password, backend=default_backend())
else: else:
return serialization.load_der_private_key(await f.read(), password=password, backend=default_backend()) return serialization.load_der_private_key(content, password=password, backend=default_backend())
def der_from_x509(certificate): def der_from_x509(certificate):
......
...@@ -7,7 +7,7 @@ import logging ...@@ -7,7 +7,7 @@ import logging
import math import math
from datetime import timedelta, datetime from datetime import timedelta, datetime
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Coroutine, Optional, Tuple from typing import Coroutine, Optional, Tuple, Union
from asyncua import ua from asyncua import ua
from .binary_server_asyncio import BinaryServer from .binary_server_asyncio import BinaryServer
...@@ -206,14 +206,14 @@ class Server: ...@@ -206,14 +206,14 @@ class Server:
return f"OPC UA Server({self.endpoint.geturl()})" return f"OPC UA Server({self.endpoint.geturl()})"
__repr__ = __str__ __repr__ = __str__
async def load_certificate(self, path: str, format: str = None): async def load_certificate(self, path_or_content: Union[str, bytes], format: str = None):
""" """
load server certificate from file, either pem or der load server certificate from file, either pem or der
""" """
self.certificate = await uacrypto.load_certificate(path, format) self.certificate = await uacrypto.load_certificate(path_or_content, format)
async def load_private_key(self, path, password=None, format=None): async def load_private_key(self, path_or_content: Union[str, bytes], password=None, format=None):
self.iserver.private_key = await uacrypto.load_private_key(path, password, format) self.iserver.private_key = await uacrypto.load_private_key(path_or_content, password, format)
def disable_clock(self, val: bool = True): def disable_clock(self, val: bool = True):
""" """
......
...@@ -3,6 +3,9 @@ import pytest ...@@ -3,6 +3,9 @@ import pytest
import asyncio import asyncio
from asyncio import TimeoutError from asyncio import TimeoutError
from asyncua.crypto.uacrypto import CertProperties
from asyncua import Client from asyncua import Client
from asyncua import Server from asyncua import Server
from asyncua import ua from asyncua import ua
...@@ -165,6 +168,25 @@ async def test_basic256_encrypt_success(srv_crypto_all_certs): ...@@ -165,6 +168,25 @@ async def test_basic256_encrypt_success(srv_crypto_all_certs):
assert await clt.nodes.objects.get_children() assert await clt.nodes.objects.get_children()
async def test_basic256_encrypt_use_certificate_bytes(srv_crypto_all_certs):
clt = Client(uri_crypto)
_, cert = srv_crypto_all_certs
with open(cert, 'rb') as server_cert, \
open(f"{EXAMPLE_PATH}certificate-example.der", 'rb') as user_cert, \
open(f"{EXAMPLE_PATH}private-key-example.pem", 'rb') as user_key:
await clt.set_security(
security_policies.SecurityPolicyBasic256Sha256,
user_cert.read(),
CertProperties(user_key.read(), extension="pem"),
None,
server_cert.read(),
ua.MessageSecurityMode.SignAndEncrypt
)
async with clt:
assert await clt.nodes.objects.get_children()
@pytest.mark.skip("# FIXME: how to make it fail???") @pytest.mark.skip("# FIXME: how to make it fail???")
async def test_basic256_encrypt_fail(srv_crypto_all_certs): async def test_basic256_encrypt_fail(srv_crypto_all_certs):
# FIXME: how to make it fail??? # FIXME: how to make it fail???
...@@ -232,7 +254,7 @@ async def test_encrypted_private_key_handling_success_with_cert_props(srv_crypto ...@@ -232,7 +254,7 @@ async def test_encrypted_private_key_handling_success_with_cert_props(srv_crypto
clt = Client(uri_crypto_cert) clt = Client(uri_crypto_cert)
user_cert = uacrypto.CertProperties(encrypted_private_key_peer_creds['certificate'], "DER") user_cert = uacrypto.CertProperties(encrypted_private_key_peer_creds['certificate'], "DER")
user_key = uacrypto.CertProperties( user_key = uacrypto.CertProperties(
path=encrypted_private_key_peer_creds['private_key'], path_or_content=encrypted_private_key_peer_creds['private_key'],
password=encrypted_private_key_peer_creds['password'], password=encrypted_private_key_peer_creds['password'],
extension="PEM", extension="PEM",
) )
...@@ -307,7 +329,7 @@ async def test_secure_channel_key_expiration(srv_crypto_one_cert, mocker): ...@@ -307,7 +329,7 @@ async def test_secure_channel_key_expiration(srv_crypto_one_cert, mocker):
clt.secure_channel_timeout = timeout * 1000 clt.secure_channel_timeout = timeout * 1000
user_cert = uacrypto.CertProperties(peer_creds['certificate'], "DER") user_cert = uacrypto.CertProperties(peer_creds['certificate'], "DER")
user_key = uacrypto.CertProperties( user_key = uacrypto.CertProperties(
path=peer_creds['private_key'], path_or_content=peer_creds['private_key'],
extension="PEM", extension="PEM",
) )
server_cert = uacrypto.CertProperties(cert) server_cert = uacrypto.CertProperties(cert)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment