Commit 4606b79c authored by Christian Bergmiller's avatar Christian Bergmiller

[ADD] refactored client receive buffer

[ADD] wip
parent b0168f54
......@@ -452,6 +452,7 @@ class Client(object):
return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.RootFolder))
def get_objects_node(self):
self.logger.info('get_objects_node')
return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.ObjectsFolder))
def get_server_node(self):
......
......@@ -6,7 +6,7 @@ import logging
from functools import partial
from opcua import ua
from opcua.ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary
from opcua.ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary
from opcua.ua.uaerrors import UaError, BadTimeout, BadNoSubscription, BadSessionClosed
from opcua.common.connection import SecureConnection
......@@ -21,7 +21,7 @@ class UASocketProtocol(asyncio.Protocol):
self.logger = logging.getLogger(__name__ + ".UASocketProtocol")
self.loop = asyncio.get_event_loop()
self.transport = None
self.receive_buffer = asyncio.Queue()
self.receive_buffer: bytes = None
self.is_receiving = False
self.timeout = timeout
self.authentication_token = ua.NodeId()
......@@ -29,7 +29,6 @@ class UASocketProtocol(asyncio.Protocol):
self._request_handle = 0
self._callbackmap = {}
self._connection = SecureConnection(security_policy)
self._leftover_chunk = None
def connection_made(self, transport: asyncio.Transport):
self.transport = transport
......@@ -39,36 +38,45 @@ class UASocketProtocol(asyncio.Protocol):
self.transport = None
def data_received(self, data: bytes):
self.receive_buffer.put_nowait(data)
if not self.is_receiving:
self.is_receiving = True
self.loop.create_task(self._receive())
async def read(self, size: int):
"""Receive up to size bytes from socket."""
data = b''
self.logger.debug('read %s bytes from socket', size)
while size > 0:
self.logger.debug('data is now %s, waiting for %s bytes', len(data), size)
# ToDo: abort on timeout, socket close
# raise SocketClosedException("Server socket has closed")
if self._leftover_chunk:
self.logger.debug('leftover bytes %s', len(self._leftover_chunk))
# use leftover chunk first
chunk = self._leftover_chunk
self._leftover_chunk = None
else:
chunk = await self.receive_buffer.get()
self.logger.debug('got chunk %s needed_length is %s', len(chunk), size)
if len(chunk) <= size:
_chunk = chunk
else:
# chunk is too big
_chunk = chunk[:size]
self._leftover_chunk = chunk[size:]
data += _chunk
size -= len(_chunk)
return data
if self.receive_buffer:
data = self.receive_buffer + data
self.receive_buffer = None
self._process_received_data(data)
def _process_received_data(self, data: bytes):
"""Try to parse a opcua message"""
buf = ua.utils.Buffer(data)
while True:
try:
try:
header = header_from_binary(buf)
except ua.utils.NotEnoughData:
self.logger.debug('Not enough data while parsing header from server, waiting for more')
self.receive_buffer = data
return
if len(buf) < header.body_size:
self.logger.debug('We did not receive enough data from server. Need %s got %s', header.body_size, len(buf))
self.receive_buffer = data
return
msg = self._connection.receive_from_header_and_body(header, buf)
self._process_received_message(msg)
if len(buf) == 0:
return
except Exception:
self.logger.exception('Exception raised while parsing message from client')
return
def _process_received_message(self, msg):
if msg is None:
pass
elif isinstance(msg, ua.Message):
self._call_callback(msg.request_id(), msg.body())
elif isinstance(msg, ua.Acknowledge):
self._call_callback(0, msg)
elif isinstance(msg, ua.ErrorMessage):
self.logger.warning("Received an error: %r", msg)
else:
raise ua.UaError("Unsupported message type: %s", msg)
def _send_request(self, request, callback=None, timeout=1000, message_type=ua.MessageType.SecureMessage):
"""
......@@ -117,24 +125,6 @@ class UASocketProtocol(asyncio.Protocol):
return False
return True
async def _receive(self):
msg = await self._connection.receive_from_socket(self)
if msg is None:
pass
elif isinstance(msg, ua.Message):
self._call_callback(msg.request_id(), msg.body())
elif isinstance(msg, ua.Acknowledge):
self._call_callback(0, msg)
elif isinstance(msg, ua.ErrorMessage):
self.logger.warning("Received an error: %r", msg)
else:
raise ua.UaError("Unsupported message type: %s", msg)
if self._leftover_chunk or not self.receive_buffer.empty():
# keep receiving
self.loop.create_task(self._receive())
else:
self.is_receiving = False
def _call_callback(self, request_id, body):
future = self._callbackmap.pop(request_id, None)
if future is None:
......
......@@ -292,21 +292,6 @@ class SecureConnection(object):
else:
raise ua.UaError("Unsupported message type {0}".format(header.MessageType))
async def receive_from_socket(self, protocol):
"""
Convert binary stream to OPC UA TCP message (see OPC UA
specs Part 6, 7.1: Hello, Acknowledge or ErrorMessage), or a Message
object, or None (if intermediate chunk is received)
"""
logger.debug("Waiting for header")
header = await header_from_binary(protocol)
logger.debug("Received header: %s", header)
body = await protocol.read(header.body_size)
if len(body) != header.body_size:
# ToDo: should never happen since UASocketProtocol.read() waits until `size` bytes are received. Remove?
raise ua.UaError("{0} bytes expected, {1} available".format(header.body_size, len(body)))
return self.receive_from_header_and_body(header, ua.utils.Buffer(body))
def _receive(self, msg):
self._check_incoming_chunk(msg)
self._incoming_parts.append(msg)
......
......@@ -470,28 +470,23 @@ class AddressSpace(object):
def __init__(self):
self.logger = logging.getLogger(__name__)
self._nodes = {}
self._lock = RLock() # FIXME: should use multiple reader, one writter pattern
self._datachange_callback_counter = 200
self._handle_to_attribute_map = {}
self._default_idx = 2
self._nodeid_counter = {0: 20000, 1: 2000}
def __getitem__(self, nodeid):
with self._lock:
if nodeid in self._nodes:
return self._nodes.__getitem__(nodeid)
if nodeid in self._nodes:
return self._nodes.__getitem__(nodeid)
def __setitem__(self, nodeid, value):
with self._lock:
return self._nodes.__setitem__(nodeid, value)
return self._nodes.__setitem__(nodeid, value)
def __contains__(self, nodeid):
with self._lock:
return self._nodes.__contains__(nodeid)
return self._nodes.__contains__(nodeid)
def __delitem__(self, nodeid):
with self._lock:
self._nodes.__delitem__(nodeid)
self._nodes.__delitem__(nodeid)
def generate_nodeid(self, idx=None):
if idx is None:
......@@ -501,23 +496,18 @@ class AddressSpace(object):
else:
self._nodeid_counter[idx] = 1
nodeid = ua.NodeId(self._nodeid_counter[idx], idx)
with self._lock: # OK since reentrant lock
while True:
if nodeid in self._nodes:
nodeid = self.generate_nodeid(idx)
else:
return nodeid
while True:
if nodeid in self._nodes:
nodeid = self.generate_nodeid(idx)
else:
return nodeid
def keys(self):
with self._lock:
return self._nodes.keys()
return self._nodes.keys()
def empty(self):
"""
Delete all nodes in address space
"""
with self._lock:
self._nodes = {}
"""Delete all nodes in address space"""
self._nodes = {}
def dump(self, path):
"""
......@@ -602,41 +592,39 @@ class AddressSpace(object):
self._nodes = LazyLoadingDict(shelve.open(path, "r"))
def get_attribute_value(self, nodeid, attr):
with self._lock:
self.logger.debug("get attr val: %s %s", nodeid, attr)
if nodeid not in self._nodes:
dv = ua.DataValue()
dv.StatusCode = ua.StatusCode(ua.StatusCodes.BadNodeIdUnknown)
return dv
node = self._nodes[nodeid]
if attr not in node.attributes:
dv = ua.DataValue()
dv.StatusCode = ua.StatusCode(ua.StatusCodes.BadAttributeIdInvalid)
return dv
attval = node.attributes[attr]
if attval.value_callback:
return attval.value_callback()
return attval.value
# self.logger.debug("get attr val: %s %s", nodeid, attr)
if nodeid not in self._nodes:
dv = ua.DataValue()
dv.StatusCode = ua.StatusCode(ua.StatusCodes.BadNodeIdUnknown)
return dv
node = self._nodes[nodeid]
if attr not in node.attributes:
dv = ua.DataValue()
dv.StatusCode = ua.StatusCode(ua.StatusCodes.BadAttributeIdInvalid)
return dv
attval = node.attributes[attr]
if attval.value_callback:
return attval.value_callback()
return attval.value
def set_attribute_value(self, nodeid, attr, value):
with self._lock:
self.logger.debug("set attr val: %s %s %s", nodeid, attr, value)
if nodeid not in self._nodes:
return ua.StatusCode(ua.StatusCodes.BadNodeIdUnknown)
node = self._nodes[nodeid]
if attr not in node.attributes:
return ua.StatusCode(ua.StatusCodes.BadAttributeIdInvalid)
if not value.SourceTimestamp:
value.SourceTimestamp = datetime.utcnow()
if not value.ServerTimestamp:
value.ServerTimestamp = datetime.utcnow()
attval = node.attributes[attr]
old = attval.value
attval.value = value
cbs = []
if old.Value != value.Value: # only send call callback when a value change has happend
cbs = list(attval.datachange_callbacks.items())
# self.logger.debug("set attr val: %s %s %s", nodeid, attr, value)
if nodeid not in self._nodes:
return ua.StatusCode(ua.StatusCodes.BadNodeIdUnknown)
node = self._nodes[nodeid]
if attr not in node.attributes:
return ua.StatusCode(ua.StatusCodes.BadAttributeIdInvalid)
if not value.SourceTimestamp:
value.SourceTimestamp = datetime.utcnow()
if not value.ServerTimestamp:
value.ServerTimestamp = datetime.utcnow()
attval = node.attributes[attr]
old = attval.value
attval.value = value
cbs = []
if old.Value != value.Value: # only send call callback when a value change has happend
cbs = list(attval.datachange_callbacks.items())
for k, v in cbs:
try:
......@@ -647,27 +635,24 @@ class AddressSpace(object):
return ua.StatusCode()
def add_datachange_callback(self, nodeid, attr, callback):
with self._lock:
self.logger.debug("set attr callback: %s %s %s", nodeid, attr, callback)
if nodeid not in self._nodes:
return ua.StatusCode(ua.StatusCodes.BadNodeIdUnknown), 0
node = self._nodes[nodeid]
if attr not in node.attributes:
return ua.StatusCode(ua.StatusCodes.BadAttributeIdInvalid), 0
attval = node.attributes[attr]
self._datachange_callback_counter += 1
handle = self._datachange_callback_counter
attval.datachange_callbacks[handle] = callback
self._handle_to_attribute_map[handle] = (nodeid, attr)
return ua.StatusCode(), handle
self.logger.debug("set attr callback: %s %s %s", nodeid, attr, callback)
if nodeid not in self._nodes:
return ua.StatusCode(ua.StatusCodes.BadNodeIdUnknown), 0
node = self._nodes[nodeid]
if attr not in node.attributes:
return ua.StatusCode(ua.StatusCodes.BadAttributeIdInvalid), 0
attval = node.attributes[attr]
self._datachange_callback_counter += 1
handle = self._datachange_callback_counter
attval.datachange_callbacks[handle] = callback
self._handle_to_attribute_map[handle] = (nodeid, attr)
return ua.StatusCode(), handle
def delete_datachange_callback(self, handle):
with self._lock:
if handle in self._handle_to_attribute_map:
nodeid, attr = self._handle_to_attribute_map.pop(handle)
self._nodes[nodeid].attributes[attr].datachange_callbacks.pop(handle)
if handle in self._handle_to_attribute_map:
nodeid, attr = self._handle_to_attribute_map.pop(handle)
self._nodes[nodeid].attributes[attr].datachange_callbacks.pop(handle)
def add_method_callback(self, methodid, callback):
with self._lock:
node = self._nodes[methodid]
node.call = callback
node = self._nodes[methodid]
node.call = callback
......@@ -18,12 +18,13 @@ class OPCUAProtocol(asyncio.Protocol):
to the internal server object
FIXME: find another solution
"""
def __init__(self, iserver=None, policies=None, clients=None):
self.loop = asyncio.get_event_loop()
self.peer_name = None
self.transport = None
self.processor = None
self.data = b''
self.receive_buffer = b''
self.iserver = iserver
self.policies = policies
self.clients = clients
......@@ -51,38 +52,40 @@ class OPCUAProtocol(asyncio.Protocol):
self.clients.remove(self)
def data_received(self, data):
logger.debug('received %s bytes from socket', len(data))
if self.data:
data = self.data + data
self.data = b''
self.loop.create_task(self._process_data(data))
if self.receive_buffer:
data = self.receive_buffer + data
self.receive_buffer = b''
self._process_received_data(data)
async def _process_data(self, data):
def _process_received_data(self, data: bytes):
logger.info('_process_received_data %s', len(data))
buf = ua.utils.Buffer(data)
while True:
try:
try:
backup_buf = buf.copy()
try:
hdr = await uabin.header_from_binary(buf)
except ua.utils.NotEnoughData:
logger.info('We did not receive enough data from client, waiting for more')
self.data = backup_buf.read(len(backup_buf))
return
if len(buf) < hdr.body_size:
logger.info('We did not receive enough data from client, waiting for more')
self.data = backup_buf.read(len(backup_buf))
return
ret = self.processor.process(hdr, buf)
if not ret:
logger.info('processor returned False, we close connection from %s', self.peer_name)
self.transport.close()
return
if len(buf) == 0:
return
except Exception:
logger.exception('Exception raised while parsing message from client, closing')
header = uabin.header_from_binary(buf)
except ua.utils.NotEnoughData:
logger.debug('Not enough data while parsing header from client, waiting for more')
self.receive_buffer = data + self.receive_buffer
return
if len(buf) < header.body_size:
logger.debug('We did not receive enough data from client. Need %s got %s', header.body_size, len(buf))
self.receive_buffer = data + self.receive_buffer
return
self.loop.create_task(self._process_received_message(header, buf))
except Exception:
logger.exception('Exception raised while parsing message from client')
return
async def _process_received_message(self, header, buf):
logger.debug('_process_received_message %s %s', header.body_size, len(buf))
ret = await self.processor.process(header, buf)
if not ret:
logger.info('processor returned False, we close connection from %s', self.peer_name)
self.transport.close()
return
if len(buf) != 0:
# There is data left in the buffer - process it
self._process_received_data(buf)
class BinaryServer:
......@@ -113,7 +116,7 @@ class BinaryServer:
sockname = self._server.sockets[0].getsockname()
self.hostname = sockname[0]
self.port = sockname[1]
self.logger.warning('Listening on {0}:{1}'.format(self.hostname, self.port))
self.logger.info('Listening on {0}:{1}'.format(self.hostname, self.port))
async def stop(self):
self.logger.info('Closing asyncio socket server')
......
......@@ -285,7 +285,8 @@ class Server:
"""
Stop server
"""
await asyncio.wait([client.disconnect() for client in self._discovery_clients.values()])
if self._discovery_clients:
await asyncio.wait([client.disconnect() for client in self._discovery_clients.values()])
await self.bserver.stop()
self.iserver.stop()
......
This diff is collapsed.
......@@ -515,13 +515,13 @@ def header_to_binary(hdr):
return b"".join(b)
async def header_from_binary(data):
def header_from_binary(data):
hdr = ua.Header()
hdr.MessageType, hdr.ChunkType, hdr.packet_size = struct.unpack("<3scI", await data.read(8))
hdr.MessageType, hdr.ChunkType, hdr.packet_size = struct.unpack("<3scI", data.read(8))
hdr.body_size = hdr.packet_size - 8
if hdr.MessageType in (ua.MessageType.SecureOpen, ua.MessageType.SecureClose, ua.MessageType.SecureMessage):
hdr.body_size -= 4
hdr.ChannelId = Primitives.UInt32.unpack(ua.utils.Buffer(await data.read(4)))
hdr.ChannelId = Primitives.UInt32.unpack(data)
return hdr
......
[pytest]
log_cli=False
log_print=True
log_level=INFO
log_level=DEBUG
import logging
import pytest
from opcua import Client
......@@ -9,9 +11,10 @@ from .tests_common import CommonTests, add_server_methods
from .tests_xml import XmlTests
port_num1 = 48510
_logger = logging.getLogger(__name__)
pytestmark = pytest.mark.asyncio
@pytest.yield_fixture()
@pytest.fixture()
async def admin_client():
# start admin client
# long timeout since travis (automated testing) can be really slow
......@@ -21,7 +24,7 @@ async def admin_client():
await clt.disconnect()
@pytest.yield_fixture()
@pytest.fixture()
async def client():
# start anonymous client
ro_clt = Client(f'opc.tcp://127.0.0.1:{port_num1}')
......@@ -30,7 +33,7 @@ async def client():
await ro_clt.disconnect()
@pytest.yield_fixture()
@pytest.fixture()
async def server():
# start our own server
srv = Server()
......@@ -43,7 +46,6 @@ async def server():
await srv.stop()
@pytest.mark.asyncio
async def test_service_fault(server, admin_client):
request = ua.ReadRequest()
request.TypeId = ua.FourByteNodeId(999) # bad type!
......@@ -51,49 +53,45 @@ async def test_service_fault(server, admin_client):
await admin_client.uaclient.protocol.send_request(request)
@pytest.mark.asyncio
async def test_objects_anonymous(server, client):
objects = client.get_objects_node()
with pytest.raises(ua.UaStatusCodeError):
objects.set_attribute(ua.AttributeIds.WriteMask, ua.DataValue(999))
await objects.set_attribute(ua.AttributeIds.WriteMask, ua.DataValue(999))
with pytest.raises(ua.UaStatusCodeError):
f = objects.add_folder(3, 'MyFolder')
await objects.add_folder(3, 'MyFolder')
@pytest.mark.asyncio
async def test_folder_anonymous(server, client):
objects = client.get_objects_node()
f = objects.add_folder(3, 'MyFolderRO')
f = await objects.add_folder(3, 'MyFolderRO')
f_ro = client.get_node(f.nodeid)
assert f == f_ro
with pytest.raises(ua.UaStatusCodeError):
f2 = f_ro.add_folder(3, 'MyFolder2')
await f_ro.add_folder(3, 'MyFolder2')
@pytest.mark.asyncio
async def test_variable_anonymous(server, admin_client, client):
objects = admin_client.get_objects_node()
v = objects.add_variable(3, 'MyROVariable', 6)
v.set_value(4) # this should work
v = await objects.add_variable(3, 'MyROVariable', 6)
await v.set_value(4) # this should work
v_ro = client.get_node(v.nodeid)
with pytest.raises(ua.UaStatusCodeError):
v_ro.set_value(2)
await v_ro.set_value(2)
assert await v_ro.get_value() == 4
v.set_writable(True)
v_ro.set_value(2) # now it should work
await v.set_writable(True)
await v_ro.set_value(2) # now it should work
assert await v_ro.get_value() == 2
v.set_writable(False)
await v.set_writable(False)
with pytest.raises(ua.UaStatusCodeError):
v_ro.set_value(9)
await v_ro.set_value(9)
assert await v_ro.get_value() == 2
@pytest.mark.asyncio
async def test_context_manager(server):
"""Context manager calls connect() and disconnect()"""
state = [0]
def increment_state(*args, **kwargs):
async def increment_state(*args, **kwargs):
state[0] += 1
# create client and replace instance methods with dummy methods
......@@ -102,14 +100,13 @@ async def test_context_manager(server):
client.disconnect = increment_state.__get__(client)
assert state[0] == 0
with client:
async with client:
# test if client connected
assert state[0] == 1
# test if client disconnected
assert state[0] == 2
@pytest.mark.asyncio
async def test_enumstrings_getvalue(server, client):
"""
The real exception is server side, but is detected by using a client.
......@@ -117,4 +114,4 @@ async def test_enumstrings_getvalue(server, client):
The client only 'sees' an TimeoutError
"""
nenumstrings = client.get_node(ua.ObjectIds.AxisScaleEnumeration_EnumStrings)
value = ua.Variant(nenumstrings.get_value())
value = ua.Variant(await nenumstrings.get_value())
import unittest
import pytest
import os
import shelve
import time
from tests_common import CommonTests, add_server_methods
from tests_xml import XmlTests
from tests_subscriptions import SubscriptionTests
from .tests_common import CommonTests, add_server_methods
from .tests_xml import XmlTests
from .tests_subscriptions import SubscriptionTests
from datetime import timedelta, datetime
from tempfile import NamedTemporaryFile
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment