Commit 7d7841bf authored by Julien Prigent's avatar Julien Prigent Committed by oroulet

[Subscription creation] adjust MaxKeepAliveCount based on RevisedPublishingInterval

parent 9acce704
......@@ -14,7 +14,7 @@ from ..common.shortcuts import Shortcuts
from ..common.structures import load_type_definitions, load_enums
from ..common.structures104 import load_data_type_definitions
from ..common.utils import create_nonce
from ..common.ua_utils import value_to_datavalue
from ..common.ua_utils import value_to_datavalue, copy_dataclass_attr
from ..crypto import uacrypto, security_policies
_logger = logging.getLogger(__name__)
......@@ -531,7 +531,9 @@ class Client:
"""
return Node(self.uaclient, nodeid)
async def create_subscription(self, period, handler, publishing=True):
async def create_subscription(
self, period, handler, publishing=True
) -> Subscription:
"""
Create a subscription.
Returns a Subscription object which allows to subscribe to events or data changes on server.
......@@ -551,9 +553,54 @@ class Client:
params.PublishingEnabled = publishing
params.Priority = 0
subscription = Subscription(self.uaclient, params, handler)
await subscription.init()
results = await subscription.init()
new_params = self.get_subscription_revised_params(params, results)
if new_params:
results = await subscription.update(new_params)
_logger.info(f"Result from subscription update: {results}")
return subscription
def get_subscription_revised_params(
self,
params: ua.CreateSubscriptionParameters,
results: ua.CreateSubscriptionResult,
) -> None:
if (
results.RevisedPublishingInterval == params.RequestedPublishingInterval
and results.RevisedLifetimeCount == params.RequestedLifetimeCount
and results.RevisedMaxKeepAliveCount == params.RequestedMaxKeepAliveCount
):
return
_logger.warning(
f"Revised values returned differ from subscription values: {results}"
)
revised_interval = results.RevisedPublishingInterval
# Adjust the MaxKeepAliveCount based on the RevisedPublishInterval when necessary
new_keepalive_count = self.get_keepalive_count(revised_interval)
if (
revised_interval != params.RequestedPublishingInterval
and new_keepalive_count != params.RequestedMaxKeepAliveCount
):
_logger.info(
f"KeepAliveCount will be updated to {new_keepalive_count} "
f"for consistency with RevisedPublishInterval"
)
modified_params = ua.ModifySubscriptionParameters()
# copy the existing subscription parameters
copy_dataclass_attr(params, modified_params)
# then override with the revised values
modified_params.RequestedMaxKeepAliveCount = new_keepalive_count
modified_params.SubscriptionId = results.SubscriptionId
modified_params.RequestedPublishingInterval = (
results.RevisedPublishingInterval
)
modified_params.RequestedPublishingInterval = (
results.RevisedPublishingInterval
)
# update LifetimeCount but chances are it will be re-revised again
modified_params.RequestedLifetimeCount = results.RevisedLifetimeCount
return modified_params
def get_keepalive_count(self, period) -> int:
"""
We request the server to send a Keepalive notification when
......
......@@ -430,7 +430,9 @@ class UaClient:
response.ResponseHeader.ServiceResult.check()
return response.Results
async def create_subscription(self, params, callback):
async def create_subscription(
self, params, callback
) -> ua.CreateSubscriptionResult:
self.logger.debug("create_subscription")
request = ua.CreateSubscriptionRequest()
request.Parameters = params
......@@ -438,7 +440,10 @@ class UaClient:
response = struct_from_binary(ua.CreateSubscriptionResponse, data)
response.ResponseHeader.ServiceResult.check()
self._subscription_callbacks[response.Parameters.SubscriptionId] = callback
self.logger.info("create_subscription success SubscriptionId %s", response.Parameters.SubscriptionId)
self.logger.info(
"create_subscription success SubscriptionId %s",
response.Parameters.SubscriptionId
)
if not self._publish_task or self._publish_task.done():
# Start the publish loop if it is not yet running
# The current strategy is to have only one open publish request per UaClient. This might not be enough
......@@ -447,6 +452,20 @@ class UaClient:
self._publish_task = asyncio.create_task(self._publish_loop())
return response.Parameters
async def update_subscription(
self, params: ua.ModifySubscriptionParameters
) -> ua.ModifySubscriptionResult:
request = ua.ModifySubscriptionRequest()
request.Parameters = params
data = await self.protocol.send_request(request)
response = struct_from_binary(ua.ModifySubscriptionResponse, data)
response.ResponseHeader.ServiceResult.check()
self.logger.info(
"update_subscription success SubscriptionId %s",
params.SubscriptionId
)
return response.Parameters
async def delete_subscriptions(self, subscription_ids):
self.logger.debug("delete_subscriptions %r", subscription_ids)
request = ua.DeleteSubscriptionsRequest()
......
......@@ -4,7 +4,8 @@ high level interface to subscriptions
import asyncio
import logging
import collections.abc
from typing import Union, List, Iterable
from typing import Union, List, Iterable, Optional
from asyncua.common.ua_utils import copy_dataclass_attr
from asyncua import ua
from .events import Event, get_filter_from_event_type
......@@ -80,12 +81,24 @@ class Subscription:
self._handler: SubHandler = handler
self.parameters: ua.CreateSubscriptionParameters = params # move to data class
self._monitored_items = {}
self.subscription_id = None
self.subscription_id: Optional[int] = None
async def init(self):
response = await self.server.create_subscription(self.parameters, callback=self.publish_callback)
async def init(self) -> ua.CreateSubscriptionResult:
response = await self.server.create_subscription(
self.parameters, callback=self.publish_callback
)
self.subscription_id = response.SubscriptionId # move to data class
self.logger.info('Subscription created %s', self.subscription_id)
self.logger.info("Subscription created %s", self.subscription_id)
return response
async def update(
self, params: ua.ModifySubscriptionParameters
) -> ua.ModifySubscriptionResponse:
response = await self.server.update_subscription(params)
self.logger.info('Subscription updated %s', params.SubscriptionId)
# update the self.parameters attr with the updated values
copy_dataclass_attr(params, self.parameters)
return response
async def publish_callback(self, publish_result: ua.PublishResult):
"""
......
......@@ -297,3 +297,11 @@ def data_type_to_string(dtype):
else:
string = dtype.to_string()
return string
def copy_dataclass_attr(dc_source, dc_dest) -> None:
"""
Copy the common attributes of dc_source to dc_dest
"""
common_params = set(vars(dc_source)) & set(vars(dc_dest))
for c in common_params:
setattr(dc_dest, c, getattr(dc_source, c))
......@@ -357,29 +357,27 @@ async def test_subscription_data_change_many(opc):
await sub.delete()
await opc.opc.delete_nodes([v1, v2])
async def test_subscription_keepalive_count(mocker):
def test_get_keepalive_count(mocker):
"""
Check the subscription parameter MaxKeepAliveCount value
with various publishInterval and session_timeout values.
"""
mock_subscription = mocker.patch("asyncua.common.subscription.Subscription.init", new=CoroutineMock())
c = Client("opc.tcp://fake")
# session timeout < publish_interval
c.session_timeout = 30000 # ms
publish_interval = 1000 # ms
handler = 123
sub = await c.create_subscription(publish_interval, handler)
assert sub.parameters.RequestedMaxKeepAliveCount == 22
c.session_timeout = 30000 # ms
keepalive_count = c.get_keepalive_count(publish_interval)
assert keepalive_count == 22
# session_timeout > publish_interval
c.session_timeout = 30000
publish_interval = 75000
sub = await c.create_subscription(publish_interval, handler)
assert sub.parameters.RequestedMaxKeepAliveCount == 0
c.session_timeout = 30000
keepalive_count = c.get_keepalive_count(publish_interval)
assert keepalive_count == 0
# RequestedPublishingInterval == 0
publish_interval = 0
sub = await c.create_subscription(publish_interval, handler)
assert sub.parameters.RequestedMaxKeepAliveCount == 22
keepalive_count = c.get_keepalive_count(publish_interval)
assert keepalive_count == 22
async def test_subscribe_server_time(opc):
......@@ -863,3 +861,62 @@ async def test_internal_server_subscription(opc):
# Check that the results are not left un-acknowledged on internal Server Subscriptions.
assert len(internal_sub._not_acknowledged_results) == 0
await opc.opc.delete_nodes([sub_obj])
@pytest.mark.parametrize("opc", ["client"], indirect=True)
async def test_maxkeepalive_count(opc, mocker):
sub_handler = MySubHandler()
client, server = opc
period = 1
max_keepalive_count = client.get_keepalive_count(period)
mock_period = 500
mock_max_keepalive_count = client.get_keepalive_count(mock_period)
mock_response = ua.CreateSubscriptionResult(
SubscriptionId=78,
RevisedPublishingInterval=mock_period,
RevisedLifetimeCount=10000,
RevisedMaxKeepAliveCount=2700
)
mock_create_subscription = mocker.patch.object(
client.uaclient,
"create_subscription",
new=CoroutineMock(return_value=mock_response)
)
mock_update_subscription = mocker.patch.object(
client.uaclient,
"update_subscription",
new=CoroutineMock()
)
sub = await client.create_subscription(period, sub_handler)
assert sub.parameters.RequestedMaxKeepAliveCount == mock_max_keepalive_count
assert mock_max_keepalive_count != max_keepalive_count
# mock point to the object at its finale state,
# here the subscription params have already been updated
mock_create_subscription.assert_awaited_with(
ua.CreateSubscriptionParameters(
RequestedPublishingInterval=mock_period,
RequestedLifetimeCount=10000,
RequestedMaxKeepAliveCount=mock_max_keepalive_count,
MaxNotificationsPerPublish=10000,
PublishingEnabled=True,
Priority=0
),
callback=mocker.ANY
)
mock_update_subscription.assert_awaited_with(
ua.ModifySubscriptionParameters(
SubscriptionId=78,
RequestedPublishingInterval=mock_period,
RequestedLifetimeCount=10000,
RequestedMaxKeepAliveCount=mock_max_keepalive_count,
MaxNotificationsPerPublish=10000
)
)
# we don't update when sub params == revised params
mock_update_subscription.reset_mock()
mock_create_subscription.reset_mock()
sub = await client.create_subscription(mock_period, sub_handler)
mock_update_subscription.assert_not_called()
import pytest
from dataclasses import dataclass
from asyncua.common.ua_utils import copy_dataclass_attr
def test_copy_dataclass_attr():
@dataclass
class A:
x: int = 1
y: int = 2
@dataclass
class B:
y: int = 12
z: int = 13
b = B()
a = A()
assert a.y != b.y
copy_dataclass_attr(a, b)
assert a.y == b.y == 2
assert a.x == 1
assert b.z == 13
b.y = 9
copy_dataclass_attr(b, a)
assert a.y == b.y == 9
assert a.x == 1
assert b.z == 13
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment