Commit 01c7acf0 authored by Alexander Schrode's avatar Alexander Schrode Committed by GitHub

Check limits of messages (CVE-2022-25304) (#1040)

* check message limits on recv

* add ErrorMessage handling

* add to large chunk test

* client disconnect on ErrorMessage

* change default transport limits
parent c66f34cf
......@@ -2,13 +2,14 @@
Low level binary client
"""
import asyncio
import copy
import logging
from typing import Awaitable, Callable, Dict, List, Optional, Union
from asyncua import ua
from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary
from ..ua.uaerrors import BadTimeout, BadNoSubscription, BadSessionClosed, BadUserAccessDenied, UaStructParsingError
from ..common.connection import SecureConnection
from ..common.connection import SecureConnection, TransportLimits
class UASocketProtocol(asyncio.Protocol):
......@@ -20,7 +21,7 @@ class UASocketProtocol(asyncio.Protocol):
OPEN = 'open'
CLOSED = 'closed'
def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.SecurityPolicy()):
def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.SecurityPolicy(), limits: TransportLimits = None):
"""
:param timeout: Timeout in seconds
:param security_policy: Security policy (optional)
......@@ -34,7 +35,12 @@ class UASocketProtocol(asyncio.Protocol):
self._request_id = 0
self._request_handle = 0
self._callbackmap: Dict[int, asyncio.Future] = {}
self._connection = SecureConnection(security_policy)
if limits is None:
limits = TransportLimits(65535, 65535, 0, 0)
else:
limits = copy.deep_copy(limits) # Make a copy because the limits can change in the session
self._connection = SecureConnection(security_policy, limits)
self.state = self.INITIALIZED
self.closed: bool = False
# needed to pass params from asynchronous request to synchronous data receive callback, as well as
......@@ -103,7 +109,7 @@ class UASocketProtocol(asyncio.Protocol):
self._call_callback(0, msg)
elif isinstance(msg, ua.ErrorMessage):
self.logger.fatal("Received an error: %r", msg)
self._call_callback(0, ua.UaStatusCodeError(msg.Error.value))
self.disconnect_socket()
else:
raise ua.UaError(f"Unsupported message type: {msg}")
......
from dataclasses import dataclass
import hashlib
from datetime import datetime, timedelta
import logging
......@@ -15,6 +16,56 @@ except ImportError:
logger = logging.getLogger('asyncua.uaprotocol')
@dataclass
class TransportLimits:
'''
Limits of the tcp transport layer to prevent excessive resource usage
'''
max_recv_buffer: int = 65535
max_send_buffer: int = 65535
max_chunk_count: int = ((100 * 1024 * 1024) // 65535) + 1 # max_message_size / max_recv_buffer
max_message_size: int = 100 * 1024 * 1024 # 100mb
@staticmethod
def _select_limit(hint: ua.UInt32, limit: int) -> ua.UInt32:
if limit <= 0:
return hint
elif limit < hint:
return hint
return ua.UInt32(limit)
def check_max_msg_size(self, sz: int) -> bool:
if self.max_message_size == 0:
return True
return self.max_message_size <= sz
def check_max_chunk_count(self, sz: int) -> bool:
if self.max_chunk_count == 0:
return True
return self.max_chunk_count <= sz
def create_acknowledge_limits(self, msg: ua.Hello) -> ua.Acknowledge:
ack = ua.Acknowledge()
ack.ReceiveBufferSize = min(msg.ReceiveBufferSize, self.max_recv_buffer)
ack.SendBufferSize = min(msg.SendBufferSize, self.max_send_buffer)
ack.MaxChunkCount = self._select_limit(msg.MaxChunkCount, self.max_chunk_count)
ack.MaxMessageSize = self._select_limit(msg.MaxMessageSize, self.max_message_size)
self.update_limits(ack)
return ack
def create_hello_limits(self, msg: ua.Hello) -> ua.Hello:
msg.ReceiveBufferSize = self.max_recv_buffer
msg.SendBufferSize = self.max_send_buffer
msg.MaxChunkCount = self.max_chunk_count
msg.MaxMessageSize = self.max_chunk_count
def update_limits(self, msg: ua.Acknowledge) -> None:
self.max_chunk_count = msg.MaxChunkCount
self.max_recv_buffer = msg.ReceiveBufferSize
self.max_send_buffer = msg.SendBufferSize
self.max_message_size = msg.MaxMessageSize
class MessageChunk:
"""
Message Chunk, as described in OPC UA specs Part 6, 6.7.2.
......@@ -139,7 +190,7 @@ class SecureConnection:
"""
Common logic for client and server
"""
def __init__(self, security_policy):
def __init__(self, security_policy, limits: TransportLimits):
self._sequence_number = 0
self._peer_sequence_number = None
self._incoming_parts = []
......@@ -152,7 +203,7 @@ class SecureConnection:
self.local_nonce = 0
self.remote_nonce = 0
self._allow_prev_token = False
self._max_chunk_size = 65536
self._limits = limits
def set_channel(self, params, request_type, client_nonce):
"""
......@@ -257,7 +308,7 @@ class SecureConnection:
chunks = MessageChunk.message_to_chunks(
self.security_policy,
message,
self._max_chunk_size,
self._limits.max_send_buffer,
message_type=message_type,
channel_id=self.security_token.ChannelId,
request_id=request_id,
......@@ -353,11 +404,10 @@ class SecureConnection:
return self._receive(chunk)
if header.MessageType == ua.MessageType.Hello:
msg = struct_from_binary(ua.Hello, body)
self._max_chunk_size = msg.ReceiveBufferSize
return msg
if header.MessageType == ua.MessageType.Acknowledge:
msg = struct_from_binary(ua.Acknowledge, body)
self._max_chunk_size = msg.SendBufferSize
self._limits.update_limits(msg)
return msg
if header.MessageType == ua.MessageType.Error:
msg = struct_from_binary(ua.ErrorMessage, body)
......@@ -366,8 +416,14 @@ class SecureConnection:
raise ua.UaError(f"Unsupported message type {header.MessageType}")
def _receive(self, msg):
if msg.MessageHeader.packet_size > self._limits.max_recv_buffer:
self._incoming_parts = []
raise ua.UaStatusCodeError(ua.StatusCodes.BadRequestTooLarge)
self._check_incoming_chunk(msg)
self._incoming_parts.append(msg)
if not self._limits.check_max_chunk_count(len(self._incoming_parts)):
self._incoming_parts = []
raise ua.UaStatusCodeError(ua.StatusCodes.BadRequestTooLarge)
if msg.MessageHeader.ChunkType == ua.ChunkType.Intermediate:
return None
if msg.MessageHeader.ChunkType == ua.ChunkType.Abort:
......
......@@ -3,8 +3,10 @@ Socket server forwarding request to internal server
"""
import logging
import asyncio
import math
from typing import Optional
from ..common.connection import TransportLimits
from ..ua.ua_binary import header_from_binary
from ..common.utils import Buffer, NotEnoughData
from .uaprocessor import UaProcessor
......@@ -18,7 +20,7 @@ class OPCUAProtocol(asyncio.Protocol):
Instantiated for every connection.
"""
def __init__(self, iserver: InternalServer, policies, clients, closing_tasks):
def __init__(self, iserver: InternalServer, policies, clients, closing_tasks, limits: TransportLimits):
self.peer_name = None
self.transport = None
self.processor = None
......@@ -28,6 +30,7 @@ class OPCUAProtocol(asyncio.Protocol):
self.clients = clients
self.closing_tasks = closing_tasks
self.messages = asyncio.Queue()
self.limits = limits
self._task = None
def __str__(self):
......@@ -39,7 +42,7 @@ class OPCUAProtocol(asyncio.Protocol):
self.peer_name = transport.get_extra_info('peername')
logger.info('New connection from %s', self.peer_name)
self.transport = transport
self.processor = UaProcessor(self.iserver, self.transport)
self.processor = UaProcessor(self.iserver, self.transport, self.limits)
self.processor.set_policies(self.policies)
self.iserver.asyncio_transports.append(transport)
self.clients.append(self)
......@@ -119,6 +122,15 @@ class BinaryServer:
self.clients = []
self.closing_tasks = []
self.cleanup_task = None
# Use accectable limits
buffer_sz = 65535
max_msg_sz = 16 * 1024 * 1024 # 16mb simular to the opc ua c stack so this is a good default
self.limits = TransportLimits(
max_recv_buffer=buffer_sz,
max_send_buffer=buffer_sz,
max_chunk_count=math.ceil(buffer_sz / max_msg_sz), # Round up to allow max msg size
max_message_size=max_msg_sz
)
def set_policies(self, policies):
self._policies = policies
......@@ -130,6 +142,7 @@ class BinaryServer:
policies=self._policies,
clients=self.clients,
closing_tasks=self.closing_tasks,
limits=self.limits
)
async def start(self):
......
import copy
import time
import logging
from typing import Deque, Optional
......@@ -6,7 +7,7 @@ from collections import deque
from asyncua import ua
from ..ua.ua_binary import nodeid_from_binary, struct_from_binary, struct_to_binary, uatcp_to_binary
from .internal_server import InternalServer, InternalSession
from ..common.connection import SecureConnection
from ..common.connection import SecureConnection, TransportLimits
from ..common.utils import ServiceError
_logger = logging.getLogger(__name__)
......@@ -25,7 +26,7 @@ class UaProcessor:
Processor for OPC UA messages. Implements the OPC UA protocol for the server side.
"""
def __init__(self, internal_server: InternalServer, transport):
def __init__(self, internal_server: InternalServer, transport, limits: TransportLimits):
self.iserver: InternalServer = internal_server
self.name = transport.get_extra_info('peername')
self.sockname = transport.get_extra_info('sockname')
......@@ -35,7 +36,8 @@ class UaProcessor:
self._publish_requests: Deque[PublishRequestData] = deque()
# used when we need to wait for PublishRequest
self._publish_results: Deque[ua.PublishResult] = deque()
self._connection = SecureConnection(ua.SecurityPolicy())
self._limits = copy.deepcopy(limits) # Copy limits because they get overriden
self._connection = SecureConnection(ua.SecurityPolicy(), self._limits)
def set_policies(self, policies):
self._connection.set_policy_factories(policies)
......@@ -89,6 +91,12 @@ class UaProcessor:
async def process(self, header, body):
try:
msg = self._connection.receive_from_header_and_body(header, body)
except ua.uaerrors.BadRequestTooLarge as e:
_logger.warning("Recived request that exceed the transport limits")
err = ua.ErrorMessage(ua.StatusCode(e.code), str(e))
data = uatcp_to_binary(ua.MessageType.Error, err)
self._transport.write(data)
return True
except ua.uaerrors.BadUserAccessDenied:
_logger.warning("Unauthenticated user attempted to connect")
return False
......@@ -101,9 +109,7 @@ class UaProcessor:
elif header.MessageType == ua.MessageType.SecureMessage:
return await self.process_message(msg.SequenceHeader(), msg.body())
elif isinstance(msg, ua.Hello):
ack = ua.Acknowledge()
ack.ReceiveBufferSize = msg.ReceiveBufferSize
ack.SendBufferSize = msg.SendBufferSize
ack = self._limits.create_acknowledge_limits(msg)
data = uatcp_to_binary(ua.MessageType.Acknowledge, ack)
self._transport.write(data)
elif isinstance(msg, ua.ErrorMessage):
......
......@@ -665,6 +665,30 @@ async def test_server_read_write_attribute_value(server: Server):
assert dv.Value.Value == 5
await server.delete_nodes([node])
@pytest.fixture(scope="function")
def restore_transport_limits_server(server: Server):
# Restore limits after test
max_recv = server.bserver.limits.max_recv_buffer
max_chunk_count = server.bserver.limits.max_chunk_count
yield server
server.bserver.limits.max_recv_buffer = max_recv
server.bserver.limits.max_chunk_count = max_chunk_count
async def test_message_limits(restore_transport_limits_server: Server):
server = restore_transport_limits_server
server.bserver.limits.max_recv_buffer = 1024
server.bserver.limits.max_chunk_count = 10
client = Client(server.endpoint.geturl())
# This should trigger a timeout error because the message is to large
with pytest.raises(asyncio.TimeoutError):
async with client:
test_string = 'a' * (1024 * 1024 * 1024)
n = client.get_node(ua.NodeId())
await n.write_value(test_string, ua.VariantType.String)
"""
class TestServerCaching(unittest.TestCase):
def runTest(self):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment