Commit abb582de authored by Julien Prigent's avatar Julien Prigent Committed by oroulet

[InvalidSignature] Clean code and add tests

parent 9576f05f
......@@ -286,7 +286,6 @@ class Client:
params.RequestType = ua.SecurityTokenRequestType.Issue
if renew:
params.RequestType = ua.SecurityTokenRequestType.Renew
_logger.warning("Renewing secure channel")
params.SecurityMode = self.security_policy.Mode
params.RequestedLifetime = self.secure_channel_timeout
# length should be equal to the length of key of symmetric encryption
......
......@@ -19,7 +19,8 @@ class MessageChunk(ua.FrozenClass):
"""
Message Chunk, as described in OPC UA specs Part 6, 6.7.2.
"""
def __init__(self, security_policy, body=b'', msg_type=ua.MessageType.SecureMessage, chunk_type=ua.ChunkType.Single):
def __init__(self, security_policy, body=b'', msg_type=ua.MessageType.SecureMessage,
chunk_type=ua.ChunkType.Single):
self.MessageHeader = ua.Header(msg_type, chunk_type)
if msg_type in (ua.MessageType.SecureMessage, ua.MessageType.SecureClose):
self.SecurityHeader = ua.SymmetricAlgorithmHeader()
......@@ -62,7 +63,10 @@ class MessageChunk(ua.FrozenClass):
if signature_size > 0:
signature = decrypted[-signature_size:]
decrypted = decrypted[:-signature_size]
crypto.verify(header_to_binary(obj.MessageHeader) + struct_to_binary(obj.SecurityHeader) + decrypted, signature)
crypto.verify(
header_to_binary(obj.MessageHeader) + struct_to_binary(obj.SecurityHeader) + decrypted,
signature
)
data = ua.utils.Buffer(crypto.remove_padding(decrypted))
obj.SequenceHeader = struct_from_binary(ua.SequenceHeader, data)
obj.Body = data.read(len(data))
......@@ -91,7 +95,8 @@ class MessageChunk(ua.FrozenClass):
return max_plain_size - ua.SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size()
@staticmethod
def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1):
def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage,
channel_id=1, request_id=1, token_id=1):
"""
Pack message body (as binary string) into one or more chunks.
Size of each chunk will not exceed max_chunk_size.
......@@ -162,10 +167,12 @@ class SecureConnection:
self.security_token = params.SecurityToken
self.local_nonce = client_nonce
self.remote_nonce = params.ServerNonce
logger.warning(f"params {params.SecurityToken.RevisedLifetime} security policy: {self.security_policy}")
revised_lifetime = self.security_token.RevisedLifetime
self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce, revised_lifetime)
self.security_policy.make_remote_symmetric_key(
self.local_nonce,
self.remote_nonce,
self.security_token.RevisedLifetime
)
self._open = True
else:
self.next_security_token = params.SecurityToken
......@@ -194,7 +201,11 @@ class SecureConnection:
response.SecurityToken = self.security_token
self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce, self.security_token.RevisedLifetime)
self.security_policy.make_remote_symmetric_key(
self.local_nonce,
self.remote_nonce,
self.security_token.RevisedLifetime
)
else:
self.next_security_token = copy.deepcopy(self.security_token)
self.next_security_token.TokenId += 1
......@@ -248,15 +259,9 @@ class SecureConnection:
The only supported types are SecureOpen, SecureMessage, SecureClose.
If message_type is SecureMessage, the AlgorithmHeader should be passed as arg.
"""
chunks = MessageChunk.message_to_chunks(
self.security_policy,
message,
self._max_chunk_size,
message_type=message_type,
channel_id=self.security_token.ChannelId,
request_id=request_id,
token_id=self.security_token.TokenId,
)
chunks = MessageChunk.message_to_chunks(self.security_policy, message, self._max_chunk_size,
message_type=message_type, channel_id=self.security_token.ChannelId,
request_id=request_id, token_id=self.security_token.TokenId)
for chunk in chunks:
self._sequence_number += 1
if self._sequence_number >= (1 << 32):
......@@ -287,7 +292,8 @@ class SecureConnection:
timeout = self.prev_security_token.CreatedAt + \
timedelta(milliseconds=self.prev_security_token.RevisedLifetime * 1.25)
if timeout < datetime.utcnow():
raise ua.UaError(f"Security token id {security_hdr.TokenId} has timed out " f"({timeout} < {datetime.utcnow()})")
raise ua.UaError(f"Security token id {security_hdr.TokenId} has timed out "
f"({timeout} < {datetime.utcnow()})")
return
expected_tokens = [self.security_token.TokenId, self.next_security_token.TokenId]
......@@ -300,10 +306,12 @@ class SecureConnection:
raise ValueError(f'Expected chunk, got: {chunk}')
if chunk.MessageHeader.MessageType != ua.MessageType.SecureOpen:
if chunk.MessageHeader.ChannelId != self.security_token.ChannelId:
raise ua.UaError(f'Wrong channel id {chunk.MessageHeader.ChannelId},' f' expected {self.security_token.ChannelId}')
raise ua.UaError(f'Wrong channel id {chunk.MessageHeader.ChannelId},'
f' expected {self.security_token.ChannelId}')
if self._incoming_parts:
if self._incoming_parts[0].SequenceHeader.RequestId != chunk.SequenceHeader.RequestId:
raise ua.UaError(f'Wrong request id {chunk.SequenceHeader.RequestId},' f' expected {self._incoming_parts[0].SequenceHeader.RequestId}')
raise ua.UaError(f'Wrong request id {chunk.SequenceHeader.RequestId},'
f' expected {self._incoming_parts[0].SequenceHeader.RequestId}')
# The sequence number must monotonically increase (but it can wrap around)
seq_num = chunk.SequenceHeader.SequenceNumber
if self._peer_sequence_number is not None:
......@@ -314,7 +322,9 @@ class SecureConnection:
logger.debug('Sequence number wrapped: %d -> %d', self._peer_sequence_number, seq_num)
else:
# Condition for monotonically increase is not met
raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:" f" {self._peer_sequence_number}, received: {seq_num}," f" spec says to close connection")
raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:"
f" {self._peer_sequence_number}, received: {seq_num},"
f" spec says to close connection")
self._peer_sequence_number = seq_num
def receive_from_header_and_body(self, header, body):
......@@ -356,6 +366,7 @@ class SecureConnection:
if header.MessageType == ua.MessageType.Error:
msg = struct_from_binary(ua.ErrorMessage, body)
logger.warning(f"Received an error: {msg}")
return msg
raise ua.UaError(f"Unsupported message type {header.MessageType}")
def _receive(self, msg):
......
......@@ -56,6 +56,11 @@ class Verifier(object):
def verify(self, data, signature):
pass
def reset(self):
attrs = self.__dict__
for k in attrs:
attrs[k] = None
class Encryptor(object):
"""
......@@ -96,6 +101,11 @@ class Decryptor(object):
def decrypt(self, data):
pass
def reset(self):
attrs = self.__dict__
for k in attrs:
attrs[k] = None
class Cryptography(CryptographyNone):
"""
......@@ -170,7 +180,7 @@ class Cryptography(CryptographyNone):
if not self.use_prev_key:
self.Verifier.verify(data, sig)
else:
logger.warning(f"FALLBACK! Checking signature with previous secure_channel key")
logger.debug(f"Message verification fallback: trying with previous secure channel key")
self.Prev_Verifier.verify(data, sig)
def encrypt(self, data):
......@@ -184,7 +194,6 @@ class Cryptography(CryptographyNone):
if self.is_encrypted:
self.revolved_expired_key()
if self.use_prev_key:
logger.warning(f"FALLBACK! Decrypt with previous secure_channel key")
return self.Prev_Decryptor.decrypt(data)
return self.Decryptor.decrypt(data)
return data
......@@ -195,20 +204,19 @@ class Cryptography(CryptographyNone):
"""
now = time.time()
if now > self.prev_key_expiration:
logger.info("Removing expired secure_channel key")
if getattr(self.Prev_Decryptor, "key", None):
self.Prev_Decryptor.key = None
self.Prev_Decryptor = None
if getattr(self.Prev_Verifier, "key", None):
self.Prev_Verifier.key = None
self.Prev_Verifier = None
if self.Prev_Decryptor and self.Prev_Verifier:
self.Prev_Decryptor.reset()
self.Prev_Decryptor = None
self.Prev_Verifier.reset()
self.Prev_Verifier = None
logger.debug(f"Expired secure_channel keys removed")
@property
def use_prev_key(self):
if self._use_prev_key:
if self.Prev_Decryptor and self.Prev_Verifier:
return True
raise uacrypto.InvalidSignature("Previous key has expired")
raise uacrypto.InvalidSignature
else:
return False
......@@ -328,8 +336,6 @@ class VerifierAesCbc(Verifier):
def verify(self, data, signature):
expected = uacrypto.hmac_sha1(self.key, data)
if signature != expected:
logger.warning(f"Actual signature: {signature}")
logger.warning(f"Expected signature: {expected}")
raise uacrypto.InvalidSignature
......@@ -418,8 +424,6 @@ class VerifierHMac256(Verifier):
def verify(self, data, signature):
expected = uacrypto.hmac_sha256(self.key, data)
if signature != expected:
logger.warning(f"Actual signature: {signature}")
logger.warning(f"Expected signature: {expected}")
raise uacrypto.InvalidSignature
......
......@@ -558,9 +558,9 @@ async def _uaserver():
server = Server()
server.set_endpoint(args.url)
if args.certificate:
await server.load_certificate(args.certificate)
server.load_certificate(args.certificate)
if args.private_key:
await server.load_private_key(args.private_key)
server.load_private_key(args.private_key)
server.disable_clock(args.disable_clock)
server.set_server_name("FreeOpcUa Example Server")
if args.xml:
......
import os
import pytest
import sys
import asyncio
if sys.version_info >= (3, 6):
from asyncio import TimeoutError
......@@ -11,6 +12,7 @@ from asyncua import Client
from asyncua import Server
from asyncua import ua
from asyncua.server.user_managers import CertificateUserManager
from asyncua.crypto.security_policies import Verifier, Decryptor
try:
from asyncua.crypto import uacrypto
......@@ -282,3 +284,43 @@ async def test_certificate_handling_mismatched_creds(srv_crypto_one_cert):
)
async with clt:
assert await clt.get_objects_node().get_children()
async def test_secure_channel_key_expiration(srv_crypto_one_cert, mocker):
timeout = 1
_, cert = srv_crypto_one_cert
clt = Client(uri_crypto_cert)
clt.secure_channel_timeout = timeout * 1000
user_cert = uacrypto.CertProperties(peer_creds['certificate'], "DER")
user_key = uacrypto.CertProperties(
path=peer_creds['private_key'],
extension="PEM",
)
server_cert = uacrypto.CertProperties(cert)
await clt.set_security(
security_policies.SecurityPolicyBasic256Sha256,
user_cert,
user_key,
server_certificate=server_cert,
mode=ua.MessageSecurityMode.SignAndEncrypt
)
async with clt:
assert clt.uaclient.security_policy.symmetric_cryptography.Prev_Verifier is None
assert clt.uaclient.security_policy.symmetric_cryptography.Prev_Decryptor is None
await asyncio.sleep(timeout)
prev_verifier = clt.uaclient.security_policy.symmetric_cryptography.Prev_Verifier
prev_decryptor = clt.uaclient.security_policy.symmetric_cryptography.Prev_Decryptor
assert isinstance(prev_verifier, Verifier)
assert isinstance(prev_decryptor, Decryptor)
mock_decry_reset = mocker.patch.object(prev_verifier, "reset", wraps=prev_verifier.reset)
mock_verif_reset = mocker.patch.object(prev_decryptor, "reset", wraps=prev_decryptor.reset)
assert mock_decry_reset.call_count == 0
assert mock_verif_reset.call_count == 0
await asyncio.sleep(timeout*0.3)
assert await clt.get_objects_node().get_children()
assert mock_decry_reset.call_count == 1
assert mock_verif_reset.call_count == 1
assert clt.uaclient.security_policy.symmetric_cryptography.Prev_Verifier is None
assert clt.uaclient.security_policy.symmetric_cryptography.Prev_Decryptor is None
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment