Commit 4652a3a9 authored by ORD's avatar ORD Committed by GitHub

Merge pull request #206 from FreeOpcUa/fix

do not stop client for protocol error, but try again
parents 711bc184 cedc4a43
...@@ -10,6 +10,7 @@ from functools import partial ...@@ -10,6 +10,7 @@ from functools import partial
from opcua import ua from opcua import ua
from opcua.common import utils from opcua.common import utils
from opcua.common.uaerrors import UaError
class UASocketClient(object): class UASocketClient(object):
...@@ -60,7 +61,7 @@ class UASocketClient(object): ...@@ -60,7 +61,7 @@ class UASocketClient(object):
if callback: if callback:
future.add_done_callback(callback) future.add_done_callback(callback)
self._callbackmap[self._request_id] = future self._callbackmap[self._request_id] = future
msg = self._connection.message_to_binary(binreq, message_type, self._request_id) msg = self._connection.message_to_binary(binreq, message_type=message_type, request_id=self._request_id)
self._socket.write(msg) self._socket.write(msg)
return future return future
...@@ -94,6 +95,8 @@ class UASocketClient(object): ...@@ -94,6 +95,8 @@ class UASocketClient(object):
except ua.utils.SocketClosedException: except ua.utils.SocketClosedException:
self.logger.info("Socket has closed connection") self.logger.info("Socket has closed connection")
break break
except UaError:
self.logger.exception("Protocol Error")
self.logger.info("Thread ended") self.logger.info("Thread ended")
def _receive(self): def _receive(self):
...@@ -159,7 +162,7 @@ class UASocketClient(object): ...@@ -159,7 +162,7 @@ class UASocketClient(object):
response = ua.OpenSecureChannelResponse.from_binary(future.result(self.timeout)) response = ua.OpenSecureChannelResponse.from_binary(future.result(self.timeout))
response.ResponseHeader.ServiceResult.check() response.ResponseHeader.ServiceResult.check()
self._connection.set_security_token(response.Parameters.SecurityToken) self._connection.set_channel(response.Parameters)
return response.Parameters return response.Parameters
def close_secure_channel(self): def close_secure_channel(self):
......
...@@ -91,8 +91,7 @@ class BinaryServer(object): ...@@ -91,8 +91,7 @@ class BinaryServer(object):
return return
except Exception: except Exception:
logger.exception("Exception raised while parsing message from client, closing") logger.exception("Exception raised while parsing message from client, closing")
self.transport.close() return
break
coro = self.loop.create_server(OPCUAProtocol, self.hostname, self.port) coro = self.loop.create_server(OPCUAProtocol, self.hostname, self.port)
self._server = self.loop.run_coro_and_wait(coro) self._server = self.loop.run_coro_and_wait(coro)
......
...@@ -25,7 +25,6 @@ class UaProcessor(object): ...@@ -25,7 +25,6 @@ class UaProcessor(object):
self.name = socket.get_extra_info('peername') self.name = socket.get_extra_info('peername')
self.sockname = socket.get_extra_info('sockname') self.sockname = socket.get_extra_info('sockname')
self.session = None self.session = None
self.channel = None
self.socket = socket self.socket = socket
self._socketlock = Lock() self._socketlock = Lock()
self._datalock = RLock() self._datalock = RLock()
...@@ -39,18 +38,18 @@ class UaProcessor(object): ...@@ -39,18 +38,18 @@ class UaProcessor(object):
def send_response(self, requesthandle, algohdr, seqhdr, response, msgtype=ua.MessageType.SecureMessage): def send_response(self, requesthandle, algohdr, seqhdr, response, msgtype=ua.MessageType.SecureMessage):
with self._socketlock: with self._socketlock:
response.ResponseHeader.RequestHandle = requesthandle response.ResponseHeader.RequestHandle = requesthandle
data = self._connection.message_to_binary(response.to_binary(), msgtype, seqhdr.RequestId) data = self._connection.message_to_binary(response.to_binary(), message_type=msgtype, request_id=seqhdr.RequestId, algohdr=algohdr)
self.socket.write(data) self.socket.write(data)
def open_secure_channel(self, algohdr, seqhdr, body): def open_secure_channel(self, algohdr, seqhdr, body):
request = ua.OpenSecureChannelRequest.from_binary(body) request = ua.OpenSecureChannelRequest.from_binary(body)
self._connection.select_policy(algohdr.SecurityPolicyURI, algohdr.SenderCertificate, request.Parameters.SecurityMode) self._connection.select_policy(algohdr.SecurityPolicyURI, algohdr.SenderCertificate, request.Parameters.SecurityMode)
channel = self._open_secure_channel(request.Parameters) channel = self._connection.open(request.Parameters, self.iserver)
# send response # send response
response = ua.OpenSecureChannelResponse() response = ua.OpenSecureChannelResponse()
response.Parameters = channel response.Parameters = channel
self.send_response(request.RequestHeader.RequestHandle, algohdr, seqhdr, response, ua.MessageType.SecureOpen) self.send_response(request.RequestHeader.RequestHandle, None, seqhdr, response, ua.MessageType.SecureOpen)
def forward_publish_response(self, result): def forward_publish_response(self, result):
self.logger.info("forward publish response %s", result) self.logger.info("forward publish response %s", result)
...@@ -76,8 +75,7 @@ class UaProcessor(object): ...@@ -76,8 +75,7 @@ class UaProcessor(object):
self.open_secure_channel(msg.SecurityHeader(), msg.SequenceHeader(), msg.body()) self.open_secure_channel(msg.SecurityHeader(), msg.SequenceHeader(), msg.body())
elif header.MessageType == ua.MessageType.SecureClose: elif header.MessageType == ua.MessageType.SecureClose:
if not self.channel or header.ChannelId != self.channel.SecurityToken.ChannelId: self._connection.close()
self.logger.warning("Request to close channel %s which was not issued, current channel is %s", header.ChannelId, self.channel)
return False return False
elif header.MessageType == ua.MessageType.SecureMessage: elif header.MessageType == ua.MessageType.SecureMessage:
...@@ -386,9 +384,9 @@ class UaProcessor(object): ...@@ -386,9 +384,9 @@ class UaProcessor(object):
elif typeid == ua.NodeId(ua.ObjectIds.CloseSecureChannelRequest_Encoding_DefaultBinary): elif typeid == ua.NodeId(ua.ObjectIds.CloseSecureChannelRequest_Encoding_DefaultBinary):
self.logger.info("close secure channel request") self.logger.info("close secure channel request")
self._connection.close()
response = ua.CloseSecureChannelResponse() response = ua.CloseSecureChannelResponse()
self.send_response(requesthdr.RequestHandle, algohdr, seqhdr, response) self.send_response(requesthdr.RequestHandle, algohdr, seqhdr, response)
self.channel = None
return False return False
elif typeid == ua.NodeId(ua.ObjectIds.CallRequest_Encoding_DefaultBinary): elif typeid == ua.NodeId(ua.ObjectIds.CallRequest_Encoding_DefaultBinary):
...@@ -409,21 +407,6 @@ class UaProcessor(object): ...@@ -409,21 +407,6 @@ class UaProcessor(object):
return True return True
def _open_secure_channel(self, params):
self.logger.info("open secure channel")
if not self.channel or params.RequestType == ua.SecurityTokenRequestType.Issue:
self.channel = ua.OpenSecureChannelResult()
self.channel.SecurityToken.TokenId = 13 # random value
self.channel.SecurityToken.ChannelId = self.iserver.get_new_channel_id()
self.channel.SecurityToken.RevisedLifetime = params.RequestedLifetime
self.channel.SecurityToken.TokenId += 1
self.channel.SecurityToken.CreatedAt = datetime.utcnow()
self.channel.SecurityToken.RevisedLifetime = params.RequestedLifetime
self.channel.ServerNonce = utils.create_nonce(self._connection._security_policy.symmetric_key_size)
self._connection.set_security_token(self.channel.SecurityToken)
self._connection._security_policy.make_symmetric_key(self.channel.ServerNonce, params.ClientNonce)
return self.channel
def close(self): def close(self):
""" """
to be called when client has disconnected to ensure we really close to be called when client has disconnected to ensure we really close
......
import struct import struct
import logging import logging
import hashlib import hashlib
from enum import IntEnum from datetime import datetime
from opcua.ua import uaprotocol_auto as auto from opcua.ua import uaprotocol_auto as auto
from opcua.ua import uatypes from opcua.ua import uatypes
...@@ -484,9 +484,42 @@ class SecureConnection(object): ...@@ -484,9 +484,42 @@ class SecureConnection(object):
self._incoming_parts = [] self._incoming_parts = []
self._security_policy = security_policy self._security_policy = security_policy
self._policies = [] self._policies = []
self._security_token = auto.ChannelSecurityToken() self.channel = auto.OpenSecureChannelResult()
self._old_tokens = []
self._open = False
self._max_chunk_size = 65536 self._max_chunk_size = 65536
def set_channel(self, channel):
"""
Called on client side when getting secure channel data from server
"""
self.channel = channel
def open(self, params, server):
"""
called on server side to open secure channel
"""
if not self._open or params.RequestType == auto.SecurityTokenRequestType.Issue:
self._open = True
self.channel = auto.OpenSecureChannelResult()
self.channel.SecurityToken.TokenId = 13 # random value
self.channel.SecurityToken.ChannelId = server.get_new_channel_id()
self.channel.SecurityToken.RevisedLifetime = params.RequestedLifetime
else:
self._old_tokens.append(self.channel.SecurityToken.TokenId)
self.channel.SecurityToken.TokenId += 1
self.channel.SecurityToken.CreatedAt = datetime.utcnow()
self.channel.SecurityToken.RevisedLifetime = params.RequestedLifetime
self.channel.ServerNonce = utils.create_nonce(self._security_policy.symmetric_key_size)
self._security_policy.make_symmetric_key(self.channel.ServerNonce, params.ClientNonce)
return self.channel
def close(self):
self._open = False
def is_open(self):
return self._open
def set_policy_factories(self, policies): def set_policy_factories(self, policies):
""" """
Set a list of available security policies. Set a list of available security policies.
...@@ -507,9 +540,6 @@ class SecureConnection(object): ...@@ -507,9 +540,6 @@ class SecureConnection(object):
self._security_policy.Mode != mode): self._security_policy.Mode != mode):
raise UaError("No matching policy: {}, {}".format(uri, mode)) raise UaError("No matching policy: {}, {}".format(uri, mode))
def set_security_token(self, tok):
self._security_token = tok
def tcp_to_binary(self, message_type, message): def tcp_to_binary(self, message_type, message):
""" """
Convert OPC UA TCP message (see OPC UA specs Part 6, 7.1) to binary. Convert OPC UA TCP message (see OPC UA specs Part 6, 7.1) to binary.
...@@ -520,18 +550,22 @@ class SecureConnection(object): ...@@ -520,18 +550,22 @@ class SecureConnection(object):
header.body_size = len(binmsg) header.body_size = len(binmsg)
return header.to_binary() + binmsg return header.to_binary() + binmsg
def message_to_binary(self, message, def message_to_binary(self, message, message_type=MessageType.SecureMessage, request_id=0, algohdr=None):
message_type=MessageType.SecureMessage, request_id=0):
""" """
Convert OPC UA secure message to binary. Convert OPC UA secure message to binary.
The only supported types are SecureOpen, SecureMessage, SecureClose The only supported types are SecureOpen, SecureMessage, SecureClose
if message_type is SecureMessage, the AlgoritmHeader should be passed as arg
""" """
if algohdr is None:
token_id = self.channel.SecurityToken.TokenId
else:
token_id=algohdr.TokenId
chunks = MessageChunk.message_to_chunks( chunks = MessageChunk.message_to_chunks(
self._security_policy, message, self._max_chunk_size, self._security_policy, message, self._max_chunk_size,
message_type=message_type, message_type=message_type,
channel_id=self._security_token.ChannelId, channel_id=self.channel.SecurityToken.ChannelId,
request_id=request_id, request_id=request_id,
token_id=self._security_token.TokenId) token_id=token_id)
for chunk in chunks: for chunk in chunks:
self._sequence_number += 1 self._sequence_number += 1
if self._sequence_number >= (1 << 32): if self._sequence_number >= (1 << 32):
...@@ -544,14 +578,20 @@ class SecureConnection(object): ...@@ -544,14 +578,20 @@ class SecureConnection(object):
def _check_incoming_chunk(self, chunk): def _check_incoming_chunk(self, chunk):
assert isinstance(chunk, MessageChunk), "Expected chunk, got: {}".format(chunk) assert isinstance(chunk, MessageChunk), "Expected chunk, got: {}".format(chunk)
if chunk.MessageHeader.MessageType != MessageType.SecureOpen: if chunk.MessageHeader.MessageType != MessageType.SecureOpen:
if chunk.MessageHeader.ChannelId != self._security_token.ChannelId: if chunk.MessageHeader.ChannelId != self.channel.SecurityToken.ChannelId:
raise UaError("Wrong channel id {}, expected {}".format( raise UaError("Wrong channel id {}, expected {}".format(
chunk.MessageHeader.ChannelId, chunk.MessageHeader.ChannelId,
self._security_token.ChannelId)) self.channel.SecurityToken.ChannelId))
if chunk.SecurityHeader.TokenId != self._security_token.TokenId: if chunk.SecurityHeader.TokenId != self.channel.SecurityToken.TokenId:
if chunk.SecurityHeader.TokenId not in self._old_tokens:
raise UaError("Wrong token id {}, expected {}".format( raise UaError("Wrong token id {}, expected {}".format(
chunk.SecurityHeader.TokenId, chunk.SecurityHeader.TokenId,
self._security_token.TokenId)) self.channel.SecurityToken.TokenId))
else:
# Do some cleanup, spec says we can remove old tokens when new one are used
idx = self._old_tokens.index(chunk.SecurityHeader.TokenId)
if idx != 0:
self._old_tokens = self._old_tokens[idx:]
if self._incoming_parts: if self._incoming_parts:
if self._incoming_parts[0].SequenceHeader.RequestId != chunk.SequenceHeader.RequestId: if self._incoming_parts[0].SequenceHeader.RequestId != chunk.SequenceHeader.RequestId:
raise UaError("Wrong request id {}, expected {}".format( raise UaError("Wrong request id {}, expected {}".format(
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment