Commit cd1645c1 authored by Olivier's avatar Olivier Committed by oroulet

format everything with ruff

parent 5ea8d17b
"""
Pure Python OPC-UA library
"""
import sys
if sys.version_info >= (3, 8):
from importlib import metadata
else:
......
......@@ -57,9 +57,9 @@ class Client:
"""
self._server_url = urlparse(url)
# take initial username and password from the url
userinfo, have_info, _ = self._server_url.netloc.rpartition('@')
userinfo, have_info, _ = self._server_url.netloc.rpartition("@")
if have_info:
username, have_password, password = userinfo.partition(':')
username, have_password, password = userinfo.partition(":")
self._username = unquote(username)
if have_password:
self._password = unquote(password)
......@@ -111,7 +111,7 @@ class Client:
is not recommended for security reasons.
"""
url = self._server_url
userinfo, have_info, hostinfo = url.netloc.rpartition('@')
userinfo, have_info, hostinfo = url.netloc.rpartition("@")
if have_info:
# remove credentials from url, preventing them to be sent unencrypted in e.g. send_hello
if self.strip_url_credentials:
......@@ -125,7 +125,7 @@ class Client:
"""
_logger.info("find_endpoint %r %r %r", endpoints, security_mode, policy_uri)
for ep in endpoints:
if (ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and ep.SecurityMode == security_mode and ep.SecurityPolicyUri == policy_uri):
if ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and ep.SecurityMode == security_mode and ep.SecurityPolicyUri == policy_uri:
return ep
raise ua.UaError(f"No matching endpoints: {security_mode}, {policy_uri}")
......@@ -171,8 +171,8 @@ class Client:
if len(parts) < 4:
raise ua.UaError(f"Wrong format: `{string}`, expected at least 4 comma-separated values")
if '::' in parts[3]: # if the filename contains a colon, assume it's a conjunction and parse it
parts[3], client_key_password = parts[3].split('::')
if "::" in parts[3]: # if the filename contains a colon, assume it's a conjunction and parse it
parts[3], client_key_password = parts[3].split("::")
else:
client_key_password = None
......@@ -205,7 +205,7 @@ class Client:
# this generates a error in our crypto part, so we strip everything after
# the server cert. To do this we read byte 2:4 and get the length - 4
cert_len_idx = 2
len_bytestr = endpoint.ServerCertificate[cert_len_idx:cert_len_idx + 2]
len_bytestr = endpoint.ServerCertificate[cert_len_idx : cert_len_idx + 2]
cert_len = int.from_bytes(len_bytestr, byteorder="big", signed=False) + 4
server_certificate = uacrypto.x509_from_der(endpoint.ServerCertificate[:cert_len])
elif not isinstance(server_certificate, uacrypto.CertProperties):
......@@ -224,7 +224,6 @@ class Client:
server_cert: uacrypto.CertProperties,
mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt,
) -> None:
if isinstance(server_cert, uacrypto.CertProperties):
server_cert = await uacrypto.load_certificate(server_cert.path_or_content, server_cert.extension)
cert = await uacrypto.load_certificate(certificate.path_or_content, certificate.extension)
......@@ -500,7 +499,7 @@ class Client:
# this generates a error in our crypto part, so we strip everything after
# the server cert. To do this we read byte 2:4 and get the length - 4
cert_len_idx = 2
len_bytestr = response.ServerCertificate[cert_len_idx:cert_len_idx + 2]
len_bytestr = response.ServerCertificate[cert_len_idx : cert_len_idx + 2]
cert_len = int.from_bytes(len_bytestr, byteorder="big", signed=False) + 4
server_certificate = response.ServerCertificate[:cert_len]
if not self.security_policy.peer_certificate:
......@@ -630,7 +629,7 @@ class Client:
if self.security_policy.AsymmetricSignatureURI:
params.ClientSignature.Algorithm = self.security_policy.AsymmetricSignatureURI
else:
params.ClientSignature.Algorithm = (security_policies.SecurityPolicyBasic256.AsymmetricSignatureURI)
params.ClientSignature.Algorithm = security_policies.SecurityPolicyBasic256.AsymmetricSignatureURI
params.ClientSignature.Signature = self.security_policy.asymmetric_cryptography.signature(challenge)
params.LocaleIds = self._locale
if not username and not user_certificate:
......@@ -729,9 +728,7 @@ class Client:
"""
return Node(self.uaclient, nodeid)
async def create_subscription(
self, period: Union[ua.CreateSubscriptionParameters, float], handler: SubscriptionHandler, publishing: bool = True
) -> Subscription:
async def create_subscription(self, period: Union[ua.CreateSubscriptionParameters, float], handler: SubscriptionHandler, publishing: bool = True) -> Subscription:
"""
Create a subscription.
Returns a Subscription object which allows to subscribe to events or data changes on server.
......@@ -763,35 +760,21 @@ class Client:
params: ua.CreateSubscriptionParameters,
results: ua.CreateSubscriptionResult,
) -> Optional[ua.ModifySubscriptionParameters]:
if (
results.RevisedPublishingInterval == params.RequestedPublishingInterval
and results.RevisedLifetimeCount == params.RequestedLifetimeCount
and results.RevisedMaxKeepAliveCount == params.RequestedMaxKeepAliveCount
):
if results.RevisedPublishingInterval == params.RequestedPublishingInterval and results.RevisedLifetimeCount == params.RequestedLifetimeCount and results.RevisedMaxKeepAliveCount == params.RequestedMaxKeepAliveCount:
return None
_logger.warning(
"Revised values returned differ from subscription values: %s", results
)
_logger.warning("Revised values returned differ from subscription values: %s", results)
revised_interval = results.RevisedPublishingInterval
# Adjust the MaxKeepAliveCount based on the RevisedPublishInterval when necessary
new_keepalive_count = self.get_keepalive_count(revised_interval)
if (
revised_interval != params.RequestedPublishingInterval
and new_keepalive_count != params.RequestedMaxKeepAliveCount
):
_logger.info(
"KeepAliveCount will be updated to %s "
"for consistency with RevisedPublishInterval", new_keepalive_count
)
if revised_interval != params.RequestedPublishingInterval and new_keepalive_count != params.RequestedMaxKeepAliveCount:
_logger.info("KeepAliveCount will be updated to %s " "for consistency with RevisedPublishInterval", new_keepalive_count)
modified_params = ua.ModifySubscriptionParameters()
# copy the existing subscription parameters
copy_dataclass_attr(params, modified_params)
# then override with the revised values
modified_params.RequestedMaxKeepAliveCount = new_keepalive_count
modified_params.SubscriptionId = results.SubscriptionId
modified_params.RequestedPublishingInterval = (
results.RevisedPublishingInterval
)
modified_params.RequestedPublishingInterval = results.RevisedPublishingInterval
# update LifetimeCount but chances are it will be re-revised again
modified_params.RequestedLifetimeCount = results.RevisedLifetimeCount
return modified_params
......
This diff is collapsed.
......@@ -87,7 +87,6 @@ class Reconciliator:
_logger.info("Starting Reconciliator loop, checking every %dsec", self.timer)
self.is_running = True
while self.is_running:
start = time.time()
async with self.ha_client._url_to_reset_lock:
await self.resubscribe()
......@@ -153,13 +152,9 @@ class Reconciliator:
if url not in real_map or digest_ideal != digest_real:
targets.add(url)
if not targets:
_logger.info(
"[PASS] No configuration difference for healthy targets: %s", valid_urls
)
_logger.info("[PASS] No configuration difference for healthy targets: %s", valid_urls)
return
_logger.info(
"[WORK] Configuration difference found for healthy targets: %s", targets
)
_logger.info("[WORK] Configuration difference found for healthy targets: %s", targets)
except (AttributeError, TypeError, PicklingError) as e:
_logger.warning("[WORK] Reconciliator performance impacted: %s", e)
targets = set(valid_urls)
......@@ -170,9 +165,7 @@ class Reconciliator:
# look for missing options (publish/monitoring) for existing subs
await self.update_subscription_modes(real_map, ideal_map, targets)
async def update_subscriptions(
self, real_map, ideal_map, targets: Set[str]
) -> None:
async def update_subscriptions(self, real_map, ideal_map, targets: Set[str]) -> None:
_logger.debug("In update_subscriptions")
tasks = []
for url in targets:
......@@ -180,9 +173,7 @@ class Reconciliator:
tasks.extend(self._subs_to_add(url, real_map, ideal_map))
await asyncio.gather(*tasks, return_exceptions=True)
def _subs_to_del(
self, url: str, real_map: SubMap, ideal_map: SubMap
) -> List[asyncio.Task]:
def _subs_to_del(self, url: str, real_map: SubMap, ideal_map: SubMap) -> List[asyncio.Task]:
to_del: List[asyncio.Task] = []
sub_to_del = set(real_map[url]) - set(ideal_map[url])
if sub_to_del:
......@@ -190,15 +181,11 @@ class Reconciliator:
for sub_name in sub_to_del:
sub_handle = self.name_to_subscription[url][sub_name]
task = asyncio.create_task(sub_handle.delete())
task.add_done_callback(
partial(self.del_from_map, url, Method.DEL_SUB, sub_name=sub_name)
)
task.add_done_callback(partial(self.del_from_map, url, Method.DEL_SUB, sub_name=sub_name))
to_del.append(task)
return to_del
def _subs_to_add(
self, url: str, real_map: SubMap, ideal_map: SubMap
) -> List[asyncio.Task]:
def _subs_to_add(self, url: str, real_map: SubMap, ideal_map: SubMap) -> List[asyncio.Task]:
to_add: List[asyncio.Task] = []
sub_to_add = set(ideal_map[url]) - set(real_map[url])
if sub_to_add:
......@@ -206,11 +193,7 @@ class Reconciliator:
client = self.ha_client.get_client_by_url(url)
for sub_name in sub_to_add:
vs = ideal_map[url][sub_name]
task = asyncio.create_task(
client.create_subscription(
vs.period, vs.handler, publishing=vs.publishing
)
)
task = asyncio.create_task(client.create_subscription(vs.period, vs.handler, publishing=vs.publishing))
task.add_done_callback(
partial(
self.add_to_map,
......@@ -226,9 +209,7 @@ class Reconciliator:
to_add.append(task)
return to_add
async def update_nodes(
self, real_map: SubMap, ideal_map: SubMap, targets: Set[str]
) -> None:
async def update_nodes(self, real_map: SubMap, ideal_map: SubMap, targets: Set[str]) -> None:
_logger.debug("In update_nodes")
tasks = []
for url in targets:
......@@ -237,17 +218,12 @@ class Reconciliator:
real_sub = self.name_to_subscription[url].get(sub_name)
# in case the previous create_subscription request failed
if not real_sub:
_logger.warning(
"Can't create nodes for %s since underlying "
"subscription for %s doesn't exist", url, sub_name
)
_logger.warning("Can't create nodes for %s since underlying " "subscription for %s doesn't exist", url, sub_name)
continue
vs_real = real_map[url][sub_name]
vs_ideal = ideal_map[url][sub_name]
tasks.extend(self._nodes_to_del(url, sub_name, vs_real, vs_ideal))
tasks.extend(
self._nodes_to_add(url, sub_name, client, vs_real, vs_ideal)
)
tasks.extend(self._nodes_to_add(url, sub_name, client, vs_real, vs_ideal))
await asyncio.gather(*tasks, return_exceptions=True)
def _nodes_to_add(
......@@ -293,9 +269,7 @@ class Reconciliator:
)
)
tasks.append(task)
self.hook_mi_request(
url=url, sub_name=sub_name, nodes=node_to_add, action=Method.ADD_MI
)
self.hook_mi_request(url=url, sub_name=sub_name, nodes=node_to_add, action=Method.ADD_MI)
return tasks
def _nodes_to_del(
......@@ -323,14 +297,10 @@ class Reconciliator:
)
)
to_del.append(task)
self.hook_mi_request(
url=url, sub_name=sub_name, nodes=node_to_del, action=Method.DEL_MI
)
self.hook_mi_request(url=url, sub_name=sub_name, nodes=node_to_del, action=Method.DEL_MI)
return to_del
async def update_subscription_modes(
self, real_map: SubMap, ideal_map: SubMap, targets: Set[str]
) -> None:
async def update_subscription_modes(self, real_map: SubMap, ideal_map: SubMap, targets: Set[str]) -> None:
_logger.debug("In update_subscription_modes")
modes = [Method.MONITORING, Method.PUBLISHING]
methods = [n.value for n in modes]
......@@ -340,9 +310,7 @@ class Reconciliator:
real_sub = self.name_to_subscription[url].get(sub_name)
# in case the previous create_subscription request failed
if not real_sub:
_logger.warning(
"Can't change modes for %s since underlying subscription for %s doesn't exist", url, sub_name
)
_logger.warning("Can't change modes for %s since underlying subscription for %s doesn't exist", url, sub_name)
continue
vs_real = real_map[url][sub_name]
vs_ideal = ideal_map[url][sub_name]
......@@ -404,18 +372,14 @@ class Reconciliator:
_logger.info("Node %s subscription failed: %s", node, handle)
# The node is invalid, remove it from both maps
if handle.name == "BadNodeIdUnknown":
_logger.warning(
"WARNING: Abandoning %s because it returned %s from %s", node, handle, url
)
_logger.warning("WARNING: Abandoning %s because it returned %s from %s", node, handle, url)
real_vs = self.ha_client.ideal_map[url][sub_name]
real_vs.unsubscribe([node])
continue
self.node_to_handle[url][node] = handle
self.hook_add_to_map(fut=fut, url=url, action=action, **kwargs)
def del_from_map(
self, url: str, action: Method, fut: asyncio.Task, **kwargs
) -> None:
def del_from_map(self, url: str, action: Method, fut: asyncio.Task, **kwargs) -> None:
if fut.exception():
# log exception but continues to delete local resources
_logger.warning("Can't %s on %s: %s", action.value, url, fut.exception())
......@@ -443,17 +407,13 @@ class Reconciliator:
_logger.debug(a)
def hook_mi_request(self, url: str, sub_name: str, nodes: Set[SortedDict], action: Method):
"""placeholder for easily superclass the HaClient and implement custom logic
"""
"""placeholder for easily superclass the HaClient and implement custom logic"""
def hook_add_to_map_error(self, url: str, action: Method, fut: asyncio.Task, **kwargs):
"""placeholder for easily superclass the HaClient and implement custom logic
"""
"""placeholder for easily superclass the HaClient and implement custom logic"""
def hook_add_to_map(self, fut: asyncio.Task, url: str, action: Method, **kwargs):
"""placeholder for easily superclass the HaClient and implement custom logic
"""
"""placeholder for easily superclass the HaClient and implement custom logic"""
def hook_del_from_map(self, fut: asyncio.Task, url: str, **kwargs):
"""placeholder for easily superclass the HaClient and implement custom logic
"""
"""placeholder for easily superclass the HaClient and implement custom logic"""
......@@ -24,9 +24,7 @@ class VirtualSubscription:
# see: https://github.com/grantjenks/python-sortedcontainers/pull/107
nodes: SortedDict = field(default_factory=SortedDict)
def subscribe_data_change(
self, nodes: Iterable[str], attr: ua.AttributeIds, queuesize: int
) -> None:
def subscribe_data_change(self, nodes: Iterable[str], attr: ua.AttributeIds, queuesize: int) -> None:
for node in nodes:
self.nodes[node] = NodeAttr(attr, queuesize)
......
"""
Low level binary client
"""
import asyncio
import copy
import logging
......@@ -20,9 +21,10 @@ class UASocketProtocol(asyncio.Protocol):
Handle socket connection and send ua messages.
Timeout is the timeout used while waiting for an ua answer from server.
"""
INITIALIZED = 'initialized'
OPEN = 'open'
CLOSED = 'closed'
INITIALIZED = "initialized"
OPEN = "open"
CLOSED = "closed"
def __init__(self, timeout: float = 1, security_policy: ua.SecurityPolicy = ua.SecurityPolicy(), limits: TransportLimits = None):
"""
......@@ -52,7 +54,7 @@ class UASocketProtocol(asyncio.Protocol):
# Hook for upper layer tasks before a request is sent (optional)
self.pre_request_hook: Optional[Callable[[], Awaitable[None]]] = None
def connection_made(self, transport: asyncio.Transport): # type: ignore[override]
def connection_made(self, transport: asyncio.Transport): # type: ignore[override]
self.state = self.OPEN
self.transport = transport
......@@ -79,11 +81,11 @@ class UASocketProtocol(asyncio.Protocol):
try:
header = header_from_binary(buf)
except ua.utils.NotEnoughData:
self.logger.debug('Not enough data while parsing header from server, waiting for more')
self.logger.debug("Not enough data while parsing header from server, waiting for more")
self.receive_buffer = data
return
if len(buf) < header.body_size:
self.logger.debug('We did not receive enough data from server. Need %s got %s', header.body_size, len(buf))
self.logger.debug("We did not receive enough data from server. Need %s got %s", header.body_size, len(buf))
self.receive_buffer = data
return
msg = self._connection.receive_from_header_and_body(header, buf)
......@@ -99,11 +101,11 @@ class UASocketProtocol(asyncio.Protocol):
# Buffer still has bytes left, try to process again
data = bytes(buf)
except ua.UaStatusCodeError as e:
self.logger.error('Got error status from server: {}'.format(e))
self.logger.error("Got error status from server: {}".format(e))
self.disconnect_socket()
return
except Exception:
self.logger.exception('Exception raised while parsing message from server')
self.logger.exception("Exception raised while parsing message from server")
self.disconnect_socket()
return
......@@ -133,7 +135,7 @@ class UASocketProtocol(asyncio.Protocol):
:return: Future that resolves with the Response
"""
self._setup_request_header(request.RequestHeader, timeout)
self.logger.debug('Sending: %s', request)
self.logger.debug("Sending: %s", request)
try:
binreq = struct_to_binary(request)
except Exception:
......@@ -229,10 +231,10 @@ class UASocketProtocol(asyncio.Protocol):
request = ua.OpenSecureChannelRequest()
request.Parameters = params
if self._open_secure_channel_exchange is not None:
raise RuntimeError('Two Open Secure Channel requests can not happen too close to each other. ' 'The response must be processed and returned before the next request can be sent.')
raise RuntimeError("Two Open Secure Channel requests can not happen too close to each other. " "The response must be processed and returned before the next request can be sent.")
self._open_secure_channel_exchange = params
await wait_for(self._send_request(request, message_type=ua.MessageType.SecureOpen), self.timeout)
_return = self._open_secure_channel_exchange.Parameters # type: ignore[union-attr]
_return = self._open_secure_channel_exchange.Parameters # type: ignore[union-attr]
self._open_secure_channel_exchange = None
return _return
......@@ -267,7 +269,7 @@ class UaClient(AbstractSession):
"""
:param timeout: Timout in seconds
"""
self.logger = logging.getLogger(f'{__name__}.UaClient')
self.logger = logging.getLogger(f"{__name__}.UaClient")
self._subscription_callbacks = {}
self._timeout = timeout
self.security_policy = ua.SecurityPolicy()
......@@ -499,7 +501,7 @@ class UaClient(AbstractSession):
response.ResponseHeader.ServiceResult.check()
return response.Results
async def create_subscription( # type: ignore[override]
async def create_subscription( # type: ignore[override]
self, params: ua.CreateSubscriptionParameters, callback
) -> ua.CreateSubscriptionResult:
self.logger.debug("create_subscription")
......@@ -509,10 +511,7 @@ class UaClient(AbstractSession):
response = struct_from_binary(ua.CreateSubscriptionResponse, data)
response.ResponseHeader.ServiceResult.check()
self._subscription_callbacks[response.Parameters.SubscriptionId] = callback
self.logger.info(
"create_subscription success SubscriptionId %s",
response.Parameters.SubscriptionId
)
self.logger.info("create_subscription success SubscriptionId %s", response.Parameters.SubscriptionId)
if not self._publish_task or self._publish_task.done():
# Start the publishing loop if it is not yet running
# The current strategy is to have only one open publish request per UaClient. This might not be enough
......@@ -523,16 +522,13 @@ class UaClient(AbstractSession):
async def inform_subscriptions(self, status: ua.StatusCode):
"""
Inform all current subscriptions with a status code. This calls the handler's status_change_notification
Inform all current subscriptions with a status code. This calls the handler's status_change_notification
"""
status_message = ua.StatusChangeNotification(Status=status)
notification_message = ua.NotificationMessage(NotificationData=[status_message]) # type: ignore[list-item]
for subid, callback in self._subscription_callbacks.items():
try:
parameters = ua.PublishResult(
subid,
NotificationMessage_=notification_message
)
parameters = ua.PublishResult(subid, NotificationMessage_=notification_message)
if asyncio.iscoroutinefunction(callback):
await callback(parameters)
else:
......@@ -540,18 +536,13 @@ class UaClient(AbstractSession):
except Exception: # we call user code, catch everything!
self.logger.exception("Exception while calling user callback: %s")
async def update_subscription(
self, params: ua.ModifySubscriptionParameters
) -> ua.ModifySubscriptionResult:
async def update_subscription(self, params: ua.ModifySubscriptionParameters) -> ua.ModifySubscriptionResult:
request = ua.ModifySubscriptionRequest()
request.Parameters = params
data = await self.protocol.send_request(request)
response = struct_from_binary(ua.ModifySubscriptionResponse, data)
response.ResponseHeader.ServiceResult.check()
self.logger.info(
"update_subscription success SubscriptionId %s",
params.SubscriptionId
)
self.logger.info("update_subscription success SubscriptionId %s", params.SubscriptionId)
return response.Parameters
modify_subscription = update_subscription # legacy support
......@@ -572,7 +563,7 @@ class UaClient(AbstractSession):
"""
Send a PublishRequest to the server.
"""
self.logger.debug('publish %r', acks)
self.logger.debug("publish %r", acks)
request = ua.PublishRequest()
request.Parameters.SubscriptionAcknowledgements = acks if acks else []
data = await self.protocol.send_request(request, timeout=0)
......
......@@ -2,13 +2,12 @@ from asyncua import ua
class UaFile:
def __init__(self, file_node, open_mode):
self._file_node = file_node
self._handle = None
if open_mode == 'r':
if open_mode == "r":
self._init_open = ua.OpenFileMode.Read.value
elif open_mode == 'w':
elif open_mode == "w":
self._init_open = ua.OpenFileMode.Write.value
else:
raise ValueError("file mode is not supported")
......@@ -21,19 +20,19 @@ class UaFile:
return await self.close()
async def open(self, open_mode):
""" open file method """
"""open file method"""
open_node = await self._file_node.get_child("Open")
arg = ua.Variant(open_mode, ua.VariantType.Byte)
return await self._file_node.call_method(open_node, arg)
async def close(self):
""" close file method """
"""close file method"""
read_node = await self._file_node.get_child("Close")
arg1 = ua.Variant(self._handle, ua.VariantType.UInt32)
return await self._file_node.call_method(read_node, arg1)
async def read(self):
""" reads file contents """
"""reads file contents"""
size = await self.get_size()
read_node = await self._file_node.get_child("Read")
arg1 = ua.Variant(self._handle, ua.VariantType.UInt32)
......@@ -41,13 +40,13 @@ class UaFile:
return await self._file_node.call_method(read_node, arg1, arg2)
async def write(self, data: bytes):
""" writes file contents """
"""writes file contents"""
write_node = await self._file_node.get_child("Write")
arg1 = ua.Variant(self._handle, ua.VariantType.UInt32)
arg2 = ua.Variant(data, ua.VariantType.ByteString)
return await self._file_node.call_method(write_node, arg1, arg2)
async def get_size(self):
""" gets size of file """
"""gets size of file"""
size_node = await self._file_node.get_child("Size")
return await size_node.read_value()
......@@ -8,6 +8,7 @@ OPC 10000-5: OPC Unified Architecture V1.04
Part 5: Information Model - Annex C (normative) File Transfer
https://reference.opcfoundation.org/Core/docs/Part5/C.1/
"""
import logging
from typing import Tuple
......@@ -22,6 +23,7 @@ class UaFile:
"""
Provides the functionality to work with "C.2 FileType".
"""
def __init__(self, file_node: Node, open_mode: OpenFileMode = OpenFileMode.Read.value):
"""
Initializes a new instance of the UaFile class.
......@@ -124,9 +126,7 @@ class UaFile:
self._set_position_node = await self._file_node.get_child("SetPosition")
arg1_file_handle = Variant(self._file_handle, VariantType.UInt32)
arg2_position = Variant(position, VariantType.UInt64)
return await self._file_node.call_method(self._set_position_node,
arg1_file_handle,
arg2_position)
return await self._file_node.call_method(self._set_position_node, arg1_file_handle, arg2_position)
async def get_size(self) -> int:
"""
......@@ -176,6 +176,7 @@ class UaDirectory:
"""
Provides the functionality to work with "C.3 File System".
"""
def __init__(self, directory_node):
self._directory_node = directory_node
......@@ -222,9 +223,7 @@ class UaDirectory:
create_file_node = await self._directory_node.get_child("CreateFile")
arg1_file_name = Variant(file_name, VariantType.String)
arg2_request_file_open = Variant(request_file_open, VariantType.Boolean)
return await self._directory_node.call_method(create_file_node,
arg1_file_name,
arg2_request_file_open)
return await self._directory_node.call_method(create_file_node, arg1_file_name, arg2_request_file_open)
async def delete(self, object_to_delete: NodeId) -> None:
"""
......@@ -237,11 +236,7 @@ class UaDirectory:
delete_node = await self._directory_node.get_child("Delete")
await self._directory_node.call_method(delete_node, object_to_delete)
async def move_or_copy(self,
object_to_move_or_copy: NodeId,
target_directory: NodeId,
create_copy: bool,
new_name: str) -> NodeId:
async def move_or_copy(self, object_to_move_or_copy: NodeId, target_directory: NodeId, create_copy: bool, new_name: str) -> NodeId:
"""
MoveOrCopy is used to move or copy a file or directory organized by this Object
to another directory or to rename a file or directory.
......@@ -256,18 +251,8 @@ class UaDirectory:
:return: The NodeId of the moved or copied object. Even if the Object is moved,
the Server may return a new NodeId.
"""
_logger.debug("Request to %s%s file system object %s from %s to %s, new name=%s",
'' if create_copy else 'move',
'copy' if create_copy else '',
object_to_move_or_copy,
self._directory_node,
target_directory,
new_name)
_logger.debug("Request to %s%s file system object %s from %s to %s, new name=%s", "" if create_copy else "move", "copy" if create_copy else "", object_to_move_or_copy, self._directory_node, target_directory, new_name)
move_or_copy_node = await self._directory_node.get_child("MoveOrCopy")
arg3_create_copy = Variant(create_copy, VariantType.Boolean)
arg4_new_name = Variant(new_name, VariantType.String)
return await self._directory_node.call_method(move_or_copy_node,
object_to_move_or_copy,
target_directory,
arg3_create_copy,
arg4_new_name)
return await self._directory_node.call_method(move_or_copy_node, object_to_move_or_copy, target_directory, arg3_create_copy, arg4_new_name)
......@@ -15,6 +15,7 @@ class CallbackType(Enum):
:ivar MonitoredItem:
"""
Null = 0
ItemSubscriptionCreated = 1
ItemSubscriptionModified = 2
......@@ -57,7 +58,7 @@ class CallbackService:
if event is None:
event = Callback()
elif not isinstance(event, Callback):
raise ValueError('Unexpected event type given')
raise ValueError("Unexpected event type given")
event.setName(eventName)
if eventName not in self._listeners:
return event
......@@ -91,7 +92,7 @@ class CallbackService:
def addSubscriber(self, subscriber):
if not isinstance(subscriber, CallbackSubscriberInterface):
raise ValueError('Unexpected subscriber type given')
raise ValueError("Unexpected subscriber type given")
for eventName, params in subscriber.getSubscribedEvents().items():
if isinstance(params, str):
self.addListener(eventName, getattr(subscriber, params))
......
......@@ -11,17 +11,20 @@ from ..ua.ua_binary import struct_from_binary, struct_to_binary, header_from_bin
try:
from ..crypto.uacrypto import InvalidSignature
except ImportError:
class InvalidSignature(Exception): # type: ignore
pass
_logger = logging.getLogger('asyncua.uaprotocol')
_logger = logging.getLogger("asyncua.uaprotocol")
@dataclass
class TransportLimits:
'''
Limits of the tcp transport layer to prevent excessive resource usage
'''
"""
Limits of the tcp transport layer to prevent excessive resource usage
"""
# Max size of a chunk we can receive
max_recv_buffer: int = 65535
# Max size of a chunk we can send
......@@ -84,7 +87,7 @@ class MessageChunk:
Message Chunk, as described in OPC UA specs Part 6, 6.7.2.
"""
def __init__(self, crypto, body=b'', msg_type=ua.MessageType.SecureMessage, chunk_type=ua.ChunkType.Single):
def __init__(self, crypto, body=b"", msg_type=ua.MessageType.SecureMessage, chunk_type=ua.ChunkType.Single):
self.MessageHeader = ua.Header(msg_type, chunk_type)
if msg_type in (ua.MessageType.SecureMessage, ua.MessageType.SecureClose):
self.SecurityHeader = ua.SymmetricAlgorithmHeader()
......@@ -107,7 +110,7 @@ class MessageChunk:
@staticmethod
def from_header_and_body(security_policy, header, buf, use_prev_key=False):
if not len(buf) >= header.body_size:
raise ValueError('Full body expected here')
raise ValueError("Full body expected here")
data = buf.copy(header.body_size)
buf.skip(header.body_size)
if header.MessageType in (ua.MessageType.SecureMessage, ua.MessageType.SecureClose):
......@@ -170,8 +173,7 @@ class MessageChunk:
if security_policy.host_certificate:
chunk.SecurityHeader.SenderCertificate = security_policy.host_certificate
if security_policy.peer_certificate:
chunk.SecurityHeader.ReceiverCertificateThumbPrint =\
hashlib.sha1(security_policy.peer_certificate).digest()
chunk.SecurityHeader.ReceiverCertificateThumbPrint = hashlib.sha1(security_policy.peer_certificate).digest()
chunk.MessageHeader.ChannelId = channel_id
chunk.SequenceHeader.RequestId = request_id
return [chunk]
......@@ -181,7 +183,7 @@ class MessageChunk:
chunks = []
for i in range(0, len(body), max_size):
part = body[i:i + max_size]
part = body[i : i + max_size]
if i + max_size >= len(body):
chunk_type = ua.ChunkType.Single
else:
......@@ -194,8 +196,7 @@ class MessageChunk:
return chunks
def __str__(self):
return f"{self.__class__.__name__}({self.MessageHeader}, {self.SequenceHeader}," \
f" {self.SecurityHeader}, {len(self.Body)} bytes)"
return f"{self.__class__.__name__}({self.MessageHeader}, {self.SequenceHeader}," f" {self.SecurityHeader}, {len(self.Body)} bytes)"
__repr__ = __str__
......@@ -204,6 +205,7 @@ class SecureConnection:
"""
Common logic for client and server
"""
def __init__(self, security_policy, limits: TransportLimits):
self._sequence_number = 0
self._peer_sequence_number = None
......@@ -228,11 +230,7 @@ class SecureConnection:
self.local_nonce = client_nonce
self.remote_nonce = params.ServerNonce
self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
self.security_policy.make_remote_symmetric_key(
self.local_nonce,
self.remote_nonce,
self.security_token.RevisedLifetime
)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce, self.security_token.RevisedLifetime)
self._open = True
else:
self.next_security_token = params.SecurityToken
......@@ -261,11 +259,7 @@ class SecureConnection:
response.SecurityToken = self.security_token
self.security_policy.make_local_symmetric_key(self.remote_nonce, self.local_nonce)
self.security_policy.make_remote_symmetric_key(
self.local_nonce,
self.remote_nonce,
self.security_token.RevisedLifetime
)
self.security_policy.make_remote_symmetric_key(self.local_nonce, self.remote_nonce, self.security_token.RevisedLifetime)
else:
self.next_security_token = copy.deepcopy(self.security_token)
self.next_security_token.TokenId += 1
......@@ -356,8 +350,7 @@ class SecureConnection:
# expired SecurityToken for up to 25 % of the token lifetime. This should ensure that
# Messages sent by the Server before the token expired are not rejected because of
# network delays.
timeout = self.prev_security_token.CreatedAt + \
timedelta(milliseconds=self.prev_security_token.RevisedLifetime * 1.25)
timeout = self.prev_security_token.CreatedAt + timedelta(milliseconds=self.prev_security_token.RevisedLifetime * 1.25)
if timeout < datetime.now(timezone.utc):
raise ua.UaError(f"Security token id {security_hdr.TokenId} has timed out " f"({timeout} < {datetime.now(timezone.utc)})")
return
......@@ -369,13 +362,13 @@ class SecureConnection:
def _check_incoming_chunk(self, chunk):
if not isinstance(chunk, MessageChunk):
raise ValueError(f'Expected chunk, got: {chunk}')
raise ValueError(f"Expected chunk, got: {chunk}")
if chunk.MessageHeader.MessageType != ua.MessageType.SecureOpen:
if chunk.MessageHeader.ChannelId != self.security_token.ChannelId:
raise ua.UaError(f'Wrong channel id {chunk.MessageHeader.ChannelId},' f' expected {self.security_token.ChannelId}')
raise ua.UaError(f"Wrong channel id {chunk.MessageHeader.ChannelId}," f" expected {self.security_token.ChannelId}")
if self._incoming_parts:
if self._incoming_parts[0].SequenceHeader.RequestId != chunk.SequenceHeader.RequestId:
raise ua.UaError(f'Wrong request id {chunk.SequenceHeader.RequestId},' f' expected {self._incoming_parts[0].SequenceHeader.RequestId}')
raise ua.UaError(f"Wrong request id {chunk.SequenceHeader.RequestId}," f" expected {self._incoming_parts[0].SequenceHeader.RequestId}")
# The sequence number must monotonically increase (but it can wrap around)
seq_num = chunk.SequenceHeader.SequenceNumber
if self._peer_sequence_number is not None:
......@@ -383,7 +376,7 @@ class SecureConnection:
wrap_limit = (1 << 32) - 1024
if seq_num < 1024 and self._peer_sequence_number >= wrap_limit:
# The sequence number has wrapped around. See spec. part 6, 6.7.2
_logger.debug('Sequence number wrapped: %d -> %d', self._peer_sequence_number, seq_num)
_logger.debug("Sequence number wrapped: %d -> %d", self._peer_sequence_number, seq_num)
else:
# Condition for monotonically increase is not met
raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:" f" {self._peer_sequence_number}, received: {seq_num}," f" spec says to close connection")
......
......@@ -13,9 +13,7 @@ from .node_factory import make_node
_logger = logging.getLogger(__name__)
async def copy_node(
parent: asyncua.Node, node: asyncua.Node, nodeid: Optional[ua.NodeId] = None, recursive: bool = True
) -> List[asyncua.Node]:
async def copy_node(parent: asyncua.Node, node: asyncua.Node, nodeid: Optional[ua.NodeId] = None, recursive: bool = True) -> List[asyncua.Node]:
"""
Copy a node or node tree as child of parent node
"""
......@@ -42,17 +40,20 @@ async def _copy_node(session: AbstractSession, parent_nodeid: ua.NodeId, rdesc:
if recursive:
descs = await node_to_copy.get_children_descriptions()
for desc in descs:
nodes = await _copy_node(session, res.AddedNodeId, desc,
nodeid=ua.NodeId(NamespaceIndex=desc.NodeId.NamespaceIndex), recursive=True)
nodes = await _copy_node(session, res.AddedNodeId, desc, nodeid=ua.NodeId(NamespaceIndex=desc.NodeId.NamespaceIndex), recursive=True)
added_nodes.extend(nodes)
return added_nodes
async def _rdesc_from_node(parent: asyncua.Node, node: asyncua.Node) -> ua.ReferenceDescription:
results = await node.read_attributes([
ua.AttributeIds.NodeClass, ua.AttributeIds.BrowseName, ua.AttributeIds.DisplayName,
])
results = await node.read_attributes(
[
ua.AttributeIds.NodeClass,
ua.AttributeIds.BrowseName,
ua.AttributeIds.DisplayName,
]
)
variants: List[ua.Variant] = []
for res in results:
res.StatusCode.check()
......@@ -76,9 +77,20 @@ async def _rdesc_from_node(parent: asyncua.Node, node: asyncua.Node) -> ua.Refer
async def _read_and_copy_attrs(node_type: asyncua.Node, struct: Any, addnode: ua.AddNodesItem) -> None:
names = [name for name in struct.__dict__.keys() if not name.startswith("_") and name not in (
"BodyLength", "TypeId", "SpecifiedAttributes", "Encoding", "IsAbstract", "EventNotifier",
)]
names = [
name
for name in struct.__dict__.keys()
if not name.startswith("_")
and name
not in (
"BodyLength",
"TypeId",
"SpecifiedAttributes",
"Encoding",
"IsAbstract",
"EventNotifier",
)
]
attrs = [getattr(ua.AttributeIds, name) for name in names]
results = await node_type.read_attributes(attrs)
for idx, name in enumerate(names):
......@@ -91,6 +103,5 @@ async def _read_and_copy_attrs(node_type: asyncua.Node, struct: Any, addnode: ua
else:
setattr(struct, name, variant.Value)
else:
_logger.warning("Instantiate: while copying attributes from node type %s,"
" attribute %s, statuscode is %s", str(node_type), str(name), str(results[idx].StatusCode))
_logger.warning("Instantiate: while copying attributes from node type %s," " attribute %s, statuscode is %s", str(node_type), str(name), str(results[idx].StatusCode))
addnode.NodeAttributes = struct
......@@ -4,6 +4,7 @@ from asyncua import ua
import asyncua
from ..ua.uaerrors import UaError
from .ua_utils import get_node_subtypes, is_subtype
if TYPE_CHECKING:
from asyncua.common.node import Node
......@@ -38,9 +39,8 @@ class Event:
self.internal_properties = list(self.__dict__.keys())[:] + ["internal_properties"]
def __str__(self):
return "{0}({1})".format(
self.__class__.__name__,
[str(k) + ":" + str(v) for k, v in self.__dict__.items() if k not in self.internal_properties])
return "{0}({1})".format(self.__class__.__name__, [str(k) + ":" + str(v) for k, v in self.__dict__.items() if k not in self.internal_properties])
__repr__ = __str__
def add_property(self, name, val, datatype):
......@@ -136,7 +136,7 @@ class Event:
iter_paths = iter(browsePath)
next(iter_paths)
for path in iter_paths:
name += '/' + path.Name
name += "/" + path.Name
return name
......@@ -149,7 +149,7 @@ async def get_filter_from_event_type(eventtypes: List["Node"], where_clause_gene
async def _append_new_attribute_to_select_clauses(select_clauses: List[ua.SimpleAttributeOperand], already_selected: Dict[str, str], browse_path: List[ua.QualifiedName]):
string_path = '/'.join(map(str, browse_path))
string_path = "/".join(map(str, browse_path))
if string_path not in already_selected:
already_selected[string_path] = string_path
op = ua.SimpleAttributeOperand()
......@@ -228,9 +228,7 @@ async def select_event_attributes_from_type_node(node: "Node", attributeSelector
attributes.extend(await attributeSelector(curr_node))
if curr_node.nodeid == ua.NodeId(ua.ObjectIds.BaseEventType):
break
parents = await curr_node.get_referenced_nodes(
refs=ua.ObjectIds.HasSubtype, direction=ua.BrowseDirection.Inverse
)
parents = await curr_node.get_referenced_nodes(refs=ua.ObjectIds.HasSubtype, direction=ua.BrowseDirection.Inverse)
if len(parents) != 1: # Something went wrong
return None
curr_node = parents[0]
......@@ -260,7 +258,6 @@ async def get_event_obj_from_type_node(node):
parent_nodeid, parent_eventtype = await _find_parent_eventtype(node)
class CustomEvent(parent_eventtype):
def __init__(self):
parent_eventtype.__init__(self)
self.EventType = node.nodeid
......@@ -269,7 +266,7 @@ async def get_event_obj_from_type_node(node):
name = (await property.read_browse_name()).Name
if parent_variable:
parent_name = (await parent_variable.read_browse_name()).Name
name = f'{parent_name}/{name}'
name = f"{parent_name}/{name}"
val = await property.read_data_value()
self.add_property(name, val.Value.Value, val.Value.VariantType)
......@@ -301,10 +298,9 @@ async def get_event_obj_from_type_node(node):
async def _find_parent_eventtype(node):
"""
"""
""" """
parents = await node.get_referenced_nodes(refs=ua.ObjectIds.HasSubtype, direction=ua.BrowseDirection.Inverse)
if len(parents) != 1: # Something went wrong
if len(parents) != 1: # Something went wrong
raise UaError("Parent of event type could not be found")
if parents[0].nodeid.NamespaceIndex == 0:
if parents[0].nodeid.Identifier in asyncua.common.event_objects.IMPLEMENTED_EVENTS.keys():
......
......@@ -51,27 +51,11 @@ async def instantiate(
elif isinstance(bname, str):
bname = ua.QualifiedName.from_string(bname)
nodeids = await _instantiate_node(
parent.session,
make_node(parent.session, rdesc.NodeId),
parent.nodeid,
rdesc,
nodeid,
bname,
dname=dname,
instantiate_optional=instantiate_optional)
nodeids = await _instantiate_node(parent.session, make_node(parent.session, rdesc.NodeId), parent.nodeid, rdesc, nodeid, bname, dname=dname, instantiate_optional=instantiate_optional)
return [make_node(parent.session, nid) for nid in nodeids]
async def _instantiate_node(session,
node_type,
parentid,
rdesc,
nodeid,
bname,
dname=None,
recursive=True,
instantiate_optional=True):
async def _instantiate_node(session, node_type, parentid, rdesc, nodeid, bname, dname=None, recursive=True, instantiate_optional=True):
"""
instantiate a node type under parent
"""
......@@ -116,41 +100,19 @@ async def _instantiate_node(session,
refs = await c_node_type.get_referenced_nodes(refs=ua.ObjectIds.HasModellingRule)
if not refs:
# spec says to ignore nodes without modelling rules
_logger.info(
"Instantiate: Skip node without modelling rule %s as part of %s",
c_rdesc.BrowseName, addnode.BrowseName
)
_logger.info("Instantiate: Skip node without modelling rule %s as part of %s", c_rdesc.BrowseName, addnode.BrowseName)
continue
# exclude nodes with optional ModellingRule if requested
if refs[0].nodeid in (ua.NodeId(ua.ObjectIds.ModellingRule_Optional), ua.NodeId(ua.ObjectIds.ModellingRule_OptionalPlaceholder)):
# instatiate optionals
if not instantiate_optional:
_logger.info(
"Instantiate: Skip optional node %s as part of %s",
c_rdesc.BrowseName, addnode.BrowseName
)
_logger.info("Instantiate: Skip optional node %s as part of %s", c_rdesc.BrowseName, addnode.BrowseName)
continue
# if root node being instantiated has a String NodeId, create the children with a String NodeId
if res.AddedNodeId.NodeIdType is ua.NodeIdType.String:
inst_nodeid = res.AddedNodeId.Identifier + "." + c_rdesc.BrowseName.Name
nodeids = await _instantiate_node(
session,
c_node_type,
res.AddedNodeId,
c_rdesc,
nodeid=ua.NodeId(Identifier=inst_nodeid, NamespaceIndex=res.AddedNodeId.NamespaceIndex),
bname=c_rdesc.BrowseName,
instantiate_optional=instantiate_optional
)
nodeids = await _instantiate_node(session, c_node_type, res.AddedNodeId, c_rdesc, nodeid=ua.NodeId(Identifier=inst_nodeid, NamespaceIndex=res.AddedNodeId.NamespaceIndex), bname=c_rdesc.BrowseName, instantiate_optional=instantiate_optional)
else:
nodeids = await _instantiate_node(
session,
c_node_type,
res.AddedNodeId,
c_rdesc,
nodeid=ua.NodeId(NamespaceIndex=res.AddedNodeId.NamespaceIndex),
bname=c_rdesc.BrowseName,
instantiate_optional=instantiate_optional
)
nodeids = await _instantiate_node(session, c_node_type, res.AddedNodeId, c_rdesc, nodeid=ua.NodeId(NamespaceIndex=res.AddedNodeId.NamespaceIndex), bname=c_rdesc.BrowseName, instantiate_optional=instantiate_optional)
added_nodes.extend(nodeids)
return added_nodes
"""
High level functions to create nodes
"""
from __future__ import annotations
import logging
......@@ -39,10 +40,7 @@ def _parse_nodeid_qname(*args):
except ua.UaError:
raise
except Exception as ex:
raise TypeError(
f"This method takes either a namespace index and a string as argument or a nodeid and a qualifiedname."
f" Received arguments {args} and got exception {ex}"
)
raise TypeError(f"This method takes either a namespace index and a string as argument or a nodeid and a qualifiedname." f" Received arguments {args} and got exception {ex}")
async def create_folder(parent: asyncua.Node, nodeid: Union[ua.NodeId, str, int], bname: Union[ua.QualifiedName, str]) -> asyncua.Node:
......@@ -52,10 +50,7 @@ async def create_folder(parent: asyncua.Node, nodeid: Union[ua.NodeId, str, int]
or namespace index, name
"""
nodeid, qname = _parse_nodeid_qname(nodeid, bname)
return make_node(
parent.session,
await _create_object(parent.session, parent.nodeid, nodeid, qname, ua.ObjectIds.FolderType)
)
return make_node(parent.session, await _create_object(parent.session, parent.nodeid, nodeid, qname, ua.ObjectIds.FolderType))
async def create_object(
......@@ -78,10 +73,7 @@ async def create_object(
nodes = await instantiate(parent, objecttype_node, nodeid, bname=qname, dname=dname, instantiate_optional=instantiate_optional)
return nodes[0]
else:
return make_node(
parent.session,
await _create_object(parent.session, parent.nodeid, nodeid, qname, ua.ObjectIds.BaseObjectType)
)
return make_node(parent.session, await _create_object(parent.session, parent.nodeid, nodeid, qname, ua.ObjectIds.BaseObjectType))
async def create_property(
......@@ -103,10 +95,7 @@ async def create_property(
datatype = ua.NodeId(datatype, 0)
if datatype and not isinstance(datatype, ua.NodeId):
raise RuntimeError("datatype argument must be a nodeid or an int refering to a nodeid")
return make_node(
parent.session,
await _create_variable(parent.session, parent.nodeid, nodeid, qname, var, datatype=datatype, isproperty=True)
)
return make_node(parent.session, await _create_variable(parent.session, parent.nodeid, nodeid, qname, var, datatype=datatype, isproperty=True))
async def create_variable(
......@@ -129,15 +118,10 @@ async def create_variable(
if datatype and not isinstance(datatype, ua.NodeId):
raise RuntimeError("datatype argument must be a nodeid or an int refering to a nodeid")
return make_node(
parent.session,
await _create_variable(parent.session, parent.nodeid, nodeid, qname, var, datatype=datatype, isproperty=False)
)
return make_node(parent.session, await _create_variable(parent.session, parent.nodeid, nodeid, qname, var, datatype=datatype, isproperty=False))
async def create_variable_type(
parent: asyncua.Node, nodeid: Union[ua.NodeId, str, int], bname: Union[ua.QualifiedName, str], datatype: Union[ua.NodeId, int]
) -> asyncua.Node:
async def create_variable_type(parent: asyncua.Node, nodeid: Union[ua.NodeId, str, int], bname: Union[ua.QualifiedName, str], datatype: Union[ua.NodeId, int]) -> asyncua.Node:
"""
Create a new variable type
args are nodeid, browsename and datatype
......@@ -147,27 +131,18 @@ async def create_variable_type(
if datatype and isinstance(datatype, int):
datatype = ua.NodeId(datatype, 0)
if datatype and not isinstance(datatype, ua.NodeId):
raise RuntimeError(
f"Data type argument must be a nodeid or an int refering to a nodeid, received: {datatype}")
return make_node(
parent.session,
await _create_variable_type(parent.session, parent.nodeid, nodeid, qname, datatype)
)
raise RuntimeError(f"Data type argument must be a nodeid or an int refering to a nodeid, received: {datatype}")
return make_node(parent.session, await _create_variable_type(parent.session, parent.nodeid, nodeid, qname, datatype))
async def create_reference_type(
parent: asyncua.Node, nodeid: Union[ua.NodeId, str, int], bname: Union[ua.QualifiedName, str, int], symmetric: bool = True, inversename: Optional[str] = None
) -> asyncua.Node:
async def create_reference_type(parent: asyncua.Node, nodeid: Union[ua.NodeId, str, int], bname: Union[ua.QualifiedName, str, int], symmetric: bool = True, inversename: Optional[str] = None) -> asyncua.Node:
"""
Create a new reference type
args are nodeid and browsename
or idx and name
"""
nodeid, qname = _parse_nodeid_qname(nodeid, bname)
return make_node(
parent.session,
await _create_reference_type(parent.session, parent.nodeid, nodeid, qname, symmetric, inversename)
)
return make_node(parent.session, await _create_reference_type(parent.session, parent.nodeid, nodeid, qname, symmetric, inversename))
async def create_object_type(parent: asyncua.Node, nodeid: Union[ua.NodeId, str, int], bname: Union[ua.QualifiedName, str]):
......@@ -190,7 +165,7 @@ async def create_method(parent: asyncua.Node, *args) -> asyncua.Node:
a callback is a method accepting the nodeid of the parent as first argument and variants after.
returns a list of variants
"""
_logger.info('create_method %r', parent)
_logger.info("create_method %r", parent)
nodeid, qname = _parse_nodeid_qname(*args[:2])
callback = args[2]
if len(args) > 3:
......@@ -337,9 +312,7 @@ async def _create_variable_type(session, parentnodeid, nodeid, qname, datatype,
return results[0].AddedNodeId
async def create_data_type(
parent: asyncua.Node, nodeid: Union[ua.NodeId, str, int], bname: Union[ua.QualifiedName, str], description: Optional[str] = None
) -> asyncua.Node:
async def create_data_type(parent: asyncua.Node, nodeid: Union[ua.NodeId, str, int], bname: Union[ua.QualifiedName, str], description: Optional[str] = None) -> asyncua.Node:
"""
Create a new data type to be used in new variables, etc ..
arguments are nodeid, browsename
......@@ -431,24 +404,10 @@ async def _create_method(parent, nodeid, qname, callback, inputs, outputs):
results[0].StatusCode.check()
method = make_node(parent.session, results[0].AddedNodeId)
if inputs:
prob = await create_property(
method,
ua.NodeId(NamespaceIndex=method.nodeid.NamespaceIndex),
ua.QualifiedName("InputArguments", 0),
[_vtype_to_argument(vtype) for vtype in inputs],
varianttype=ua.VariantType.ExtensionObject,
datatype=ua.ObjectIds.Argument
)
prob = await create_property(method, ua.NodeId(NamespaceIndex=method.nodeid.NamespaceIndex), ua.QualifiedName("InputArguments", 0), [_vtype_to_argument(vtype) for vtype in inputs], varianttype=ua.VariantType.ExtensionObject, datatype=ua.ObjectIds.Argument)
await prob.set_modelling_rule(True)
if outputs:
prob = await create_property(
method,
ua.NodeId(NamespaceIndex=method.nodeid.NamespaceIndex),
ua.QualifiedName("OutputArguments", 0),
[_vtype_to_argument(vtype) for vtype in outputs],
varianttype=ua.VariantType.ExtensionObject,
datatype=ua.ObjectIds.Argument
)
prob = await create_property(method, ua.NodeId(NamespaceIndex=method.nodeid.NamespaceIndex), ua.QualifiedName("OutputArguments", 0), [_vtype_to_argument(vtype) for vtype in outputs], varianttype=ua.VariantType.ExtensionObject, datatype=ua.ObjectIds.Argument)
await prob.set_modelling_rule(True)
if hasattr(parent.session, "add_method_callback"):
parent.session.add_method_callback(method.nodeid, callback)
......@@ -498,9 +457,7 @@ def _guess_datatype(variant: ua.Variant):
return ua.NodeId(getattr(ua.ObjectIds, variant.VariantType.name))
async def delete_nodes(
session: AbstractSession, nodes: Iterable[asyncua.Node], recursive: bool = False, delete_target_references: bool = True
) -> Tuple[List[asyncua.Node], List[ua.StatusCode]]:
async def delete_nodes(session: AbstractSession, nodes: Iterable[asyncua.Node], recursive: bool = False, delete_target_references: bool = True) -> Tuple[List[asyncua.Node], List[ua.StatusCode]]:
"""
Delete specified nodes. Optionally delete recursively all nodes with a
downward hierachic references to the node
......
......@@ -42,7 +42,7 @@ async def call_method_full(parent: asyncua.Node, methodid: Union[ua.NodeId, ua.Q
"""
if isinstance(methodid, (str, ua.uatypes.QualifiedName)):
methodid = (await parent.get_child(methodid)).nodeid
elif hasattr(methodid, 'nodeid'):
elif hasattr(methodid, "nodeid"):
methodid = methodid.nodeid
result = await _call_method(parent.session, parent.nodeid, methodid, to_variant(*args))
......@@ -78,6 +78,7 @@ def uamethod(func):
"""
if iscoroutinefunction(func):
@wraps(func)
async def wrapper(parent, *args):
func_args = _format_call_inputs(parent, *args)
......@@ -85,11 +86,13 @@ def uamethod(func):
return _format_call_outputs(result)
else:
@wraps(func)
def wrapper(parent, *args):
func_args = _format_call_inputs(parent, *args)
result = func(*func_args)
return _format_call_outputs(result)
return wrapper
......
This diff is collapsed.
......@@ -11,4 +11,5 @@ def make_node(session: AbstractSession, nodeid: ua.NodeId) -> asyncua.Node:
Needed no break cyclical import of `Node`
"""
from .node import Node
return Node(session, nodeid)
This diff is collapsed.
......@@ -146,20 +146,19 @@ sqlite3_keywords = [
"WHERE",
"WINDOW",
"WITH",
"WITHOUT"
"WITHOUT",
]
class SqlInjectionError(Exception):
"""Raised, if a sql injection is detected."""
pass
def validate_table_name(table_name: str) -> None:
"""Checks wether the sql table name is valid or not."""
not_allowed_characters = [' ', ';', ',', '(', ')', '[', ']', '"', "'"]
not_allowed_characters = [" ", ";", ",", "(", ")", "[", "]", '"', "'"]
for character in table_name:
if character in not_allowed_characters:
raise SqlInjectionError(
f'table_name: {table_name} contains invalid character: {character}'
)
raise SqlInjectionError(f"table_name: {table_name} contains invalid character: {character}")
This diff is collapsed.
......@@ -6,6 +6,7 @@ for custom structures
import uuid
import logging
# The next two imports are for generated code
from datetime import datetime, timezone
from enum import IntEnum, EnumMeta
......@@ -88,17 +89,17 @@ class {self.name}:
field.uatype = "UInt32"
self.fields = [field] + self.fields
for sfield in self.fields:
if sfield.name != 'SwitchField':
'''
if sfield.name != "SwitchField":
"""
SwitchFields is the 'Encoding' Field in OptionSets to be
compatible with 1.04 structs we added
the 'Encoding' Field before and skip the SwitchField Field
'''
"""
uatype = f"'ua.{sfield.uatype}'"
if sfield.array:
uatype = f"List[{uatype}]"
if uatype == 'List[ua.Char]':
uatype = 'String'
if uatype == "List[ua.Char]":
uatype = "String"
if sfield.is_optional:
code += f" {sfield.name}: Optional[{uatype}] = None\n"
else:
......@@ -173,12 +174,12 @@ class StructGenerator:
_type = xmlfield.get("TypeName")
if ":" in _type:
_type = _type.split(":")[1]
if _type == 'Bit':
if _type == "Bit":
# Bits are used for bit fields and filler ignore
continue
field = Field(_clean_name)
field.uatype = clean_name(_type)
if xmlfield.get("SwitchField", '') != '':
if xmlfield.get("SwitchField", "") != "":
# Optional Field
field.is_optional = True
struct.option_counter += 1
......@@ -204,8 +205,7 @@ class StructGenerator:
for struct in self.model:
if isinstance(struct, EnumType):
continue # No registration required for enums
code += f"ua.register_extension_object('{struct.name}'," \
f" ua.NodeId.from_string('{struct.typeid}'), {struct.name})\n"
code += f"ua.register_extension_object('{struct.name}'," f" ua.NodeId.from_string('{struct.typeid}'), {struct.name})\n"
return code
def get_python_classes(self, env=None):
......@@ -297,22 +297,22 @@ def _generate_python_class(model, env=None):
env = ua.__dict__
# Add the required libraries to dict
if "ua" not in env:
env['ua'] = ua
env["ua"] = ua
if "datetime" not in env:
env['datetime'] = datetime
env['timezone'] = timezone
env["datetime"] = datetime
env["timezone"] = timezone
if "uuid" not in env:
env['uuid'] = uuid
env["uuid"] = uuid
if "enum" not in env:
env['IntEnum'] = IntEnum
env["IntEnum"] = IntEnum
if "dataclass" not in env:
env['dataclass'] = dataclass
env["dataclass"] = dataclass
if "field" not in env:
env['field'] = field
env["field"] = field
if "List" not in env:
env['List'] = List
if 'Optional' not in env:
env['Optional'] = Optional
env["List"] = List
if "Optional" not in env:
env["Optional"] = Optional
# generate classes one by one and add them to dict
for element in model:
code = element.get_code()
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -133,7 +133,7 @@ def string_to_val(string, vtype):
elif vtype == ua.VariantType.Guid:
val = uuid.UUID(string)
elif issubclass(vtype, Enum):
enum_int = int(string.rsplit('_', 1)[1])
enum_int = int(string.rsplit("_", 1)[1])
val = vtype(enum_int)
else:
# FIXME: Some types are probably missing!
......@@ -203,9 +203,7 @@ async def get_node_supertype(node):
"""
return node supertype or None
"""
supertypes = await node.get_referenced_nodes(
refs=ua.ObjectIds.HasSubtype, direction=ua.BrowseDirection.Inverse
)
supertypes = await node.get_referenced_nodes(refs=ua.ObjectIds.HasSubtype, direction=ua.BrowseDirection.Inverse)
if supertypes:
return supertypes[0]
return None
......@@ -292,10 +290,7 @@ async def get_nodes_of_namespace(server, namespaces=None):
namespace_indexes = [n if isinstance(n, int) else ns_available.index(n) for n in namespaces]
# filter node is based on the provided namespaces and convert the nodeid to a node
nodes = [
server.get_node(nodeid) for nodeid in server.iserver.aspace.keys()
if nodeid.NamespaceIndex != 0 and nodeid.NamespaceIndex in namespace_indexes
]
nodes = [server.get_node(nodeid) for nodeid in server.iserver.aspace.keys() if nodeid.NamespaceIndex != 0 and nodeid.NamespaceIndex in namespace_indexes]
return nodes
......
......@@ -2,6 +2,7 @@
Helper function and classes that do not rely on asyncua library.
Helper function and classes depending on ua object are in ua_utils.py
"""
import asyncio
import logging
import os
......@@ -16,7 +17,7 @@ _logger = logging.getLogger(__name__)
class ServiceError(UaError):
def __init__(self, code):
super().__init__('UA Service Error')
super().__init__("UA Service Error")
self.code = code
......@@ -43,6 +44,7 @@ class Buffer:
def __str__(self):
return f"Buffer(size:{self._size}, data:{self._data[self._cur_pos:self._cur_pos + self._size]})"
__repr__ = __str__
def __len__(self):
......@@ -53,7 +55,7 @@ class Buffer:
def __bytes__(self):
"""Return remains of buffer as bytes."""
return bytes(self._data[self._cur_pos:])
return bytes(self._data[self._cur_pos :])
def read(self, size):
"""
......@@ -64,7 +66,7 @@ class Buffer:
self._size -= size
pos = self._cur_pos
self._cur_pos += size
return self._data[pos:self._cur_pos]
return self._data[pos : self._cur_pos]
def copy(self, size=-1):
"""
......@@ -113,17 +115,10 @@ def fields_with_resolved_types(
fields_ = fields(class_or_instance)
if sys.version_info.major == 3 and sys.version_info.minor <= 8:
resolved_fieldtypes = get_type_hints(
class_or_instance,
globalns=globalns,
localns=localns
)
resolved_fieldtypes = get_type_hints(class_or_instance, globalns=globalns, localns=localns)
else:
resolved_fieldtypes = get_type_hints( # type: ignore[call-arg]
class_or_instance,
globalns=globalns,
localns=localns,
include_extras=include_extras
class_or_instance, globalns=globalns, localns=localns, include_extras=include_extras
)
for field in fields_:
try:
......@@ -135,7 +130,7 @@ def fields_with_resolved_types(
return fields_
_T = TypeVar('_T')
_T = TypeVar("_T")
async def wait_for(aw: Awaitable[_T], timeout: Union[int, float, None]) -> _T:
......@@ -151,4 +146,5 @@ async def wait_for(aw: Awaitable[_T], timeout: Union[int, float, None]) -> _T:
return await asyncio.wait_for(aw, timeout)
import wait_for2
return await wait_for2.wait_for(aw, timeout)
This diff is collapsed.
This diff is collapsed.
"""
parse xml file from asyncua-spec
"""
import re
import asyncio
import base64
......@@ -34,7 +35,6 @@ def _to_bool(val):
class NodeData:
def __init__(self):
self.nodetype = None
self.nodeid = None
......@@ -89,7 +89,6 @@ class Field:
class RefStruct:
def __init__(self):
self.reftype = None
self.forward = True
......@@ -102,7 +101,6 @@ class RefStruct:
class ExtObj:
def __init__(self):
self.typeid = None
self.objname = None
......@@ -116,17 +114,11 @@ class ExtObj:
class XMLParser:
def __init__(self):
self.logger = logging.getLogger(__name__)
self._retag = re.compile(r"(\{.*\})(.*)")
self.root = None
self.ns = {
'base': "http://opcfoundation.org/UA/2011/03/UANodeSet.xsd",
'uax': "http://opcfoundation.org/UA/2008/02/Types.xsd",
'xsd': "http://www.w3.org/2001/XMLSchema",
'xsi': "http://www.w3.org/2001/XMLSchema-instance"
}
self.ns = {"base": "http://opcfoundation.org/UA/2011/03/UANodeSet.xsd", "uax": "http://opcfoundation.org/UA/2008/02/Types.xsd", "xsd": "http://www.w3.org/2001/XMLSchema", "xsi": "http://www.w3.org/2001/XMLSchema-instance"}
async def parse(self, xmlpath=None, xmlstring=None):
if xmlstring:
......@@ -149,7 +141,7 @@ class XMLParser:
namespaces_uris = []
for child in self.root:
tag = self._retag.match(child.tag).groups()[1]
if tag == 'NamespaceUris':
if tag == "NamespaceUris":
namespaces_uris = [ns_element.text for ns_element in child]
break
return namespaces_uris
......@@ -161,7 +153,7 @@ class XMLParser:
aliases = {}
for child in self.root:
tag = self._retag.match(child.tag).groups()[1]
if tag == 'Aliases':
if tag == "Aliases":
for el in child:
aliases[el.attrib["Alias"]] = el.text
break
......@@ -370,11 +362,11 @@ class XMLParser:
ext = ExtObj()
for extension_object_part in el:
ntag = self._retag.match(extension_object_part.tag).groups()[1]
if ntag == 'TypeId':
ntag = self._retag.match(extension_object_part.find('*').tag).groups()[1]
if ntag == "TypeId":
ntag = self._retag.match(extension_object_part.find("*").tag).groups()[1]
ext.typeid = self._get_text(extension_object_part)
elif ntag == 'Body':
ext.objname = self._retag.match(extension_object_part.find('*').tag).groups()[1]
elif ntag == "Body":
ext.objname = self._retag.match(extension_object_part.find("*").tag).groups()[1]
ext.body = self._parse_body(extension_object_part)
else:
self.logger.warning("Unknown ntag", ntag)
......@@ -452,11 +444,11 @@ class XMLParser:
Get all namespaces that are registered with version and date_time
"""
ns = []
for model in self.root.findall('base:Models/base:Model', self.ns):
uri = model.attrib.get('ModelUri')
for model in self.root.findall("base:Models/base:Model", self.ns):
uri = model.attrib.get("ModelUri")
if uri is not None:
version = model.attrib.get('Version', '')
date_time = model.attrib.get('PublicationDate')
version = model.attrib.get("Version", "")
date_time = model.attrib.get("PublicationDate")
if date_time is None:
date_time = ua.DateTime(1, 1, 1)
elif date_time is not None and date_time[-1] == "Z":
......
This diff is collapsed.
......@@ -55,11 +55,7 @@ class SimpleRoleRuleset(PermissionRuleset):
def __init__(self):
admin_ids = list(map(ua.NodeId, ADMIN_TYPES))
user_ids = list(map(ua.NodeId, USER_TYPES))
self._permission_dict = {
UserRole.Admin: set().union(admin_ids, user_ids),
UserRole.User: set().union(user_ids),
UserRole.Anonymous: set()
}
self._permission_dict = {UserRole.Admin: set().union(admin_ids, user_ids), UserRole.User: set().union(user_ids), UserRole.Anonymous: set()}
def check_validity(self, user, action_type_id, body):
if action_type_id in self._permission_dict[user.role]:
......
This diff is collapsed.
'''
"""
Functionality for checking a certificate based on:
- trusted (ca) certificates
- crl
Use of cryptography module is prefered, but doesn't provide functionality for truststores yet, so for some we rely on using pyOpenSSL.
'''
"""
from typing import List
from pathlib import Path
import re
......@@ -19,13 +20,13 @@ _logger = logging.getLogger("asyncuagds.validate")
class TrustStore:
'''
"""
TrustStore is used to validate certificates in two ways:
- Based on being absent in provided certificate revocation lists
- The certificate or its issuer being present in a list of trusted certificates
It doesn't check other content of extensions of the certificate
'''
"""
def __init__(self, trust_locations: List[Path], crl_locations: List[Path]):
"""Constructor of the TrustStore
......@@ -68,7 +69,7 @@ class TrustStore:
await self._load_crl_location(location)
def validate(self, certificate: x509.Certificate) -> bool:
""" Validates if a certificate is trusted, not revoked and lays in valid datarange
"""Validates if a certificate is trusted, not revoked and lays in valid datarange
Args:
certificate (x509.Certificate): Certificate to check
......@@ -80,7 +81,7 @@ class TrustStore:
return self.is_trusted(certificate) and self.is_revoked(certificate) is False and self.check_date_range(certificate)
def check_date_range(self, certificate: x509.Certificate) -> bool:
""" Checks if the certificate not_valid_before_utc and not_valid_after_utc are valid.
"""Checks if the certificate not_valid_before_utc and not_valid_after_utc are valid.
Args:
certificate (x509.Certificate): Certificate to check
......@@ -91,15 +92,15 @@ class TrustStore:
valid: bool = True
now = datetime.now(timezone.utc)
if certificate.not_valid_after_utc < now:
_logger.error('certificate is no longer valid: valid until %s', certificate.not_valid_after_utc)
_logger.error("certificate is no longer valid: valid until %s", certificate.not_valid_after_utc)
valid = False
if certificate.not_valid_before_utc > now:
_logger.error('certificate is not yet vaild: valid after %s', certificate.not_valid_before_utc)
_logger.error("certificate is not yet vaild: valid after %s", certificate.not_valid_before_utc)
valid = False
return valid
def is_revoked(self, certificate: x509.Certificate) -> bool:
""" Check if the provided certifcate is in the revocation lists
"""Check if the provided certifcate is in the revocation lists
when not CRLs are present it the certificate is considere not revoked.
......@@ -119,7 +120,7 @@ class TrustStore:
return is_revoked
def is_trusted(self, certificate: x509.Certificate) -> bool:
""" Check if the provided certifcate is considered trusted
"""Check if the provided certifcate is considered trusted
For a self-signed to be trusted is must be placed in the trusted location
Args:
certificate (x509.Certificate): Certificate to check
......@@ -132,7 +133,7 @@ class TrustStore:
store_ctx = crypto.X509StoreContext(self._trust_store, _certificate)
try:
store_ctx.verify_certificate()
_logger.debug('Use trusted certificate : \'%s\'', _certificate.get_subject().CN)
_logger.debug("Use trusted certificate : '%s'", _certificate.get_subject().CN)
return True
except crypto.X509StoreContextError as exp:
print(exp)
......@@ -145,10 +146,10 @@ class TrustStore:
Args:
location (Path): location to scan for certificates
"""
files = Path(location).glob('*.*')
files = Path(location).glob("*.*")
for file_name in files:
if re.match('.*(der|pem)', file_name.name.lower()):
_logger.debug('Add certificate to TrustStore : \'%s\'', file_name)
if re.match(".*(der|pem)", file_name.name.lower()):
_logger.debug("Add certificate to TrustStore : '%s'", file_name)
trusted_cert: crypto.X509 = crypto.X509.from_cryptography(await load_certificate(file_name))
self._trust_store.add_cert(trusted_cert)
......@@ -158,16 +159,16 @@ class TrustStore:
Args:
location (Path): location to scan for crls
"""
files = Path(location).glob('*.*')
files = Path(location).glob("*.*")
for file_name in files:
if re.match('.*(der|pem)', file_name.name.lower()):
_logger.debug('Add CRL to list : \'%s\'', file_name)
if re.match(".*(der|pem)", file_name.name.lower()):
_logger.debug("Add CRL to list : '%s'", file_name)
crl = await self._load_crl(file_name)
for revoked in crl:
self._revoked_list.append(revoked)
@ staticmethod
@staticmethod
async def _load_crl(crl_file_name: Path) -> x509.CertificateRevocationList:
"""Load a single crl from file
......@@ -178,7 +179,7 @@ class TrustStore:
x509.CertificateRevocationList: Return loaded CRL
"""
content = await get_content(crl_file_name)
if crl_file_name.suffix.lower() == '.der':
if crl_file_name.suffix.lower() == ".der":
return x509.load_der_x509_crl(content)
return x509.load_pem_x509_crl(content)
......@@ -18,6 +18,7 @@ from cryptography.exceptions import InvalidSignature # noqa: F811
from dataclasses import dataclass
import logging
_logger = logging.getLogger(__name__)
......@@ -36,7 +37,7 @@ async def get_content(path_or_content: Union[str, bytes, Path]) -> bytes:
if isinstance(path_or_content, bytes):
return path_or_content
async with aiofiles.open(path_or_content, mode='rb') as f:
async with aiofiles.open(path_or_content, mode="rb") as f:
return await f.read()
......@@ -46,10 +47,10 @@ async def load_certificate(path_or_content: Union[bytes, str, Path], extension:
elif isinstance(path_or_content, Path):
ext = path_or_content.suffix
else:
ext = ''
ext = ""
content = await get_content(path_or_content)
if ext == ".pem" or extension == 'pem' or extension == 'PEM':
if ext == ".pem" or extension == "pem" or extension == "PEM":
return x509.load_pem_x509_certificate(content, default_backend())
else:
return x509.load_der_x509_certificate(content, default_backend())
......@@ -61,20 +62,18 @@ def x509_from_der(data):
return x509.load_der_x509_certificate(data, default_backend())
async def load_private_key(path_or_content: Union[str, Path, bytes],
password: Optional[Union[str, bytes]] = None,
extension: Optional[str] = None):
async def load_private_key(path_or_content: Union[str, Path, bytes], password: Optional[Union[str, bytes]] = None, extension: Optional[str] = None):
if isinstance(path_or_content, str):
ext = Path(path_or_content).suffix
elif isinstance(path_or_content, Path):
ext = path_or_content.suffix
else:
ext = ''
ext = ""
if isinstance(password, str):
password = password.encode('utf-8')
password = password.encode("utf-8")
content = await get_content(path_or_content)
if ext == ".pem" or extension == 'pem' or extension == 'PEM':
if ext == ".pem" or extension == "pem" or extension == "PEM":
return serialization.load_pem_private_key(content, password=password, backend=default_backend())
else:
return serialization.load_der_private_key(content, password=password, backend=default_backend())
......@@ -99,131 +98,76 @@ def pem_from_key(private_key: rsa.RSAPrivateKey) -> bytes:
def sign_sha1(private_key, data):
return private_key.sign(
data,
padding.PKCS1v15(),
hashes.SHA1()
)
return private_key.sign(data, padding.PKCS1v15(), hashes.SHA1())
def sign_sha256(private_key, data):
return private_key.sign(
data,
padding.PKCS1v15(),
hashes.SHA256()
)
return private_key.sign(data, padding.PKCS1v15(), hashes.SHA256())
def sign_pss_sha256(private_key, data):
return private_key.sign(
data,
padding.PSS(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
salt_length=32
),
padding.PSS(mgf=padding.MGF1(algorithm=hashes.SHA256()), salt_length=32),
hashes.SHA256(),
)
def verify_sha1(certificate, data, signature):
certificate.public_key().verify(
signature,
data,
padding.PKCS1v15(),
hashes.SHA1()
)
certificate.public_key().verify(signature, data, padding.PKCS1v15(), hashes.SHA1())
def verify_sha256(certificate, data, signature):
certificate.public_key().verify(
signature,
data,
padding.PKCS1v15(),
hashes.SHA256())
certificate.public_key().verify(signature, data, padding.PKCS1v15(), hashes.SHA256())
def verify_pss_sha256(certificate, data, signature):
certificate.public_key().verify(
signature,
data,
padding.PSS(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
salt_length=32
),
padding.PSS(mgf=padding.MGF1(algorithm=hashes.SHA256()), salt_length=32),
hashes.SHA256(),
)
def encrypt_basic256(public_key, data):
ciphertext = public_key.encrypt(
data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None)
)
ciphertext = public_key.encrypt(data, padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None))
return ciphertext
def encrypt_rsa_oaep(public_key, data):
ciphertext = public_key.encrypt(
data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(),
label=None)
)
ciphertext = public_key.encrypt(data, padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA1()), algorithm=hashes.SHA1(), label=None))
return ciphertext
def encrypt_rsa_oaep_sha256(public_key, data):
ciphertext = public_key.encrypt(
data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
),
padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
)
return ciphertext
def encrypt_rsa15(public_key, data):
ciphertext = public_key.encrypt(
data,
padding.PKCS1v15()
)
ciphertext = public_key.encrypt(data, padding.PKCS1v15())
return ciphertext
def decrypt_rsa_oaep(private_key, data):
text = private_key.decrypt(
bytes(data),
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA1()),
algorithm=hashes.SHA1(),
label=None)
)
text = private_key.decrypt(bytes(data), padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA1()), algorithm=hashes.SHA1(), label=None))
return text
def decrypt_rsa_oaep_sha256(private_key, data):
text = private_key.decrypt(
bytes(data),
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
),
padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None),
)
return text
def decrypt_rsa15(private_key, data):
text = private_key.decrypt(
bytes(data),
padding.PKCS1v15()
)
text = private_key.decrypt(bytes(data), padding.PKCS1v15())
return text
......@@ -272,7 +216,7 @@ def p_sha1(secret, seed, sizes=()):
for size in sizes:
full_size += size
result = b''
result = b""
accum = seed
while len(result) < full_size:
accum = hmac_sha1(secret, accum)
......@@ -295,7 +239,7 @@ def p_sha256(secret, seed, sizes=()):
for size in sizes:
full_size += size
result = b''
result = b""
accum = seed
while len(result) < full_size:
accum = hmac_sha256(secret, accum)
......@@ -310,7 +254,7 @@ def p_sha256(secret, seed, sizes=()):
def x509_name_to_string(name):
parts = [f"{attr.oid._name}={attr.value}" for attr in name]
return ', '.join(parts)
return ", ".join(parts)
def x509_to_string(cert):
......@@ -318,9 +262,9 @@ def x509_to_string(cert):
Convert x509 certificate to human-readable string
"""
if cert.subject == cert.issuer:
issuer = ' (self-signed)'
issuer = " (self-signed)"
else:
issuer = f', issuer: {x509_name_to_string(cert.issuer)}'
issuer = f", issuer: {x509_name_to_string(cert.issuer)}"
# TODO: show more information
return f"{x509_name_to_string(cert.subject)}{issuer}, {cert.not_valid_before_utc} - {cert.not_valid_after_utc}"
......@@ -341,14 +285,14 @@ def check_certificate(cert: x509.Certificate, application_uri: str, hostname: Op
san = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
san_uri = san.value.get_values_for_type(x509.UniformResourceIdentifier)
if application_uri not in san_uri:
_logger.warning('certificate does not contain the application uri (%s). Most applications will reject a connection without it.', application_uri)
_logger.warning("certificate does not contain the application uri (%s). Most applications will reject a connection without it.", application_uri)
err = True
if hostname is not None:
san_dns_names = san.value.get_values_for_type(x509.DNSName)
if hostname not in san_dns_names:
_logger.warning('certificate does not contain the hostname in DNSNames %s. Some applications will check this.', hostname)
_logger.warning("certificate does not contain the hostname in DNSNames %s. Some applications will check this.", hostname)
err = True
except x509.ExtensionNotFound:
_logger.warning('certificate has no SubjectAlternativeName this is need for application verification!')
_logger.warning("certificate has no SubjectAlternativeName this is need for application verification!")
err = True
return err
......@@ -14,6 +14,7 @@ _logger = logging.getLogger(__name__)
# Use for storing method that can validate a certificate on a create_session
CertificateValidatorMethod = Callable[[x509.Certificate, ApplicationDescription], Awaitable[None]]
class CertificateValidatorOptions(Flag):
"""
Flags for which certificate validation should be performed
......@@ -23,6 +24,7 @@ class CertificateValidatorOptions(Flag):
- EXT_VALIDATION
- TRUSTED_VALIDATION
"""
TIME_RANGE = auto()
URI = auto()
KEY_USAGE = auto()
......@@ -56,12 +58,12 @@ class CertificateValidator:
self._trust_store: Optional[TrustStore] = trust_store
def set_validate_options(self, options: CertificateValidatorOptions):
""" Change the use validation options at runtime"""
"""Change the use validation options at runtime"""
self._options = options
async def validate(self, cert: x509.Certificate, app_description: ua.ApplicationDescription):
""" Validate if a certificate is valid based on the validation options.
"""Validate if a certificate is valid based on the validation options.
When not valid is raises a ServiceError with an UA Result Code.
Args:
......@@ -90,12 +92,8 @@ class CertificateValidator:
if app_description.ApplicationUri not in san_uri:
raise ServiceError(ua.StatusCodes.BadCertificateUriInvalid)
if CertificateValidatorOptions.KEY_USAGE in self._options:
key_usage = cert.extensions.get_extension_for_class(x509.KeyUsage).value
if key_usage.data_encipherment is False or \
key_usage.digital_signature is False or \
key_usage.content_commitment is False or \
key_usage.key_encipherment is False:
if key_usage.data_encipherment is False or key_usage.digital_signature is False or key_usage.content_commitment is False or key_usage.key_encipherment is False:
raise ServiceError(ua.StatusCodes.BadCertificateUseNotAllowed)
if CertificateValidatorOptions.EXT_KEY_USAGE in self._options:
oid = ExtendedKeyUsageOID.SERVER_AUTH if CertificateValidatorOptions.PEER_SERVER in self._options else ExtendedKeyUsageOID.CLIENT_AUTH
......@@ -103,16 +101,13 @@ class CertificateValidator:
if oid not in cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value:
raise ServiceError(ua.StatusCodes.BadCertificateUseNotAllowed)
if CertificateValidatorOptions.PEER_SERVER in self._options and \
app_description.ApplicationType not in [ua.ApplicationType.Server, ua.ApplicationType.ClientAndServer]:
_logger.warning('mismatch between application type and certificate ExtendedKeyUsage')
if CertificateValidatorOptions.PEER_SERVER in self._options and app_description.ApplicationType not in [ua.ApplicationType.Server, ua.ApplicationType.ClientAndServer]:
_logger.warning("mismatch between application type and certificate ExtendedKeyUsage")
raise ServiceError(ua.StatusCodes.BadCertificateUseNotAllowed)
elif CertificateValidatorOptions.PEER_CLIENT in self._options and \
app_description.ApplicationType not in [ua.ApplicationType.Client, ua.ApplicationType.ClientAndServer]:
_logger.warning('mismatch between application type and certificate ExtendedKeyUsage')
elif CertificateValidatorOptions.PEER_CLIENT in self._options and app_description.ApplicationType not in [ua.ApplicationType.Client, ua.ApplicationType.ClientAndServer]:
_logger.warning("mismatch between application type and certificate ExtendedKeyUsage")
raise ServiceError(ua.StatusCodes.BadCertificateUseNotAllowed)
# if hostname is not None:
# san_dns_names = san.value.get_values_for_type(x509.DNSName)
# if hostname not in san_dns_names:
......@@ -121,7 +116,6 @@ class CertificateValidator:
raise ServiceError(ua.StatusCodes.BadCertificateInvalid) from exc
if CertificateValidatorOptions.TRUSTED in self._options or CertificateValidatorOptions.REVOKED in self._options:
if CertificateValidatorOptions.TRUSTED in self._options:
if self._trust_store and not self._trust_store.is_trusted(cert):
raise ServiceError(ua.StatusCodes.BadCertificateUntrusted)
......
This diff is collapsed.
"""
Socket server forwarding request to internal server
"""
import logging
import asyncio
from typing import Optional
......@@ -23,7 +24,7 @@ class OPCUAProtocol(asyncio.Protocol):
self.peer_name = None
self.transport = None
self.processor = None
self._buffer = b''
self._buffer = b""
self.iserver: InternalServer = iserver
self.policies = policies
self.clients = clients
......@@ -33,13 +34,13 @@ class OPCUAProtocol(asyncio.Protocol):
self._task = None
def __str__(self):
return f'OPCUAProtocol({self.peer_name}, {self.processor.session})'
return f"OPCUAProtocol({self.peer_name}, {self.processor.session})"
__repr__ = __str__
def connection_made(self, transport):
self.peer_name = transport.get_extra_info('peername')
_logger.info('New connection from %s', self.peer_name)
self.peer_name = transport.get_extra_info("peername")
_logger.info("New connection from %s", self.peer_name)
self.transport = transport
self.processor = UaProcessor(self.iserver, self.transport, self.limits)
self.processor.set_policies(self.policies)
......@@ -48,7 +49,7 @@ class OPCUAProtocol(asyncio.Protocol):
self._task = asyncio.create_task(self._process_received_message_loop())
def connection_lost(self, ex):
_logger.info('Lost connection from %s, %s', self.peer_name, ex)
_logger.info("Lost connection from %s, %s", self.peer_name, ex)
self.transport.close()
self.iserver.asyncio_transports.remove(self.transport)
closing_task = asyncio.create_task(self.processor.close())
......@@ -72,18 +73,18 @@ class OPCUAProtocol(asyncio.Protocol):
return
if header.header_size + header.body_size <= header.header_size:
# malformed header prevent invalid access of your buffer
_logger.error('Got malformed header %s', header)
_logger.error("Got malformed header %s", header)
self.transport.close()
return
else:
if len(buf) < header.body_size:
_logger.debug('We did not receive enough data from client. Need %s got %s', header.body_size, len(buf))
_logger.debug("We did not receive enough data from client. Need %s got %s", header.body_size, len(buf))
return
# we have a complete message
self.messages.put_nowait((header, buf))
self._buffer = self._buffer[(header.header_size + header.body_size):]
self._buffer = self._buffer[(header.header_size + header.body_size) :]
except Exception:
_logger.exception('Exception raised while parsing message from client')
_logger.exception("Exception raised while parsing message from client")
return
async def _process_received_message_loop(self):
......@@ -98,13 +99,13 @@ class OPCUAProtocol(asyncio.Protocol):
try:
await self._process_one_msg(header, buf)
except Exception:
_logger.exception('Exception raised while processing message from client')
_logger.exception("Exception raised while processing message from client")
async def _process_one_msg(self, header, buf):
_logger.debug('_process_received_message %s %s', header.body_size, len(buf))
_logger.debug("_process_received_message %s %s", header.body_size, len(buf))
ret = await self.processor.process(header, buf)
if not ret:
_logger.info('processor returned False, we close connection from %s', self.peer_name)
_logger.info("processor returned False, we close connection from %s", self.peer_name)
self.transport.close()
return
......@@ -127,13 +128,7 @@ class BinaryServer:
def _make_protocol(self):
"""Protocol Factory"""
return OPCUAProtocol(
iserver=self.iserver,
policies=self._policies,
clients=self.clients,
closing_tasks=self.closing_tasks,
limits=self.limits
)
return OPCUAProtocol(iserver=self.iserver, policies=self._policies, clients=self.clients, closing_tasks=self.closing_tasks, limits=self.limits)
async def start(self):
self._server = await asyncio.get_running_loop().create_server(self._make_protocol, self.hostname, self.port)
......@@ -145,11 +140,11 @@ class BinaryServer:
sockname = self._server.sockets[0].getsockname()
self.hostname = sockname[0]
self.port = sockname[1]
self.logger.info('Listening on %s:%s', self.hostname, self.port)
self.logger.info("Listening on %s:%s", self.hostname, self.port)
self.cleanup_task = asyncio.create_task(self._close_task_loop())
async def stop(self):
self.logger.info('Closing asyncio socket server')
self.logger.info("Closing asyncio socket server")
for transport in self.iserver.asyncio_transports:
transport.close()
......
......@@ -42,7 +42,7 @@ class EventGenerator:
self.event = await events.get_event_obj_from_type_node(node)
if isinstance(self.event, event_objects.Condition):
# Add ConditionId, which is not modelled as a component of the ConditionType
self.event.add_property('NodeId', None, ua.VariantType.NodeId)
self.event.add_property("NodeId", None, ua.VariantType.NodeId)
if isinstance(emitting_node, Node):
pass
......@@ -73,8 +73,7 @@ class EventGenerator:
result.check()
def __str__(self):
return f"EventGenerator(Type:{self.event.EventType}, Emitting Node:{self.event.emitting_node.to_string()}, " \
f"Time:{self.event.Time}, Message: {self.event.Message})"
return f"EventGenerator(Type:{self.event.EventType}, Emitting Node:{self.event.emitting_node.to_string()}, " f"Time:{self.event.Time}, Message: {self.event.Message})"
__repr__ = __str__
......@@ -82,7 +81,7 @@ class EventGenerator:
"""
Trigger the event. This will send a notification to all subscribed clients
"""
self.event.EventId = ua.Variant(uuid.uuid4().hex.encode('utf-8'), ua.VariantType.ByteString)
self.event.EventId = ua.Variant(uuid.uuid4().hex.encode("utf-8"), ua.VariantType.ByteString)
if time_attr:
self.event.Time = time_attr
else:
......
......@@ -23,6 +23,7 @@ class HistoryStorageInterface:
Interface of a history backend.
Must be implemented by backends
"""
def __init__(self, max_history_data_response_size=10000):
self.max_history_data_response_size = max_history_data_response_size
......@@ -132,31 +133,21 @@ class HistoryDict(HistoryStorageInterface):
if end is None:
end = ua.get_win_epoch()
if start == ua.get_win_epoch():
results = [
dv
for dv in reversed(self._datachanges[node_id])
if start <= dv.SourceTimestamp
]
results = [dv for dv in reversed(self._datachanges[node_id]) if start <= dv.SourceTimestamp]
elif end == ua.get_win_epoch():
results = [dv for dv in self._datachanges[node_id] if start <= dv.SourceTimestamp]
elif start > end:
results = [
dv
for dv in reversed(self._datachanges[node_id])
if end <= dv.SourceTimestamp <= start
]
results = [dv for dv in reversed(self._datachanges[node_id]) if end <= dv.SourceTimestamp <= start]
else:
results = [
dv for dv in self._datachanges[node_id] if start <= dv.SourceTimestamp <= end
]
results = [dv for dv in self._datachanges[node_id] if start <= dv.SourceTimestamp <= end]
if nb_values and len(results) > nb_values:
results = results[:nb_values]
if len(results) > self.max_history_data_response_size:
cont = results[self.max_history_data_response_size].SourceTimestamp
results = results[:self.max_history_data_response_size]
results = results[: self.max_history_data_response_size]
return results, cont
async def new_historized_event(self, source_id, evtypes, period, count=0):
......@@ -197,9 +188,7 @@ class HistoryDict(HistoryStorageInterface):
elif end == ua.get_win_epoch():
results = [ev for ev in self._events[source_id] if start <= ev.Time]
elif start > end:
results = [
ev for ev in reversed(self._events[source_id]) if end <= ev.Time <= start
]
results = [ev for ev in reversed(self._events[source_id]) if end <= ev.Time <= start]
else:
results = [ev for ev in self._events[source_id] if start <= ev.Time <= end]
......@@ -209,7 +198,7 @@ class HistoryDict(HistoryStorageInterface):
if len(results) > self.max_history_data_response_size:
cont = results[self.max_history_data_response_size].Time
results = results[:self.max_history_data_response_size]
results = results[: self.max_history_data_response_size]
return results, cont
async def stop(self):
......@@ -260,9 +249,7 @@ class HistoryManager:
Subscribe to the nodes' data changes and store the data in the active storage.
"""
if not self._sub:
self._sub = await self._create_subscription(
SubHandler(self.storage)
)
self._sub = await self._create_subscription(SubHandler(self.storage))
if node in self._handlers:
raise ua.UaError(f"Node {node} is already historized")
await self.storage.new_historized_node(node.nodeid, period, count)
......@@ -285,9 +272,7 @@ class HistoryManager:
must be deleted manually so that a new table with the custom event fields can be created.
"""
if not self._sub:
self._sub = await self._create_subscription(
SubHandler(self.storage)
)
self._sub = await self._create_subscription(SubHandler(self.storage))
if source in self._handlers:
raise ua.UaError(f"Events from {source} are already historized")
......@@ -364,9 +349,7 @@ class HistoryManager:
# send correctly with continuation point
starttime = ua.ua_binary.Primitives.DateTime.unpack(Buffer(rv.ContinuationPoint))
dv, cont = await self.storage.read_node_history(
rv.NodeId, starttime, details.EndTime, details.NumValuesPerNode
)
dv, cont = await self.storage.read_node_history(rv.NodeId, starttime, details.EndTime, details.NumValuesPerNode)
if cont:
cont = ua.ua_binary.Primitives.DateTime.pack(cont)
# rv.IndexRange
......@@ -382,9 +365,7 @@ class HistoryManager:
# send correctly with continuation point
starttime = ua.ua_binary.Primitives.DateTime.unpack(Buffer(rv.ContinuationPoint))
evts, cont = await self.storage.read_event_history(
rv.NodeId, starttime, details.EndTime, details.NumValuesPerNode, details.Filter
)
evts, cont = await self.storage.read_event_history(rv.NodeId, starttime, details.EndTime, details.NumValuesPerNode, details.Filter)
results = []
for ev in evts:
field_list = ua.HistoryEventFieldList()
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -103,7 +103,7 @@ class InternalSubscription:
await asyncio.sleep(max(sleep_time, 0))
await self.publish_results()
except asyncio.CancelledError:
self.logger.info('exiting _subscription_loop for %s', self.data.SubscriptionId)
self.logger.info("exiting _subscription_loop for %s", self.data.SubscriptionId)
raise
except Exception:
# seems this except is necessary to log errors
......@@ -114,8 +114,7 @@ class InternalSubscription:
if self._startup or self._triggered_datachanges or self._triggered_events:
return True
if self._keep_alive_count > self.data.RevisedMaxKeepAliveCount:
self.logger.debug("keep alive count %s is > than max keep alive count %s, sending publish event",
self._keep_alive_count, self.data.RevisedMaxKeepAliveCount)
self.logger.debug("keep alive count %s is > than max keep alive count %s, sending publish event", self._keep_alive_count, self.data.RevisedMaxKeepAliveCount)
return True
self._keep_alive_count += 1
return False
......@@ -128,8 +127,7 @@ class InternalSubscription:
queued to be called back with publish request when one is available.
"""
if self._publish_cycles_count > self.data.RevisedLifetimeCount:
self.logger.warning("Subscription %s has expired, publish cycle count(%s) > lifetime count (%s)", self,
self._publish_cycles_count, self.data.RevisedLifetimeCount)
self.logger.warning("Subscription %s has expired, publish cycle count(%s) > lifetime count (%s)", self, self._publish_cycles_count, self.data.RevisedLifetimeCount)
# FIXME this will never be send since we do not have publish request anyway
await self.monitored_item_srv.trigger_statuschange(ua.StatusCode(ua.StatusCodes.BadTimeout))
# Stop the subscription
......@@ -253,8 +251,7 @@ class InternalSubscription:
self._triggered_statuschanges.append(code)
await self._trigger_publish()
async def _enqueue_event(self, mid: int,
eventdata: Union[ua.MonitoredItemNotification, ua.EventFieldList], size: int, queue: dict):
async def _enqueue_event(self, mid: int, eventdata: Union[ua.MonitoredItemNotification, ua.EventFieldList], size: int, queue: dict):
if mid not in queue:
# New Monitored Item Id
queue[mid] = [eventdata]
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment