Commit 9576f05f authored by Julien Prigent's avatar Julien Prigent Committed by oroulet

[InvalidSignature] Fix race condition with secure channel renewal

parent 4fedec4c
......@@ -286,6 +286,7 @@ class Client:
params.RequestType = ua.SecurityTokenRequestType.Issue
if renew:
params.RequestType = ua.SecurityTokenRequestType.Renew
_logger.warning("Renewing secure channel")
params.SecurityMode = self.security_policy.Mode
params.RequestedLifetime = self.secure_channel_timeout
# length should be equal to the length of key of symmetric encryption
......
......@@ -6,6 +6,12 @@ import copy
from asyncua import ua
from ..ua.ua_binary import struct_from_binary, struct_to_binary, header_from_binary, header_to_binary
try:
from ..crypto.uacrypto import InvalidSignature
except ImportError:
class InvalidSignature(Exception):
pass
logger = logging.getLogger('asyncua.uaprotocol')
......@@ -28,10 +34,13 @@ class MessageChunk(ua.FrozenClass):
@staticmethod
def from_binary(security_policy, data):
h = header_from_binary(data)
return MessageChunk.from_header_and_body(security_policy, h, data)
try:
return MessageChunk.from_header_and_body(security_policy, h, data)
except InvalidSignature:
return MessageChunk.from_header_and_body(security_policy, h, data, use_prev_key=True)
@staticmethod
def from_header_and_body(security_policy, header, buf):
def from_header_and_body(security_policy, header, buf, use_prev_key=False):
if not len(buf) >= header.body_size:
raise ValueError('Full body expected here')
data = buf.copy(header.body_size)
......@@ -44,6 +53,7 @@ class MessageChunk(ua.FrozenClass):
crypto = security_policy.asymmetric_cryptography
else:
raise ua.UaError(f"Unsupported message type: {header.MessageType}")
crypto.use_prev_key = use_prev_key
obj = MessageChunk(crypto)
obj.MessageHeader = header
obj.SecurityHeader = security_header
......@@ -152,8 +162,10 @@ class SecureConnection:
self.security_token = params.SecurityToken
self.local_nonce = client_nonce
self.remote_nonce = params.ServerNonce
logger.warning(f"params {params.SecurityToken.RevisedLifetime} security policy: {self.security_policy}")
revised_lifetime = self.security_token.RevisedLifetime
self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce, revised_lifetime)
self._open = True
else:
self.next_security_token = params.SecurityToken
......@@ -182,7 +194,7 @@ class SecureConnection:
response.SecurityToken = self.security_token
self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce, self.security_token.RevisedLifetime)
else:
self.next_security_token = copy.deepcopy(self.security_token)
self.next_security_token.TokenId += 1
......@@ -228,7 +240,7 @@ class SecureConnection:
self.security_token = self.next_security_token
self.next_security_token = ua.ChannelSecurityToken()
self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce, self.security_token.RevisedLifetime)
def message_to_binary(self, message, message_type=ua.MessageType.SecureMessage, request_id=0):
"""
......@@ -326,7 +338,12 @@ class SecureConnection:
self._check_sym_header(security_header)
if header.MessageType in (ua.MessageType.SecureMessage, ua.MessageType.SecureOpen, ua.MessageType.SecureClose):
chunk = MessageChunk.from_header_and_body(self.security_policy, header, body)
try:
pos = body.cur_pos
chunk = MessageChunk.from_header_and_body(self.security_policy, header, body, use_prev_key=False)
except InvalidSignature:
body.rewind(cur_pos=pos)
chunk = MessageChunk.from_header_and_body(self.security_policy, header, body, use_prev_key=True)
return self._receive(chunk)
if header.MessageType == ua.MessageType.Hello:
msg = struct_from_binary(ua.Hello, body)
......@@ -339,7 +356,6 @@ class SecureConnection:
if header.MessageType == ua.MessageType.Error:
msg = struct_from_binary(ua.ErrorMessage, body)
logger.warning(f"Received an error: {msg}")
return msg
raise ua.UaError(f"Unsupported message type {header.MessageType}")
def _receive(self, msg):
......
......@@ -79,6 +79,16 @@ class Buffer:
self._size -= size
self._cur_pos += size
@property
def cur_pos(self):
return self._cur_pos
def rewind(self, cur_pos=0):
"""
rewind the buffer
"""
self._cur_pos = cur_pos
self._size = len(self._data) - cur_pos
def create_nonce(size=32):
return os.urandom(size)
import logging
import struct
import time
from abc import ABCMeta, abstractmethod
from ..ua import CryptographyNone, SecurityPolicy, MessageSecurityMode, UaError
......@@ -12,6 +13,7 @@ except ImportError:
POLICY_NONE_URI = 'http://opcfoundation.org/UA/SecurityPolicy#None'
logger = logging.getLogger(__name__)
def require_cryptography(obj):
......@@ -103,8 +105,14 @@ class Cryptography(CryptographyNone):
def __init__(self, mode=MessageSecurityMode.Sign):
self.Signer = None
self.Verifier = None
self.Prev_Verifier = None
self.Encryptor = None
self.Decryptor = None
self.Prev_Decryptor = None
# we turn this flag on to fallback on previous key
self._use_prev_key = False
self.key_expiration = 0.0
self.prev_key_expiration = 0.0
if mode not in (MessageSecurityMode.Sign,
MessageSecurityMode.SignAndEncrypt):
raise ValueError(f"unknown security mode {mode}")
......@@ -159,7 +167,11 @@ class Cryptography(CryptographyNone):
return self.Verifier.signature_size()
def verify(self, data, sig):
self.Verifier.verify(data, sig)
if not self.use_prev_key:
self.Verifier.verify(data, sig)
else:
logger.warning(f"FALLBACK! Checking signature with previous secure_channel key")
self.Prev_Verifier.verify(data, sig)
def encrypt(self, data):
if self.is_encrypted:
......@@ -170,12 +182,44 @@ class Cryptography(CryptographyNone):
def decrypt(self, data):
if self.is_encrypted:
self.revolved_expired_key()
if self.use_prev_key:
logger.warning(f"FALLBACK! Decrypt with previous secure_channel key")
return self.Prev_Decryptor.decrypt(data)
return self.Decryptor.decrypt(data)
return data
def revolved_expired_key(self):
"""
Remove expired keys as soon as possible
"""
now = time.time()
if now > self.prev_key_expiration:
logger.info("Removing expired secure_channel key")
if getattr(self.Prev_Decryptor, "key", None):
self.Prev_Decryptor.key = None
self.Prev_Decryptor = None
if getattr(self.Prev_Verifier, "key", None):
self.Prev_Verifier.key = None
self.Prev_Verifier = None
@property
def use_prev_key(self):
if self._use_prev_key:
if self.Prev_Decryptor and self.Prev_Verifier:
return True
raise uacrypto.InvalidSignature("Previous key has expired")
else:
return False
@use_prev_key.setter
def use_prev_key(self, value: bool):
self._use_prev_key = value
def remove_padding(self, data):
decryptor = self.Decryptor if not self.use_prev_key else self.Prev_Decryptor
if self.is_encrypted:
if self.Decryptor.encrypted_block_size() > 256:
if decryptor.encrypted_block_size() > 256:
pad_size = struct.unpack('<h', data[-2:])[0] + 2
else:
pad_size = bytearray(data[-1:])[0] + 1
......@@ -284,6 +328,8 @@ class VerifierAesCbc(Verifier):
def verify(self, data, signature):
expected = uacrypto.hmac_sha1(self.key, data)
if signature != expected:
logger.warning(f"Actual signature: {signature}")
logger.warning(f"Expected signature: {expected}")
raise uacrypto.InvalidSignature
......@@ -372,6 +418,8 @@ class VerifierHMac256(Verifier):
def verify(self, data, signature):
expected = uacrypto.hmac_sha256(self.key, data)
if signature != expected:
logger.warning(f"Actual signature: {signature}")
logger.warning(f"Expected signature: {expected}")
raise uacrypto.InvalidSignature
......@@ -414,7 +462,6 @@ class SecurityPolicyBasic128Rsa15(SecurityPolicy):
return uacrypto.encrypt_rsa15(pubkey, data)
def __init__(self, server_cert, client_cert, client_pk, mode):
logger = logging.getLogger(__name__)
logger.warning("DEPRECATED! Do not use SecurityPolicyBasic128Rsa15 anymore!")
require_cryptography(self)
......@@ -442,10 +489,17 @@ class SecurityPolicyBasic128Rsa15(SecurityPolicy):
self.symmetric_cryptography.Signer = SignerAesCbc(sigkey)
self.symmetric_cryptography.Encryptor = EncryptorAesCbc(key, init_vec)
def make_remote_symmetric_key(self, secret, seed):
def make_remote_symmetric_key(self, secret, seed, lifetime):
key_sizes = (self.signature_key_size, self.symmetric_key_size, 16)
(sigkey, key, init_vec) = uacrypto.p_sha1(secret, seed, key_sizes)
if self.symmetric_cryptography.Verifier or self.symmetric_cryptography.Decryptor:
self.symmetric_cryptography.Prev_Verifier = self.symmetric_cryptography.Verifier
self.symmetric_cryptography.Prev_Decryptor = self.symmetric_cryptography.Decryptor
self.symmetric_cryptography.prev_key_expiration = self.symmetric_cryptography.key_expiration
# lifetime is in ms
self.symmetric_cryptography.key_expiration = time.time() + (lifetime * 0.001)
self.symmetric_cryptography.Verifier = VerifierAesCbc(sigkey)
self.symmetric_cryptography.Decryptor = DecryptorAesCbc(key, init_vec)
......@@ -489,7 +543,6 @@ class SecurityPolicyBasic256(SecurityPolicy):
return uacrypto.encrypt_rsa_oaep(pubkey, data)
def __init__(self, server_cert, client_cert, client_pk, mode):
logger = logging.getLogger(__name__)
logger.warning("DEPRECATED! Do not use SecurityPolicyBasic256 anymore!")
require_cryptography(self)
......@@ -518,12 +571,19 @@ class SecurityPolicyBasic256(SecurityPolicy):
self.symmetric_cryptography.Signer = SignerAesCbc(sigkey)
self.symmetric_cryptography.Encryptor = EncryptorAesCbc(key, init_vec)
def make_remote_symmetric_key(self, secret, seed):
def make_remote_symmetric_key(self, secret, seed, lifetime):
# specs part 6, 6.7.5
key_sizes = (self.signature_key_size, self.symmetric_key_size, 16)
(sigkey, key, init_vec) = uacrypto.p_sha1(secret, seed, key_sizes)
if self.symmetric_cryptography.Verifier or self.symmetric_cryptography.Decryptor:
self.symmetric_cryptography.Prev_Verifier = self.symmetric_cryptography.Verifier
self.symmetric_cryptography.Prev_Decryptor = self.symmetric_cryptography.Decryptor
self.symmetric_cryptography.prev_key_expiration = self.symmetric_cryptography.key_expiration
# lifetime is in ms
self.symmetric_cryptography.key_expiration = time.time() + (lifetime * 0.001)
self.symmetric_cryptography.Verifier = VerifierAesCbc(sigkey)
self.symmetric_cryptography.Decryptor = DecryptorAesCbc(key, init_vec)
......@@ -596,12 +656,19 @@ class SecurityPolicyBasic256Sha256(SecurityPolicy):
self.symmetric_cryptography.Signer = SignerHMac256(sigkey)
self.symmetric_cryptography.Encryptor = EncryptorAesCbc(key, init_vec)
def make_remote_symmetric_key(self, secret, seed):
def make_remote_symmetric_key(self, secret, seed, lifetime):
# specs part 6, 6.7.5
key_sizes = (self.signature_key_size, self.symmetric_key_size, 16)
(sigkey, key, init_vec) = uacrypto.p_sha256(secret, seed, key_sizes)
if self.symmetric_cryptography.Verifier or self.symmetric_cryptography.Decryptor:
self.symmetric_cryptography.Prev_Verifier = self.symmetric_cryptography.Verifier
self.symmetric_cryptography.Prev_Decryptor = self.symmetric_cryptography.Decryptor
self.symmetric_cryptography.prev_key_expiration = self.symmetric_cryptography.key_expiration
# lifetime is in ms
self.symmetric_cryptography.key_expiration = time.time() + (lifetime * 0.001)
self.symmetric_cryptography.Verifier = VerifierHMac256(sigkey)
self.symmetric_cryptography.Decryptor = DecryptorAesCbc(key, init_vec)
......
......@@ -246,7 +246,7 @@ class SecurityPolicy:
def make_local_symmetric_key(self, secret, seed):
pass
def make_remote_symmetric_key(self, secret, seed):
def make_remote_symmetric_key(self, secret, seed, lifetime):
pass
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment