Commit a80a7be6 authored by ORD's avatar ORD

Merge pull request #79 from alkor/message-chunks

Split messages into MessageChunks
parents 558fddd6 7105528f
...@@ -24,7 +24,7 @@ class UASocketClient(object): ...@@ -24,7 +24,7 @@ class UASocketClient(object):
handle socket connection and send ua messages handle socket connection and send ua messages
timeout is the timeout used while waiting for an ua answer from server timeout is the timeout used while waiting for an ua answer from server
""" """
def __init__(self, timeout=1): def __init__(self, timeout=1, security_policy=ua.SecurityPolicy()):
self.logger = logging.getLogger(__name__ + "Socket") self.logger = logging.getLogger(__name__ + "Socket")
self._thread = None self._thread = None
self._lock = Lock() self._lock = Lock()
...@@ -37,6 +37,8 @@ class UASocketClient(object): ...@@ -37,6 +37,8 @@ class UASocketClient(object):
self._request_id = 0 self._request_id = 0
self._request_handle = 0 self._request_handle = 0
self._callbackmap = {} self._callbackmap = {}
self._security_policy = security_policy
self._max_chunk_size = 65536
def start(self): def start(self):
""" """
...@@ -47,28 +49,43 @@ class UASocketClient(object): ...@@ -47,28 +49,43 @@ class UASocketClient(object):
self._thread = Thread(target=self._run) self._thread = Thread(target=self._run)
self._thread.start() self._thread.start()
def send_request(self, request, callback=None, timeout=1000): def _send_request(self, request, callback=None, timeout=1000, message_type=ua.MessageType.SecureMessage):
""" """
send request to server. send request to server, lower-level method
timeout is the timeout written in ua header timeout is the timeout written in ua header
returns future
""" """
with self._lock: with self._lock:
request.RequestHeader = self._create_request_header(timeout) request.RequestHeader = self._create_request_header(timeout)
try: try:
cachedreq = CachedRequest(request.to_binary()) binreq = request.to_binary()
except: except:
# reset reqeust handle if any error # reset reqeust handle if any error
# see self._create_request_header # see self._create_request_header
self._request_handle -= 1 self._request_handle -= 1
raise raise
hdr = ua.Header(ua.MessageType.SecureMessage, ua.ChunkType.Single, self._security_token.ChannelId) self._request_id += 1
symhdr = self._create_sym_algo_header()
seqhdr = self._create_sequence_header()
future = Future() future = Future()
if callback: if callback:
future.add_done_callback(callback) future.add_done_callback(callback)
self._callbackmap[seqhdr.RequestId] = future self._callbackmap[self._request_id] = future
self._write_socket(hdr, symhdr, seqhdr, cachedreq) for chunk in ua.MessageChunk.message_to_chunks(self._security_policy, binreq, self._max_chunk_size,
message_type=message_type,
channel_id=self._security_token.ChannelId,
request_id=self._request_id,
token_id=self._security_token.TokenId):
self._sequence_number += 1
chunk.SequenceHeader.SequenceNumber = self._sequence_number
self._socket.write(chunk.to_binary())
return future
def send_request(self, request, callback=None, timeout=1000, message_type=ua.MessageType.SecureMessage):
"""
send request to server.
timeout is the timeout written in ua header
returns response object if no callback is provided
"""
future = self._send_request(request, callback, timeout, message_type)
if not callback: if not callback:
data = future.result(self.timeout) data = future.result(self.timeout)
self.check_answer(data, " in response to " + request.__class__.__name__) self.check_answer(data, " in response to " + request.__class__.__name__)
...@@ -179,19 +196,6 @@ class UASocketClient(object): ...@@ -179,19 +196,6 @@ class UASocketClient(object):
hdr.TimeoutHint = timeout hdr.TimeoutHint = timeout
return hdr return hdr
def _create_sym_algo_header(self):
hdr = ua.SymmetricAlgorithmHeader()
hdr.TokenId = self._security_token.TokenId
return hdr
def _create_sequence_header(self):
hdr = ua.SequenceHeader()
self._sequence_number += 1
hdr.SequenceNumber = self._sequence_number
self._request_id += 1
hdr.RequestId = self._request_id
return hdr
def connect_socket(self, host, port): def connect_socket(self, host, port):
""" """
connect to server socket and start receiving thread connect to server socket and start receiving thread
...@@ -215,22 +219,15 @@ class UASocketClient(object): ...@@ -215,22 +219,15 @@ class UASocketClient(object):
with self._lock: with self._lock:
self._callbackmap[0] = future self._callbackmap[0] = future
self._write_socket(header, hello) self._write_socket(header, hello)
return ua.Acknowledge.from_binary(future.result(self.timeout)) ack = ua.Acknowledge.from_binary(future.result(self.timeout))
self._max_chunk_size = ack.SendBufferSize # client shouldn't send chunks larger than this
return ack
def open_secure_channel(self, params): def open_secure_channel(self, params):
self.logger.info("open_secure_channel") self.logger.info("open_secure_channel")
with self._lock:
request = ua.OpenSecureChannelRequest() request = ua.OpenSecureChannelRequest()
request.Parameters = params request.Parameters = params
request.RequestHeader = self._create_request_header() future = self._send_request(request, message_type=ua.MessageType.SecureOpen)
hdr = ua.Header(ua.MessageType.SecureOpen, ua.ChunkType.Single, self._security_token.ChannelId)
asymhdr = ua.AsymmetricAlgorithmHeader()
seqhdr = self._create_sequence_header()
future = Future()
self._callbackmap[seqhdr.RequestId] = future
self._write_socket(hdr, asymhdr, seqhdr, request)
response = ua.OpenSecureChannelResponse.from_binary(future.result(self.timeout)) response = ua.OpenSecureChannelResponse.from_binary(future.result(self.timeout))
response.ResponseHeader.ServiceResult.check() response.ResponseHeader.ServiceResult.check()
...@@ -240,17 +237,16 @@ class UASocketClient(object): ...@@ -240,17 +237,16 @@ class UASocketClient(object):
def close_secure_channel(self): def close_secure_channel(self):
""" """
close secure channel. It seems to trigger a shutdown of socket close secure channel. It seems to trigger a shutdown of socket
in most servers, so be prepare to reconnect in most servers, so be prepare to reconnect.
OPC UA specs Part 6, 7.1.4 say that Server does not send a CloseSecureChannel response and should just close socket
""" """
self.logger.info("get_endpoint") self.logger.info("close_secure_channel")
with self._lock:
request = ua.CloseSecureChannelRequest() request = ua.CloseSecureChannelRequest()
request.RequestHeader = self._create_request_header() future = self._send_request(request, message_type=ua.MessageType.SecureClose)
with self._lock:
hdr = ua.Header(ua.MessageType.SecureClose, ua.ChunkType.Single, self._security_token.ChannelId) # don't expect any more answers
symhdr = self._create_sym_algo_header() future.cancel()
seqhdr = self._create_sequence_header() self._callbackmap.clear()
self._write_socket(hdr, symhdr, seqhdr, request)
# some servers send a response here, most do not ... so we ignore # some servers send a response here, most do not ... so we ignore
......
...@@ -28,17 +28,20 @@ class UAProcessor(object): ...@@ -28,17 +28,20 @@ class UAProcessor(object):
self._socketlock = Lock() self._socketlock = Lock()
self._datalock = Lock() self._datalock = Lock()
self._publishdata_queue = [] self._publishdata_queue = []
self._seq_number = 1 self._seq_number = 0
self._security_policy = ua.SecurityPolicy()
self._max_chunk_size = 65536
def send_response(self, requesthandle, algohdr, seqhdr, response, msgtype=ua.MessageType.SecureMessage): def send_response(self, requesthandle, algohdr, seqhdr, response, msgtype=ua.MessageType.SecureMessage):
with self._socketlock: with self._socketlock:
response.ResponseHeader.RequestHandle = requesthandle response.ResponseHeader.RequestHandle = requesthandle
seqhdr.SequenceNumber = self._seq_number for chunk in ua.MessageChunk.message_to_chunks(self._security_policy, response.to_binary(), self._max_chunk_size, msgtype,
channel_id=self.channel.SecurityToken.ChannelId,
token_id=self.channel.SecurityToken.TokenId,
request_id=seqhdr.RequestId):
self._seq_number += 1 self._seq_number += 1
hdr = ua.Header(msgtype, ua.ChunkType.Single, self.channel.SecurityToken.ChannelId) chunk.SequenceHeader.SequenceNumber = self._seq_number
if isinstance(algohdr, ua.SymmetricAlgorithmHeader): self.socket.write(chunk.to_binary())
algohdr.TokenId = self.channel.SecurityToken.TokenId
self._write_socket(hdr, algohdr, seqhdr, response)
def _write_socket(self, hdr, *args): def _write_socket(self, hdr, *args):
alle = [] alle = []
...@@ -81,6 +84,7 @@ class UAProcessor(object): ...@@ -81,6 +84,7 @@ class UAProcessor(object):
hello = ua.Hello.from_binary(body) hello = ua.Hello.from_binary(body)
hdr = ua.Header(ua.MessageType.Acknowledge, ua.ChunkType.Single) hdr = ua.Header(ua.MessageType.Acknowledge, ua.ChunkType.Single)
ack = ua.Acknowledge() ack = ua.Acknowledge()
self._max_chunk_size = hello.ReceiveBufferSize
ack.ReceiveBufferSize = hello.ReceiveBufferSize ack.ReceiveBufferSize = hello.ReceiveBufferSize
ack.SendBufferSize = hello.SendBufferSize ack.SendBufferSize = hello.SendBufferSize
self._write_socket(hdr, ack) self._write_socket(hdr, ack)
......
import struct import struct
import logging import logging
import hashlib
import opcua.uaprotocol_auto as auto import opcua.uaprotocol_auto as auto
import opcua.uatypes as uatypes import opcua.uatypes as uatypes
...@@ -70,7 +71,7 @@ class ChunkType(object): ...@@ -70,7 +71,7 @@ class ChunkType(object):
Invalid = b"0" # FIXME check Invalid = b"0" # FIXME check
Single = b"F" Single = b"F"
Intermediate = b"C" Intermediate = b"C"
Final = b"A" Abort = b"A" # when an error occurred and the Message is aborted (body is ErrorMessage)
class Header(uatypes.FrozenClass): class Header(uatypes.FrozenClass):
...@@ -107,6 +108,10 @@ class Header(uatypes.FrozenClass): ...@@ -107,6 +108,10 @@ class Header(uatypes.FrozenClass):
hdr.ChannelId = uatype_UInt32.unpack(data.read(4))[0] hdr.ChannelId = uatype_UInt32.unpack(data.read(4))[0]
return hdr return hdr
@staticmethod
def max_size():
return struct.calcsize("<3scII")
def __str__(self): def __str__(self):
return "Header(type:{}, chunk_type:{}, body_size:{}, channel:{})".format(self.MessageType, self.ChunkType, self.body_size, self.ChannelId) return "Header(type:{}, chunk_type:{}, body_size:{}, channel:{})".format(self.MessageType, self.ChunkType, self.body_size, self.ChannelId)
__repr__ = __str__ __repr__ = __str__
...@@ -207,6 +212,10 @@ class SymmetricAlgorithmHeader(uatypes.FrozenClass): ...@@ -207,6 +212,10 @@ class SymmetricAlgorithmHeader(uatypes.FrozenClass):
def to_binary(self): def to_binary(self):
return uatype_UInt32.pack(self.TokenId) return uatype_UInt32.pack(self.TokenId)
@staticmethod
def max_size():
return struct.calcsize("<I")
def __str__(self): def __str__(self):
return "{}(TokenId:{} )".format(self.__class__.__name__, self.TokenId) return "{}(TokenId:{} )".format(self.__class__.__name__, self.TokenId)
__repr__ = __str__ __repr__ = __str__
...@@ -232,10 +241,198 @@ class SequenceHeader(uatypes.FrozenClass): ...@@ -232,10 +241,198 @@ class SequenceHeader(uatypes.FrozenClass):
b.append(uatype_UInt32.pack(self.RequestId)) b.append(uatype_UInt32.pack(self.RequestId))
return b"".join(b) return b"".join(b)
@staticmethod
def max_size():
return struct.calcsize("<II")
def __str__(self): def __str__(self):
return "{}(SequenceNumber:{}, RequestId:{} )".format(self.__class__.__name__, self.SequenceNumber, self.RequestId) return "{}(SequenceNumber:{}, RequestId:{} )".format(self.__class__.__name__, self.SequenceNumber, self.RequestId)
__repr__ = __str__ __repr__ = __str__
class CryptographyNone:
"""
Base class for symmetric/asymmetric cryprography
"""
def __init__(self, mode=auto.MessageSecurityMode.None_):
pass
def plain_block_size(self):
"""
Size of plain text block for block cipher.
"""
return 1
def encrypted_block_size(self):
"""
Size of encrypted text block for block cipher.
"""
return 1
def padding(self, size):
"""
Create padding for a block of given size.
plain_size = size + len(padding) + signature_size()
plain_size = N * plain_block_size()
"""
return b''
def min_padding_size(self):
return 0
def signature_size(self):
return 0
def signature(self, data):
return b''
def encrypt(self, data):
return data
def decrypt(self, data):
return data
def vsignature_size(self):
return 0
def verify(self, data, signature):
"""
Verify signature and raise exception if signature is invalid
"""
pass
def remove_padding(self, data):
return data
class SecurityPolicy:
"""
Base class for security policy
"""
def __init__(self):
self.asymmetric_cryptography = CryptographyNone()
self.symmetric_cryptography = CryptographyNone()
self.Mode = auto.MessageSecurityMode.None_
self.URI = "http://opcfoundation.org/UA/SecurityPolicy#None"
self.server_certificate = b""
self.client_certificate = b""
def symmetric_key_size(self):
return 0
def make_symmetric_key(self, a, b):
pass
class MessageChunk(uatypes.FrozenClass):
"""
Message Chunk, as described in OPC UA specs Part 6, 6.7.2.
"""
def __init__(self, security_policy, body=b'', msg_type=MessageType.SecureMessage, chunk_type=ChunkType.Single):
self.MessageHeader = Header(msg_type, chunk_type)
if msg_type in (MessageType.SecureMessage, MessageType.SecureClose):
self.SecurityHeader = SymmetricAlgorithmHeader()
elif msg_type == MessageType.SecureOpen:
self.SecurityHeader = AsymmetricAlgorithmHeader()
else:
raise Exception("Unsupported message type: {}".format(msg_type))
self.SequenceHeader = SequenceHeader()
self.Body = body
self._security_policy = security_policy
@staticmethod
def from_binary(security_policy, data):
h = Header.from_string(data)
return MessageChunk.from_header_and_body(security_policy, h, data)
@staticmethod
def from_header_and_body(security_policy, header, data):
if header.MessageType in (MessageType.SecureMessage, MessageType.SecureClose):
security_header = SymmetricAlgorithmHeader.from_binary(data)
crypto = security_policy.symmetric_cryptography
elif header.MessageType == MessageType.SecureOpen:
security_header = AsymmetricAlgorithmHeader.from_binary(data)
crypto = security_policy.asymmetric_cryptography
else:
raise Exception("Unsupported message type: {}".format(header.MessageType))
obj = MessageChunk(crypto)
obj.MessageHeader = header
obj.SecurityHeader = security_header
decrypted = crypto.decrypt(data.read(len(data)))
signature_size = crypto.vsignature_size()
if signature_size > 0:
signature = decrypted[-signature_size:]
decrypted = decrypted[:-signature_size]
crypto.verify(obj.MessageHeader.to_binary() + obj.SecurityHeader.to_binary() + decrypted, signature)
data = utils.Buffer(crypto.remove_padding(decrypted))
obj.SequenceHeader = SequenceHeader.from_binary(data)
obj.Body = data.read(len(data))
return obj
def encrypted_size(self, plain_size):
size = plain_size + self._security_policy.signature_size()
pbs = self._security_policy.plain_block_size()
assert(size % pbs == 0)
return size // pbs * self._security_policy.encrypted_block_size()
def to_binary(self):
security = self.SecurityHeader.to_binary()
encrypted_part = self.SequenceHeader.to_binary() + self.Body
encrypted_part += self._security_policy.padding(len(encrypted_part))
self.MessageHeader.body_size = len(security) + self.encrypted_size(len(encrypted_part))
header = self.MessageHeader.to_binary()
encrypted_part += self._security_policy.signature(header + security + encrypted_part)
return header + security + self._security_policy.encrypt(encrypted_part)
@staticmethod
def max_body_size(crypto, max_chunk_size):
max_encrypted_size = max_chunk_size - Header.max_size() - SymmetricAlgorithmHeader.max_size()
max_plain_size = (max_encrypted_size // crypto.encrypted_block_size()) * crypto.plain_block_size()
return max_plain_size - SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size()
@staticmethod
def message_to_chunks(security_policy, body, max_chunk_size, message_type=MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1):
"""
Pack message body (as binary string) into one or more chunks.
Size of each chunk will not exceed max_chunk_size.
Returns a list of MessageChunks. SequenceNumber is not initialized here,
it must be set by Secure Channel driver.
"""
if message_type == MessageType.SecureOpen:
# SecureOpen message must be in a single chunk (specs, Part 6, 6.7.2)
chunk = MessageChunk(security_policy.asymmetric_cryptography, body, message_type, ChunkType.Single)
chunk.SecurityHeader.SecurityPolicyURI = security_policy.URI
if security_policy.client_certificate:
chunk.SecurityHeader.SenderCertificate = security_policy.client_certificate
if security_policy.server_certificate:
chunk.SecurityHeader.ReceiverCertificateThumbPrint = hashlib.sha1(security_policy.server_certificate).digest()
chunk.MessageHeader.ChannelId = channel_id
chunk.SequenceHeader.RequestId = request_id
return [chunk]
crypto = security_policy.symmetric_cryptography
max_size = MessageChunk.max_body_size(crypto, max_chunk_size)
chunks = []
for i in range(0, len(body), max_size):
part = body[i:i+max_size]
if i+max_size >= len(body):
chunk_type = ChunkType.Single
else:
chunk_type = ChunkType.Intermediate
chunk = MessageChunk(crypto, part, message_type, chunk_type)
chunk.SecurityHeader.TokenId = token_id
chunk.MessageHeader.ChannelId = channel_id
chunk.SequenceHeader.RequestId = request_id
chunks.append(chunk)
return chunks
def __str__(self):
return "{}({}, {}, {}, {} bytes)".format(self.__class__.__name__,
self.MessageHeader, self.SequenceHeader, self.SecurityHeader, len(self.Body))
__repr__ = __str__
# FIXES for missing switchfield in NodeAttributes classes # FIXES for missing switchfield in NodeAttributes classes
ana = auto.NodeAttributesMask ana = auto.NodeAttributesMask
......
...@@ -266,6 +266,29 @@ class Unit(unittest.TestCase): ...@@ -266,6 +266,29 @@ class Unit(unittest.TestCase):
t4 = ua.LocalizedText.from_binary(ua.utils.Buffer(t1.to_binary())) t4 = ua.LocalizedText.from_binary(ua.utils.Buffer(t1.to_binary()))
self.assertEqual(t1, t4) self.assertEqual(t1, t4)
def test_message_chunk(self):
pol = ua.SecurityPolicy()
chunks = ua.MessageChunk.message_to_chunks(pol, b'123', 65536)
self.assertEqual(len(chunks), 1)
seq = 0
for chunk in chunks:
seq += 1
chunk.SequenceHeader.SequenceNumber = seq
chunk2 = ua.MessageChunk.from_binary(pol, ua.utils.Buffer(chunks[0].to_binary()))
self.assertEqual(chunks[0].to_binary(), chunk2.to_binary())
# for policy None, MessageChunk overhead is 12+4+8 = 24 bytes
# Let's pack 11 bytes into 28-byte chunks. The message must be split as 4+4+3
chunks = ua.MessageChunk.message_to_chunks(pol, b'12345678901', 28)
self.assertEqual(len(chunks), 3)
self.assertEqual(chunks[0].Body, b'1234')
self.assertEqual(chunks[1].Body, b'5678')
self.assertEqual(chunks[2].Body, b'901')
for chunk in chunks:
seq += 1
chunk.SequenceHeader.SequenceNumber = seq
self.assertTrue(len(chunk.to_binary()) <= 28)
class CommonTests(object): class CommonTests(object):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment