Commit cee4d4b8 authored by Alexander Schrode's avatar Alexander Schrode Committed by oroulet

add background task

parent 688e6079
...@@ -29,11 +29,7 @@ class Client: ...@@ -29,11 +29,7 @@ class Client:
use UaClient object, available as self.uaclient use UaClient object, available as self.uaclient
which offers the raw OPC-UA services interface. which offers the raw OPC-UA services interface.
""" """
def __init__(self, url: str, timeout: float = 4, watchdog_intervall: float = 10):
_username = None
_password = None
def __init__(self, url: str, timeout: float = 4):
""" """
:param url: url of the server. :param url: url of the server.
if you are unsure of url, write at least hostname if you are unsure of url, write at least hostname
...@@ -41,6 +37,8 @@ class Client: ...@@ -41,6 +37,8 @@ class Client:
:param timeout: :param timeout:
Each request sent to the server expects an answer within this Each request sent to the server expects an answer within this
time. The timeout is specified in seconds. time. The timeout is specified in seconds.
:param watchdog_intervall:
The time between checking if the server is still alive. The timeout is specified in seconds.
Some other client parameters can be changed by setting Some other client parameters can be changed by setting
attributes on the constructed object: attributes on the constructed object:
See the source code for the exhaustive list. See the source code for the exhaustive list.
...@@ -65,7 +63,7 @@ class Client: ...@@ -65,7 +63,7 @@ class Client:
self.secure_channel_timeout = 3600000 # 1 hour self.secure_channel_timeout = 3600000 # 1 hour
self.session_timeout = 3600000 # 1 hour self.session_timeout = 3600000 # 1 hour
self._policy_ids = [] self._policy_ids = []
self.uaclient: UaClient = UaClient(timeout) self.uaclient: UaClient = UaClient(timeout, self._check_tasks)
self.user_certificate = None self.user_certificate = None
self.user_private_key = None self.user_private_key = None
self._server_nonce = None self._server_nonce = None
...@@ -74,7 +72,9 @@ class Client: ...@@ -74,7 +72,9 @@ class Client:
self.max_messagesize = 0 # No limits self.max_messagesize = 0 # No limits
self.max_chunkcount = 0 # No limits self.max_chunkcount = 0 # No limits
self._renew_channel_task = None self._renew_channel_task = None
self._watch_task = None
self._locale = ["en"] self._locale = ["en"]
self._watchdog_intervall = watchdog_intervall
async def __aenter__(self): async def __aenter__(self):
await self.connect() await self.connect()
...@@ -421,11 +421,39 @@ class Client: ...@@ -421,11 +421,39 @@ class Client:
_logger.warning("Requested session timeout to be %dms, got %dms instead", self.secure_channel_timeout, response.RevisedSessionTimeout) _logger.warning("Requested session timeout to be %dms, got %dms instead", self.secure_channel_timeout, response.RevisedSessionTimeout)
self.session_timeout = response.RevisedSessionTimeout self.session_timeout = response.RevisedSessionTimeout
self._renew_channel_task = asyncio.create_task(self._renew_channel_loop()) self._renew_channel_task = asyncio.create_task(self._renew_channel_loop())
self._watch_task = asyncio.create_task(self._watchdog_loop())
return response return response
async def _check_tasks(self):
# Check if a background task has finished and if a exception is thrown rethrow it with result
# use half the session timeout or the supply value from constructor
if self._renew_channel_task is not None:
if self._renew_channel_task.done():
await self._renew_channel_task
if self._watch_task is not None:
if self._watch_task.done():
await self._watch_task
async def _watchdog_loop(self):
"""
Checks if the server is alive
"""
timeout = min(self.session_timeout / 1000 / 2, self._watchdog_intervall)
try:
while True:
await asyncio.sleep(timeout)
# @FIXME handle state change
_ = await self.nodes.server_state.read_value()
except asyncio.CancelledError:
pass
except Exception:
_logger.exception("Error in watchdog loop")
raise
async def _renew_channel_loop(self): async def _renew_channel_loop(self):
""" """
Renew the SecureChannel before the SessionTimeout will happen. Renew the SecureChannel before the SecureChannelTimeout will happen.
In theory we could do that only if no session activity In theory we could do that only if no session activity
but it does not cost much.. but it does not cost much..
""" """
...@@ -549,6 +577,12 @@ class Client: ...@@ -549,6 +577,12 @@ class Client:
""" """
Close session Close session
""" """
if self._watch_task:
self._watch_task.cancel()
try:
await self._watch_task
except Exception:
_logger.exception("Error while closing watch_task")
if self._renew_channel_task: if self._renew_channel_task:
self._renew_channel_task.cancel() self._renew_channel_task.cancel()
try: try:
......
...@@ -3,7 +3,7 @@ Low level binary client ...@@ -3,7 +3,7 @@ Low level binary client
""" """
import asyncio import asyncio
import logging import logging
from typing import Dict, List, Optional, Union from typing import Awaitable, Callable, Dict, List, Optional, Union
from asyncua import ua from asyncua import ua
from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary
...@@ -20,10 +20,11 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -20,10 +20,11 @@ class UASocketProtocol(asyncio.Protocol):
OPEN = 'open' OPEN = 'open'
CLOSED = 'closed' CLOSED = 'closed'
def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.SecurityPolicy()): def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.SecurityPolicy(), before_request_hook: Callable[[], Awaitable[None]] = None):
""" """
:param timeout: Timeout in seconds :param timeout: Timeout in seconds
:param security_policy: Security policy (optional) :param security_policy: Security policy (optional)
:param before_request_hook: Hook for upperlayer tasks before a request is send (optional)
""" """
self.logger = logging.getLogger(f"{__name__}.UASocketProtocol") self.logger = logging.getLogger(f"{__name__}.UASocketProtocol")
self.transport: Optional[asyncio.Transport] = None self.transport: Optional[asyncio.Transport] = None
...@@ -40,6 +41,7 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -40,6 +41,7 @@ class UASocketProtocol(asyncio.Protocol):
# needed to pass params from asynchronous request to synchronous data receive callback, as well as # needed to pass params from asynchronous request to synchronous data receive callback, as well as
# passing back the processed response to the request so that it can return it. # passing back the processed response to the request so that it can return it.
self._open_secure_channel_exchange: Union[ua.OpenSecureChannelResponse, ua.OpenSecureChannelParameters, None] = None self._open_secure_channel_exchange: Union[ua.OpenSecureChannelResponse, ua.OpenSecureChannelParameters, None] = None
self._before_request_hook = before_request_hook
def connection_made(self, transport: asyncio.Transport): # type: ignore def connection_made(self, transport: asyncio.Transport): # type: ignore
self.state = self.OPEN self.state = self.OPEN
...@@ -143,12 +145,15 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -143,12 +145,15 @@ class UASocketProtocol(asyncio.Protocol):
Returns response object if no callback is provided. Returns response object if no callback is provided.
""" """
timeout = self.timeout if timeout is None else timeout timeout = self.timeout if timeout is None else timeout
if self._before_request_hook:
# This will propagade exceptions from background tasks to the libary user before calling a request which will
# timeout then.
await self._before_request_hook()
try: try:
data = await asyncio.wait_for(self._send_request(request, timeout, message_type), timeout if timeout else None) data = await asyncio.wait_for(self._send_request(request, timeout, message_type), timeout if timeout else None)
except Exception: except Exception:
if self.state != self.OPEN: if self.state != self.OPEN:
raise ConnectionError("Connection is closed") from None raise ConnectionError("Connection is closed") from None
raise raise
self.check_answer(data, f" in response to {request.__class__.__name__}") self.check_answer(data, f" in response to {request.__class__.__name__}")
...@@ -242,9 +247,10 @@ class UaClient: ...@@ -242,9 +247,10 @@ class UaClient:
In this Python implementation most of the structures are defined in In this Python implementation most of the structures are defined in
uaprotocol_auto.py and uaprotocol_hand.py available under asyncua.ua uaprotocol_auto.py and uaprotocol_hand.py available under asyncua.ua
""" """
def __init__(self, timeout: float = 1): def __init__(self, timeout: float = 1, before_request_hook: Callable[[], Awaitable[None]] = None):
""" """
:param timeout: Timout in seconds :param timeout: Timout in seconds
:param before_request_hook: Hook to execute before a request
""" """
self.logger = logging.getLogger(f'{__name__}.UaClient') self.logger = logging.getLogger(f'{__name__}.UaClient')
self._subscription_callbacks = {} self._subscription_callbacks = {}
...@@ -252,12 +258,13 @@ class UaClient: ...@@ -252,12 +258,13 @@ class UaClient:
self.security_policy = ua.SecurityPolicy() self.security_policy = ua.SecurityPolicy()
self.protocol: UASocketProtocol = None self.protocol: UASocketProtocol = None
self._publish_task = None self._publish_task = None
self._before_request_hook = before_request_hook
def set_security(self, policy: ua.SecurityPolicy): def set_security(self, policy: ua.SecurityPolicy):
self.security_policy = policy self.security_policy = policy
def _make_protocol(self): def _make_protocol(self):
self.protocol = UASocketProtocol(self._timeout, security_policy=self.security_policy) self.protocol = UASocketProtocol(self._timeout, security_policy=self.security_policy, before_request_hook=self._before_request_hook)
return self.protocol return self.protocol
async def connect_socket(self, host: str, port: int): async def connect_socket(self, host: str, port: int):
......
# coding: utf-8 # coding: utf-8
import asyncio
import pytest import pytest
from asyncua import Client from asyncua import Client, Server
from asyncua.ua.uaerrors import BadMaxConnectionsReached from asyncua.ua.uaerrors import BadMaxConnectionsReached
from .conftest import port_num from .conftest import port_num, find_free_port
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
...@@ -30,3 +31,17 @@ async def test_safe_disconnect(): ...@@ -30,3 +31,17 @@ async def test_safe_disconnect():
await c.disconnect() await c.disconnect()
# second disconnect should be noop # second disconnect should be noop
await c.disconnect() await c.disconnect()
async def test_client_connection_lost():
# Test the disconnect behavoir
port = find_free_port()
srv = Server()
await srv.init()
srv.set_endpoint(f'opc.tcp://127.0.0.1:{port}')
await srv.start()
async with Client(f'opc.tcp://127.0.0.1:{port}', timeout=0.5, watchdog_intervall=1) as cl:
await srv.stop()
await asyncio.sleep(2)
with pytest.raises(ConnectionError):
await cl.get_namespace_array()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment