Commit cee4d4b8 authored by Alexander Schrode's avatar Alexander Schrode Committed by oroulet

add background task

parent 688e6079
......@@ -29,11 +29,7 @@ class Client:
use UaClient object, available as self.uaclient
which offers the raw OPC-UA services interface.
"""
_username = None
_password = None
def __init__(self, url: str, timeout: float = 4):
def __init__(self, url: str, timeout: float = 4, watchdog_intervall: float = 10):
"""
:param url: url of the server.
if you are unsure of url, write at least hostname
......@@ -41,6 +37,8 @@ class Client:
:param timeout:
Each request sent to the server expects an answer within this
time. The timeout is specified in seconds.
:param watchdog_intervall:
The time between checking if the server is still alive. The timeout is specified in seconds.
Some other client parameters can be changed by setting
attributes on the constructed object:
See the source code for the exhaustive list.
......@@ -65,7 +63,7 @@ class Client:
self.secure_channel_timeout = 3600000 # 1 hour
self.session_timeout = 3600000 # 1 hour
self._policy_ids = []
self.uaclient: UaClient = UaClient(timeout)
self.uaclient: UaClient = UaClient(timeout, self._check_tasks)
self.user_certificate = None
self.user_private_key = None
self._server_nonce = None
......@@ -74,7 +72,9 @@ class Client:
self.max_messagesize = 0 # No limits
self.max_chunkcount = 0 # No limits
self._renew_channel_task = None
self._watch_task = None
self._locale = ["en"]
self._watchdog_intervall = watchdog_intervall
async def __aenter__(self):
await self.connect()
......@@ -421,11 +421,39 @@ class Client:
_logger.warning("Requested session timeout to be %dms, got %dms instead", self.secure_channel_timeout, response.RevisedSessionTimeout)
self.session_timeout = response.RevisedSessionTimeout
self._renew_channel_task = asyncio.create_task(self._renew_channel_loop())
self._watch_task = asyncio.create_task(self._watchdog_loop())
return response
async def _check_tasks(self):
# Check if a background task has finished and if a exception is thrown rethrow it with result
# use half the session timeout or the supply value from constructor
if self._renew_channel_task is not None:
if self._renew_channel_task.done():
await self._renew_channel_task
if self._watch_task is not None:
if self._watch_task.done():
await self._watch_task
async def _watchdog_loop(self):
"""
Checks if the server is alive
"""
timeout = min(self.session_timeout / 1000 / 2, self._watchdog_intervall)
try:
while True:
await asyncio.sleep(timeout)
# @FIXME handle state change
_ = await self.nodes.server_state.read_value()
except asyncio.CancelledError:
pass
except Exception:
_logger.exception("Error in watchdog loop")
raise
async def _renew_channel_loop(self):
"""
Renew the SecureChannel before the SessionTimeout will happen.
Renew the SecureChannel before the SecureChannelTimeout will happen.
In theory we could do that only if no session activity
but it does not cost much..
"""
......@@ -549,6 +577,12 @@ class Client:
"""
Close session
"""
if self._watch_task:
self._watch_task.cancel()
try:
await self._watch_task
except Exception:
_logger.exception("Error while closing watch_task")
if self._renew_channel_task:
self._renew_channel_task.cancel()
try:
......
......@@ -3,7 +3,7 @@ Low level binary client
"""
import asyncio
import logging
from typing import Dict, List, Optional, Union
from typing import Awaitable, Callable, Dict, List, Optional, Union
from asyncua import ua
from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary
......@@ -20,10 +20,11 @@ class UASocketProtocol(asyncio.Protocol):
OPEN = 'open'
CLOSED = 'closed'
def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.SecurityPolicy()):
def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.SecurityPolicy(), before_request_hook: Callable[[], Awaitable[None]] = None):
"""
:param timeout: Timeout in seconds
:param security_policy: Security policy (optional)
:param before_request_hook: Hook for upperlayer tasks before a request is send (optional)
"""
self.logger = logging.getLogger(f"{__name__}.UASocketProtocol")
self.transport: Optional[asyncio.Transport] = None
......@@ -40,6 +41,7 @@ class UASocketProtocol(asyncio.Protocol):
# needed to pass params from asynchronous request to synchronous data receive callback, as well as
# passing back the processed response to the request so that it can return it.
self._open_secure_channel_exchange: Union[ua.OpenSecureChannelResponse, ua.OpenSecureChannelParameters, None] = None
self._before_request_hook = before_request_hook
def connection_made(self, transport: asyncio.Transport): # type: ignore
self.state = self.OPEN
......@@ -143,12 +145,15 @@ class UASocketProtocol(asyncio.Protocol):
Returns response object if no callback is provided.
"""
timeout = self.timeout if timeout is None else timeout
if self._before_request_hook:
# This will propagade exceptions from background tasks to the libary user before calling a request which will
# timeout then.
await self._before_request_hook()
try:
data = await asyncio.wait_for(self._send_request(request, timeout, message_type), timeout if timeout else None)
except Exception:
if self.state != self.OPEN:
raise ConnectionError("Connection is closed") from None
raise
self.check_answer(data, f" in response to {request.__class__.__name__}")
......@@ -242,9 +247,10 @@ class UaClient:
In this Python implementation most of the structures are defined in
uaprotocol_auto.py and uaprotocol_hand.py available under asyncua.ua
"""
def __init__(self, timeout: float = 1):
def __init__(self, timeout: float = 1, before_request_hook: Callable[[], Awaitable[None]] = None):
"""
:param timeout: Timout in seconds
:param before_request_hook: Hook to execute before a request
"""
self.logger = logging.getLogger(f'{__name__}.UaClient')
self._subscription_callbacks = {}
......@@ -252,12 +258,13 @@ class UaClient:
self.security_policy = ua.SecurityPolicy()
self.protocol: UASocketProtocol = None
self._publish_task = None
self._before_request_hook = before_request_hook
def set_security(self, policy: ua.SecurityPolicy):
self.security_policy = policy
def _make_protocol(self):
self.protocol = UASocketProtocol(self._timeout, security_policy=self.security_policy)
self.protocol = UASocketProtocol(self._timeout, security_policy=self.security_policy, before_request_hook=self._before_request_hook)
return self.protocol
async def connect_socket(self, host: str, port: int):
......
# coding: utf-8
import asyncio
import pytest
from asyncua import Client
from asyncua import Client, Server
from asyncua.ua.uaerrors import BadMaxConnectionsReached
from .conftest import port_num
from .conftest import port_num, find_free_port
pytestmark = pytest.mark.asyncio
......@@ -30,3 +31,17 @@ async def test_safe_disconnect():
await c.disconnect()
# second disconnect should be noop
await c.disconnect()
async def test_client_connection_lost():
# Test the disconnect behavoir
port = find_free_port()
srv = Server()
await srv.init()
srv.set_endpoint(f'opc.tcp://127.0.0.1:{port}')
await srv.start()
async with Client(f'opc.tcp://127.0.0.1:{port}', timeout=0.5, watchdog_intervall=1) as cl:
await srv.stop()
await asyncio.sleep(2)
with pytest.raises(ConnectionError):
await cl.get_namespace_array()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment