Commit c8bebb29 authored by Yuta Okamoto's avatar Yuta Okamoto Committed by oroulet

fix lint errors and mypy warnings

parent a8091a44
......@@ -15,7 +15,6 @@ from .events import Event
__all__ = ["BaseEvent", "AuditEvent", "AuditSecurityEvent", "AuditChannelEvent", "AuditOpenSecureChannelEvent", "AuditSessionEvent", "AuditCreateSessionEvent", "AuditActivateSessionEvent", "AuditCancelEvent", "AuditCertificateEvent", "AuditCertificateDataMismatchEvent", "AuditCertificateExpiredEvent", "AuditCertificateInvalidEvent", "AuditCertificateUntrustedEvent", "AuditCertificateRevokedEvent", "AuditCertificateMismatchEvent", "AuditNodeManagementEvent", "AuditAddNodesEvent", "AuditDeleteNodesEvent", "AuditAddReferencesEvent", "AuditDeleteReferencesEvent", "AuditUpdateEvent", "AuditWriteUpdateEvent", "AuditHistoryUpdateEvent", "AuditUpdateMethodEvent", "SystemEvent", "DeviceFailureEvent", "BaseModelChangeEvent", "GeneralModelChangeEvent", "TransitionEvent", "AuditUpdateStateEvent", "ProgramTransitionEvent", "SemanticChangeEvent", "AuditUrlMismatchEvent", "Condition", "RefreshStartEvent", "RefreshEndEvent", "RefreshRequiredEvent", "AuditConditionEvent", "AuditConditionEnableEvent", "AuditConditionCommentEvent", "DialogCondition", "AcknowledgeableCondition", "AlarmCondition", "LimitAlarm", "AuditHistoryEventUpdateEvent", "AuditHistoryValueUpdateEvent", "AuditHistoryDeleteEvent", "AuditHistoryRawModifyDeleteEvent", "AuditHistoryAtTimeDeleteEvent", "AuditHistoryEventDeleteEvent", "EventQueueOverflowEvent", "ProgramTransitionAuditEvent", "AuditConditionRespondEvent", "AuditConditionAcknowledgeEvent", "AuditConditionConfirmEvent", "ExclusiveLimitAlarm", "ExclusiveLevelAlarm", "ExclusiveRateOfChangeAlarm", "ExclusiveDeviationAlarm", "NonExclusiveLimitAlarm", "NonExclusiveLevelAlarm", "NonExclusiveRateOfChangeAlarm", "NonExclusiveDeviationAlarm", "DiscreteAlarm", "OffNormalAlarm", "TripAlarm", "AuditConditionShelvingEvent", "ProgressEvent", "SystemStatusChangeEvent", "SystemOffNormalAlarm", "AuditProgramTransitionEvent", "TrustListUpdatedAuditEvent", "CertificateUpdatedAuditEvent", "CertificateExpirationAlarm", "AuditConditionResetEvent", "PubSubStatusEvent", "PubSubTransportLimitsExceedEvent", "PubSubCommunicationFailureEvent", "DiscrepancyAlarm", "AuditConditionSuppressionEvent", "AuditConditionSilenceEvent", "AuditConditionOutOfServiceEvent", "RoleMappingRuleChangedAuditEvent", "KeyCredentialAuditEvent", "KeyCredentialUpdatedAuditEvent", "KeyCredentialDeletedAuditEvent", "InstrumentDiagnosticAlarm", "SystemDiagnosticAlarm", "AuditHistoryAnnotationUpdateEvent", "TrustListOutOfDateAlarm", "AuditClientEvent", "AuditClientUpdateMethodResultEvent"]
class BaseEvent(Event):
"""
BaseEvent:
......@@ -1197,4 +1196,4 @@ IMPLEMENTED_EVENTS = {
ua.ObjectIds.TrustListOutOfDateAlarmType: TrustListOutOfDateAlarm,
ua.ObjectIds.AuditClientEventType: AuditClientEvent,
ua.ObjectIds.AuditClientUpdateMethodResultEventType: AuditClientUpdateMethodResultEvent,
}
}
......@@ -7,7 +7,7 @@ import logging
from datetime import datetime
from enum import Enum, IntEnum, IntFlag
from dateutil import parser
from dateutil import parser # type: ignore[attr-defined]
from asyncua import ua
......@@ -315,6 +315,7 @@ def data_type_to_string(dtype):
string = dtype.to_string()
return string
def copy_dataclass_attr(dc_source, dc_dest) -> None:
"""
Copy the common attributes of dc_source to dc_dest
......
......@@ -14,9 +14,11 @@ from ..ua.uaerrors import UaError
_logger = logging.getLogger(__name__)
def _parse_version(version_string: str) -> List[int]:
return [int(v) for v in version_string.split('.')]
class XmlImporter:
def __init__(self, server, strict_mode=True):
......@@ -189,8 +191,12 @@ class XmlImporter:
_logger.debug("Adding missing reference: %s <-> %s (%s)", target_node_id, source_node_id, ref.ReferenceTypeId)
new_ref = ua.AddReferencesItem(SourceNodeId=target_node_id, TargetNodeId=source_node_id,
ReferenceTypeId=ref_type, IsForward=(not ref.IsForward))
new_ref = ua.AddReferencesItem(
SourceNodeId=target_node_id,
TargetNodeId=source_node_id,
ReferenceTypeId=ref_type,
IsForward=(not ref.IsForward)
)
reference_fixes.append(new_ref)
await self._add_references(reference_fixes)
......
......@@ -459,7 +459,7 @@ class XMLParser:
date_time = model.attrib.get('PublicationDate')
if date_time is None:
date_time = ua.DateTime(1, 1, 1)
elif date_time is not None and date_time[-1]=="Z":
elif date_time is not None and date_time[-1] == "Z":
date_time = ua.DateTime.strptime(date_time, "%Y-%m-%dT%H:%M:%SZ")
else:
date_time = ua.DateTime.strptime(date_time, "%Y-%m-%dT%H:%M:%S%z")
......
......@@ -57,7 +57,7 @@ class TrustStore:
async def load_trust(self):
"""(re)load the trusted certificates"""
self._trust_store: crypto.X509Store = crypto.X509Store()
self._trust_store = crypto.X509Store()
for location in self._trust_locations:
await self._load_trust_location(location)
......
......@@ -85,6 +85,7 @@ def der_from_x509(certificate):
return b""
return certificate.public_bytes(serialization.Encoding.DER)
def pem_from_key(private_key: rsa.RSAPrivateKey) -> bytes:
"""dumps a private key in PEM format
......@@ -96,6 +97,7 @@ def pem_from_key(private_key: rsa.RSAPrivateKey) -> bytes:
"""
return private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption())
def sign_sha1(private_key, data):
return private_key.sign(
data,
......
......@@ -612,7 +612,7 @@ class AddressSpace:
https://reference.opcfoundation.org/Core/docs/Part3/
"""
def __init__(self):
def __init__(self) -> None:
self.logger = logging.getLogger(__name__)
self.force_server_timestamp: bool = True
self._nodes: Dict[ua.NodeId, NodeData] = {}
......@@ -814,8 +814,12 @@ class AddressSpace:
return True
if value.Value.VariantType == vtype: # type: ignore[union-attr]
return True
_logger.warning("Write refused: Variant: %s with type %s does not have expected type: %s",
value.Value, value.Value.VariantType, attval.value.Value.VariantType) # type: ignore[union-attr]
_logger.warning(
"Write refused: Variant: %s with type %s does not have expected type: %s",
value.Value,
value.Value.VariantType if value.Value else None,
attval.value.Value.VariantType if attval.value.Value else None,
)
return False
def add_datachange_callback(self, nodeid: ua.NodeId, attr: ua.AttributeIds, callback: Callable) -> Tuple[ua.StatusCode, int]:
......
......@@ -5,6 +5,7 @@ import uuid
import sys
from asyncua import ua
from asyncua.server.internal_session import InternalSession
from ..common import events, event_objects, Node
......@@ -20,7 +21,7 @@ class EventGenerator:
etype: The event type, either an objectId, a NodeId or a Node object
"""
def __init__(self, isession):
def __init__(self, isession: InternalSession):
self.logger = logging.getLogger(__name__)
self.isession = isession
self.event: event_objects.BaseEvent = None
......@@ -91,7 +92,7 @@ class EventGenerator:
self.event.LocalTime = ua.uaprotocol_auto.TimeZoneDataType()
if sys.version_info.major > 2:
localtime = time.localtime(self.event.Time.timestamp())
self.event.LocalTime.Offset = localtime.tm_gmtoff//60
self.event.LocalTime.Offset = localtime.tm_gmtoff // 60
else:
localtime = time.localtime(time.mktime(self.event.Time.timetuple()))
self.event.LocalTime.Offset = -(time.altzone if localtime.tm_isdst else time.timezone)
......
......@@ -22,7 +22,7 @@ class HistorySQLite(HistoryStorageInterface):
note that PARSE_DECLTYPES is active so certain data types (such as datetime) will not be BLOBs
"""
def __init__(self, path="history.db", max_history_data_response_size=10000):
def __init__(self, path="history.db", max_history_data_response_size=10000) -> None:
self.max_history_data_response_size = max_history_data_response_size
self.logger = logging.getLogger(__name__)
self._datachanges_period = {}
......
......@@ -10,7 +10,6 @@ from struct import unpack_from
from pathlib import Path
import logging
from urllib.parse import urlparse
from typing import Coroutine
from asyncua import ua
from .user_managers import PermissiveUserManager, UserManager
......@@ -151,7 +150,7 @@ class InternalServer:
# path was supplied, but file doesn't exist - create one for next start up
await asyncio.get_running_loop().run_in_executor(None, self.aspace.make_aspace_shelf, shelf_file)
async def _address_space_fixes(self) -> Coroutine: # type: ignore
async def _address_space_fixes(self): # type: ignore
"""
Looks like the xml definition of address space has some error. This is a good place to fix them
"""
......
......@@ -11,6 +11,7 @@ from asyncua import ua
from .monitored_item_service import MonitoredItemService
from .address_space import AddressSpace
class InternalSubscription:
"""
Server internal subscription.
......
......@@ -22,7 +22,7 @@ class MonitoredItemData:
class MonitoredItemValues:
def __init__(self):
def __init__(self) -> None:
self.current_dvalue: Optional[ua.DataValue] = None
self.old_dvalue: Optional[ua.DataValue] = None
......@@ -189,13 +189,12 @@ class MonitoredItemService:
if old.StatusCode != current.StatusCode:
return True
if trg in [ua.DataChangeTrigger.StatusValue,ua.DataChangeTrigger.StatusValueTimestamp ] and \
old.Value != current.Value:
if trg in [ua.DataChangeTrigger.StatusValue, ua.DataChangeTrigger.StatusValueTimestamp] and old.Value != current.Value:
return True
if trg == ua.DataChangeTrigger.StatusValueTimestamp and \
(old.SourceTimestamp != current.SourceTimestamp or
old.SourcePicoseconds != current.SourcePicoseconds):
if trg == ua.DataChangeTrigger.StatusValueTimestamp and (
old.SourceTimestamp != current.SourceTimestamp or old.SourcePicoseconds != current.SourcePicoseconds
):
return True
return False
......@@ -213,8 +212,9 @@ class MonitoredItemService:
mdata = self._monitored_items[mid]
mdata.mvalue.set_current_datavalue(value)
if mdata.filter:
deadband_flag_pass = self._is_data_changed(mdata.mvalue, mdata.filter.Trigger) and \
self._is_deadband_exceeded(mdata.mvalue, mdata.filter)
deadband_flag_pass = self._is_data_changed(
mdata.mvalue, mdata.filter.Trigger
) and self._is_deadband_exceeded(mdata.mvalue, mdata.filter)
else:
# Trigger defaults to StatusValue
deadband_flag_pass = self._is_data_changed(mdata.mvalue, ua.DataChangeTrigger.StatusValue)
......
......@@ -8,7 +8,7 @@ import math
from datetime import timedelta, datetime
import socket
from urllib.parse import urlparse
from typing import Coroutine, Optional, Tuple, Union
from typing import Optional, Tuple, Union
from pathlib import Path
from asyncua import ua
......@@ -124,7 +124,7 @@ class Server:
await self.set_application_uri(self._application_uri)
sa_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_ServerArray))
await sa_node.write_value([self._application_uri])
#TODO: ServiceLevel is 255 default, should be calculated in later Versions
# TODO: ServiceLevel is 255 default, should be calculated in later Versions
sl_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_ServiceLevel))
await sl_node.write_value(ua.Variant(255, ua.VariantType.Byte))
......@@ -306,7 +306,7 @@ class Server:
async def _renew_registration(self):
for client in self._discovery_clients.values():
await client.connect_sessionless()
await client.register_server(self) #FIXME discovery_configuration?
await client.register_server(self) # FIXME discovery_configuration?
await client.disconnect_sessionless()
def allow_remote_admin(self, allow):
......@@ -318,7 +318,7 @@ class Server:
def set_endpoint(self, url):
self.endpoint = urlparse(url)
async def get_endpoints(self) -> Coroutine:
async def get_endpoints(self):
return await self.iserver.get_endpoints()
def set_security_policy(self, security_policy, permission_ruleset=None):
......@@ -414,9 +414,8 @@ class Server:
ua.MessageSecurityMode.SignAndEncrypt, self.certificate, self.iserver.private_key,
permission_ruleset=self._permission_ruleset))
@staticmethod
def lookup_security_level_for_policy_type( security_policy_type: ua.SecurityPolicyType ) -> ua.Byte:
def lookup_security_level_for_policy_type(security_policy_type: ua.SecurityPolicyType) -> ua.Byte:
"""Returns the security level for an ua.SecurityPolicyType.
This is endpoint & server implementation specific!
......@@ -426,20 +425,19 @@ class Server:
"""
return ua.Byte({
ua.SecurityPolicyType.NoSecurity : 0,
ua.SecurityPolicyType.Basic128Rsa15_Sign : 1,
ua.SecurityPolicyType.Basic128Rsa15_SignAndEncrypt : 2,
ua.SecurityPolicyType.Basic256_Sign : 11,
ua.SecurityPolicyType.Basic256_SignAndEncrypt : 21,
ua.SecurityPolicyType.Basic256Sha256_Sign : 50,
ua.SecurityPolicyType.Basic256Sha256_SignAndEncrypt : 70,
ua.SecurityPolicyType.Aes128Sha256RsaOaep_Sign : 55,
ua.SecurityPolicyType.Aes128Sha256RsaOaep_SignAndEncrypt : 75
ua.SecurityPolicyType.NoSecurity: 0,
ua.SecurityPolicyType.Basic128Rsa15_Sign: 1,
ua.SecurityPolicyType.Basic128Rsa15_SignAndEncrypt: 2,
ua.SecurityPolicyType.Basic256_Sign: 11,
ua.SecurityPolicyType.Basic256_SignAndEncrypt: 21,
ua.SecurityPolicyType.Basic256Sha256_Sign: 50,
ua.SecurityPolicyType.Basic256Sha256_SignAndEncrypt: 70,
ua.SecurityPolicyType.Aes128Sha256RsaOaep_Sign: 55,
ua.SecurityPolicyType.Aes128Sha256RsaOaep_SignAndEncrypt: 75
}[security_policy_type])
@staticmethod
def determine_security_level(security_policy_uri:str, security_mode: ua.MessageSecurityMode) -> ua.Byte:
def determine_security_level(security_policy_uri: str, security_mode: ua.MessageSecurityMode) -> ua.Byte:
"""Determine the security level of an EndPoint.
The security level indicates how secure an EndPoint is, compared to other EndPoints of the same server.
Value 0 is a special value; EndPoint isn't recommended, typical for ua.MessageSecurityMode.None_.
......@@ -569,7 +567,7 @@ class Server:
"""
return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.ObjectsFolder))
def get_node(self, nodeid):
def get_node(self, nodeid: Union[Node, ua.NodeId, str, int]) -> Node:
"""
Get a specific node using NodeId object or a string representing a NodeId
"""
......@@ -629,7 +627,7 @@ class Server:
return ev_gen
async def create_custom_data_type(self, idx, name, basetype=ua.ObjectIds.BaseDataType,
properties=None, description=None) -> Coroutine:
properties=None, description=None):
if properties is None:
properties = []
base_t = _get_node(self.iserver.isession, basetype)
......@@ -644,7 +642,7 @@ class Server:
return custom_t
async def create_custom_event_type(self, idx, name,
basetype=ua.ObjectIds.BaseEventType, properties=None) -> Coroutine:
basetype=ua.ObjectIds.BaseEventType, properties=None):
if properties is None:
properties = []
return await self._create_custom_type(idx, name, basetype, properties, [], [])
......@@ -655,7 +653,7 @@ class Server:
basetype=ua.ObjectIds.BaseObjectType,
properties=None,
variables=None,
methods=None) -> Coroutine:
methods=None):
if properties is None:
properties = []
if variables is None:
......@@ -673,7 +671,7 @@ class Server:
basetype=ua.ObjectIds.BaseVariableType,
properties=None,
variables=None,
methods=None) -> Coroutine:
methods=None):
if properties is None:
properties = []
if variables is None:
......@@ -701,7 +699,7 @@ class Server:
await custom_t.add_method(idx, method[0], method[1], method[2], method[3])
return custom_t
async def import_xml(self, path=None, xmlstring=None, strict_mode=True) -> Coroutine:
async def import_xml(self, path=None, xmlstring=None, strict_mode=True):
"""
Import nodes defined in xml
"""
......@@ -731,7 +729,7 @@ class Server:
nodes = await get_nodes_of_namespace(self, namespaces)
await self.export_xml(nodes, path, export_values=export_values)
async def delete_nodes(self, nodes, recursive=False) -> Coroutine:
async def delete_nodes(self, nodes, recursive=False):
return await delete_nodes(self.iserver.isession, nodes, recursive)
async def historize_node_data_change(self, node, period=timedelta(days=7), count=0):
......@@ -789,7 +787,7 @@ class Server:
"""
self.iserver.isession.add_method_callback(node.nodeid, callback)
async def load_type_definitions(self, nodes=None) -> Coroutine:
async def load_type_definitions(self, nodes=None):
"""
load custom structures from our server.
Server side this can be used to create python objects from custom structures
......@@ -806,7 +804,7 @@ class Server:
"""
return await load_data_type_definitions(self, node)
async def load_enums(self) -> Coroutine:
async def load_enums(self):
"""
load UA structures and generate python Enums in ua module for custom enums in server
"""
......
......@@ -13,3 +13,5 @@ check_untyped_defs = False
[mypy-asyncua.ua.uaprotocol_auto.*]
# Autogenerated file
disable_error_code = literal-required
[mypy-asynctest.*]
ignore_missing_imports = True
......@@ -19,14 +19,11 @@ from asyncua.server.history_sql import HistorySQLite
from .test_common import add_server_methods
from .util_enum_struct import add_server_custom_enum_struct
RETRY = 20
SLEEP = 0.4
PORTS_USED = set()
Opc = namedtuple('opc', ['opc', 'server'])
Opc = namedtuple('Opc', ['opc', 'server'])
def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
......@@ -39,6 +36,7 @@ def find_free_port():
else:
return find_free_port()
port_num = find_free_port()
port_num1 = find_free_port()
port_discovery = find_free_port()
......@@ -266,6 +264,7 @@ async def history_server(request):
yield srv
await srv.srv.stop()
@pytest.fixture(scope="session")
def client_key_and_cert(request):
base_dir = Path(__file__).parent.parent
......
from pathlib import Path
from typing import Tuple
import pytest
import asyncio
......@@ -107,6 +108,7 @@ async def srv_crypto_one_cert(request):
# stop the server
await srv.stop()
@pytest.fixture(params=srv_crypto_params)
async def srv_crypto_all_cert_basic128rsa15(request):
# start our own server
......@@ -479,6 +481,7 @@ async def test_anonymous_rejection():
await clt.connect()
await srv.stop()
async def test_security_level_all():
assert Server.determine_security_level(ua.SecurityPolicy.URI, ua.MessageSecurityMode.None_) == Server.lookup_security_level_for_policy_type(ua.SecurityPolicyType.NoSecurity)
......@@ -495,7 +498,8 @@ async def test_security_level_all():
assert Server.determine_security_level(security_policies.SecurityPolicyBasic256.URI, ua.MessageSecurityMode.Sign) == Server.lookup_security_level_for_policy_type(ua.SecurityPolicyType.Basic256_Sign)
assert Server.determine_security_level(security_policies.SecurityPolicyBasic256.URI, ua.MessageSecurityMode.SignAndEncrypt) == Server.lookup_security_level_for_policy_type(ua.SecurityPolicyType.Basic256_SignAndEncrypt)
async def test_security_level_endpoints(srv_crypto_all_certs):
async def test_security_level_endpoints(srv_crypto_all_certs: Tuple[Server, str]):
srv = srv_crypto_all_certs[0]
end_points: list[ua.EndpointDescription] = await srv.get_endpoints()
......
......@@ -9,7 +9,7 @@ from cryptography.x509.extensions import _key_identifier_from_public_key as key_
from asyncua.crypto.cert_gen import generate_private_key, generate_app_certificate_signing_request, generate_self_signed_app_certificate, sign_certificate_request
async def test_create_self_signed_app_certificate():
async def test_create_self_signed_app_certificate() -> None:
""" Checks if the self signed certificate complies to OPC 10000-6 6.2.2"""
hostname = socket.gethostname()
......@@ -52,7 +52,7 @@ async def test_create_self_signed_app_certificate():
# check valid time range
assert dt_before_generation <= cert.not_valid_before <= dt_after_generation
assert (dt_before_generation+timedelta(days_valid)) <= cert.not_valid_after <= (dt_after_generation+timedelta(days_valid))
assert (dt_before_generation + timedelta(days_valid)) <= cert.not_valid_after <= (dt_after_generation + timedelta(days_valid))
# check issuer
assert cert.issuer.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value == f"myserver@{hostname}"
......@@ -63,6 +63,8 @@ async def test_create_self_signed_app_certificate():
# check Authority Key Identifier
auth_key_identifier: x509.AuthorityKeyIdentifier = cert.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier).value
assert auth_key_identifier
assert isinstance(auth_key_identifier.authority_cert_issuer, list)
assert len(auth_key_identifier.authority_cert_issuer) > 0
issuer: x509.Name = auth_key_identifier.authority_cert_issuer[0].value
assert issuer.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value == f"myserver@{hostname}"
assert auth_key_identifier.authority_cert_serial_number == cert.serial_number
......@@ -89,7 +91,7 @@ async def test_create_self_signed_app_certificate():
assert cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value == x509.ExtendedKeyUsage(extended)
async def test_app_create_certificate_signing_request():
async def test_app_create_certificate_signing_request() -> None:
""" Checks if the self signed certificate complies to OPC 10000-6 6.2.2"""
hostname = socket.gethostname()
......@@ -139,7 +141,7 @@ async def test_app_create_certificate_signing_request():
assert csr.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value == x509.ExtendedKeyUsage(extended)
async def test_app_sign_certificate_request():
async def test_app_sign_certificate_request() -> None:
"""Check the correct signing of certificate signing request"""
hostname = socket.gethostname()
......@@ -180,6 +182,8 @@ async def test_app_sign_certificate_request():
# check authority Key Identifier
auth_key_identifier: x509.AuthorityKeyIdentifier = cert.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier).value
assert auth_key_identifier
assert isinstance(auth_key_identifier.authority_cert_issuer, list)
assert len(auth_key_identifier.authority_cert_issuer) > 0
assert auth_key_identifier.authority_cert_issuer[0].value == issuer.subject
assert auth_key_identifier.authority_cert_serial_number == issuer.serial_number
assert auth_key_identifier.key_identifier == key_identifier_from_public_key(key_ca.public_key())
......
......@@ -189,6 +189,7 @@ async def test_references_for_added_nodes_method(server):
assert await m.get_parent() == o
await server.delete_nodes([o])
async def test_get_event_from_type_node_BaseEvent(server):
"""
This should work for following BaseEvent tests to work
......@@ -299,6 +300,7 @@ async def test_eventgenerator_sourceMyObject(server):
await check_event_generator_object(evgen, o)
await server.delete_nodes([o])
async def test_eventgenerator_source_collision(server):
objects = server.nodes.objects
o = await objects.add_object(3, 'MyObject')
......@@ -308,6 +310,7 @@ async def test_eventgenerator_source_collision(server):
await check_event_generator_object(evgen, o, emitting_node=asyncua.Node(server.iserver.isession, ua.ObjectIds.Server))
await server.delete_nodes([o])
async def test_eventgenerator_inherited_event(server):
evgen = await server.get_event_generator(ua.ObjectIds.AuditEventType)
await check_eventgenerator_source_server(evgen, server)
......@@ -402,8 +405,7 @@ async def test_create_custom_event_type_node_id(server):
async def test_create_custom_event_type_node(server):
etype = await server.create_custom_event_type(2, 'MyEvent1', asyncua.Node(server.iserver.isession,
ua.NodeId(ua.ObjectIds.BaseEventType)),
etype = await server.create_custom_event_type(2, 'MyEvent1', asyncua.Node(server.iserver.isession, ua.NodeId(ua.ObjectIds.BaseEventType)),
[('PropertyNum', ua.VariantType.Int32),
('PropertyString', ua.VariantType.String)])
await check_custom_type(etype, ua.ObjectIds.BaseEventType, server)
......@@ -440,8 +442,7 @@ async def test_eventgenerator_custom_event_with_variables(server):
('PropertyString', ua.VariantType.String)]
variables = [('VariableString', ua.VariantType.String),
('MyEnumVar', ua.VariantType.Int32, ua.NodeId(ua.ObjectIds.ApplicationType))]
etype = await server.create_custom_object_type(2, 'MyEvent33', ua.ObjectIds.BaseEventType,
properties, variables)
etype = await server.create_custom_object_type(2, 'MyEvent33', ua.ObjectIds.BaseEventType, properties, variables)
evgen = await server.get_event_generator(etype, ua.ObjectIds.Server)
check_eventgenerator_custom_event(evgen, etype, server)
await check_eventgenerator_source_server(evgen, server)
......@@ -555,8 +556,11 @@ async def test_get_node_by_ns(server):
async def test_load_enum_strings(server):
dt = await server.nodes.enum_data_type.add_data_type(0, "MyStringEnum")
await dt.add_property(0, "EnumStrings", [ua.LocalizedText("e1"), ua.LocalizedText("e2"), ua.LocalizedText("e3"),
ua.LocalizedText("e 4")])
await dt.add_property(
0,
"EnumStrings",
[ua.LocalizedText("e1"), ua.LocalizedText("e2"), ua.LocalizedText("e3"), ua.LocalizedText("e 4")]
)
await server.load_enums()
e = getattr(ua, "MyStringEnum")
assert isinstance(e, EnumMeta)
......@@ -567,18 +571,9 @@ async def test_load_enum_strings(server):
async def test_load_enum_values(server):
dt = await server.nodes.enum_data_type.add_data_type(0, "MyValuesEnum")
v1 = ua.EnumValueType(
DisplayName=ua.LocalizedText("v1"),
Value=2,
)
v2 = ua.EnumValueType(
DisplayName=ua.LocalizedText("v2"),
Value=3,
)
v3 = ua.EnumValueType(
DisplayName=ua.LocalizedText("v 3 "),
Value=4.
)
v1 = ua.EnumValueType(DisplayName=ua.LocalizedText("v1"), Value=2)
v2 = ua.EnumValueType(DisplayName=ua.LocalizedText("v2"), Value=3)
v3 = ua.EnumValueType(DisplayName=ua.LocalizedText("v 3 "), Value=4.)
await dt.add_property(0, "EnumValues", [v1, v2, v3])
await server.load_enums()
e = getattr(ua, "MyValuesEnum")
......@@ -645,8 +640,7 @@ def check_custom_event(ev, etype):
async def check_custom_type(ntype, base_type, server: Server, node_class=None):
base = asyncua.Node(server.iserver.isession, ua.NodeId(base_type))
assert ntype in await base.get_children()
nodes = await ntype.get_referenced_nodes(refs=ua.ObjectIds.HasSubtype, direction=ua.BrowseDirection.Inverse,
includesubtypes=True)
nodes = await ntype.get_referenced_nodes(refs=ua.ObjectIds.HasSubtype, direction=ua.BrowseDirection.Inverse, includesubtypes=True)
assert base == nodes[0]
if node_class:
assert node_class == await ntype.read_node_class()
......@@ -658,6 +652,7 @@ async def check_custom_type(ntype, base_type, server: Server, node_class=None):
assert await ntype.get_child("2:PropertyString") in properties
assert (await(await ntype.get_child("2:PropertyString")).read_data_value()).Value.VariantType == ua.VariantType.String
async def test_server_read_write_attribute_value(server: Server):
node = await server.get_objects_node().add_variable(0, "0:TestVar", 0, varianttype=ua.VariantType.Int64)
dv = server.read_attribute_value(node.nodeid, attr=ua.AttributeIds.Value)
......@@ -672,6 +667,7 @@ async def test_server_read_write_attribute_value(server: Server):
@pytest.fixture(scope="function")
def restore_transport_limits_server(server: Server):
# Restore limits after test
assert server.bserver is not None
max_recv = server.bserver.limits.max_recv_buffer
max_chunk_count = server.bserver.limits.max_chunk_count
yield server
......@@ -681,6 +677,7 @@ def restore_transport_limits_server(server: Server):
async def test_message_limits_fail_write(restore_transport_limits_server: Server):
server = restore_transport_limits_server
assert server.bserver is not None
server.bserver.limits.max_recv_buffer = 1024
server.bserver.limits.max_send_buffer = 10240000
server.bserver.limits.max_chunk_count = 10
......@@ -698,6 +695,7 @@ async def test_message_limits_fail_write(restore_transport_limits_server: Server
async def test_message_limits_fail_read(restore_transport_limits_server: Server):
server = restore_transport_limits_server
assert server.bserver is not None
server.bserver.limits.max_recv_buffer = 10240000
server.bserver.limits.max_send_buffer = 1024
server.bserver.limits.max_chunk_count = 10
......@@ -716,6 +714,7 @@ async def test_message_limits_fail_read(restore_transport_limits_server: Server)
async def test_message_limits_works(restore_transport_limits_server: Server):
server = restore_transport_limits_server
# server.bserver.limits.max_recv_buffer = 1024
assert server.bserver is not None
server.bserver.limits.max_send_buffer = 1024
server.bserver.limits.max_chunk_count = 10
n = await server.nodes.objects.add_variable(1, "MyLimitVariable2", "t")
......@@ -729,7 +728,6 @@ async def test_message_limits_works(restore_transport_limits_server: Server):
await n.read_value()
"""
class TestServerCaching(unittest.TestCase):
def runTest(self):
......@@ -755,6 +753,7 @@ class TestServerCaching(unittest.TestCase):
"""
async def test_null_auth(server):
"""
OPC-UA Specification Part 4, 5.6.3 specifies that a:
......@@ -763,6 +762,7 @@ async def test_null_auth(server):
Ensure a Null token is accepted as an anonymous connection token.
"""
client = Client(server.endpoint.geturl())
# Modify the authentication creation in the client request
def _add_null_auth(self, params):
params.UserIdentityToken = ua.ExtensionObject(ua.NodeId(ua.ObjectIds.Null))
......@@ -772,7 +772,7 @@ async def test_null_auth(server):
pass
async def test_start_server_when_port_is_in_use(server: str):
async def test_start_server_when_port_is_in_use(server: Server):
server2 = Server()
await server2.init()
url = server.endpoint.geturl()
......
......@@ -8,10 +8,12 @@ from asyncua.common.subscription import Subscription
try:
from unittest.mock import AsyncMock
except ImportError:
from asynctest import CoroutineMock as AsyncMock
from asynctest import CoroutineMock as AsyncMock # type: ignore[no-redef]
import asyncua
from asyncua import ua, Client
from .conftest import Opc
class MySubHandler:
"""
......@@ -245,7 +247,8 @@ async def test_subscription_data_change(opc):
await sub.unsubscribe(handle1) # sub does not exist anymore
await opc.opc.delete_nodes([v1])
async def test_subscription_monitored_item(opc):
async def test_subscription_monitored_item(opc: Opc):
"""
test subscriptions with a monitored item with a datachange filter.
......@@ -258,12 +261,11 @@ async def test_subscription_monitored_item(opc):
v1 = await o.add_variable(3, 'SubscriptionVariableV1', startv1)
sub: Subscription = await opc.opc.create_subscription(100, myhandler)
mfilter = ua.DataChangeFilter(Trigger=ua.DataChangeTrigger.StatusValueTimestamp)
#For creating monitor items create_monitored_items is availablem, but that one is not very easy in use.
#So use the internal function instead.
#TODO: Should there be an easy shorthand for making monitored items with filter?
# For creating monitor items create_monitored_items is availablem, but that one is not very easy in use.
# So use the internal function instead.
# TODO: Should there be an easy shorthand for making monitored items with filter?
handles = await sub._subscribe(nodes=v1, mfilter=mfilter)
# # Now check we get the start value
......@@ -1043,6 +1045,7 @@ async def test_publish(opc, mocker):
publish_event = asyncio.Event()
publish_org = client.uaclient.publish
async def publish(acks):
await publish_event.wait()
publish_event.clear()
......
......@@ -132,12 +132,12 @@ def trust_store(tmp_path) -> TrustStore:
return _trust_store
async def test_selfsigned_not_in_trust_store(cert_files, trust_store):
async def test_selfsigned_not_in_trust_store(cert_files, trust_store) -> None:
cert_self_signed: x509.Certificate = await load_certificate(cert_files / SERVER_CERT_SELF_SIGNED_FILE)
assert trust_store.is_trusted(cert_self_signed) is False
async def test_selfsigned_in_trust_store(cert_files, trust_store):
async def test_selfsigned_in_trust_store(cert_files, trust_store) -> None:
shutil.copyfile(cert_files / SERVER_CERT_SELF_SIGNED_FILE, trust_store.trust_locations[0] / SERVER_CERT_SELF_SIGNED_FILE)
await trust_store.load()
......@@ -145,12 +145,12 @@ async def test_selfsigned_in_trust_store(cert_files, trust_store):
assert trust_store.is_trusted(cert_self_signed) is True
async def test_ca_not_in_trust_store(cert_files, trust_store):
async def test_ca_not_in_trust_store(cert_files, trust_store) -> None:
cert_self_signed: x509.Certificate = await load_certificate(cert_files / SERVER_CERT_SELF_SIGNED_FILE)
assert trust_store.is_trusted(cert_self_signed) is False
async def test_ca_in_trust_store(cert_files, trust_store):
async def test_ca_in_trust_store(cert_files, trust_store) -> None:
shutil.copyfile(cert_files / CA_CERT_FILE, trust_store.trust_locations[0] / CA_CERT_FILE)
await trust_store.load()
......@@ -158,7 +158,7 @@ async def test_ca_in_trust_store(cert_files, trust_store):
assert trust_store.is_trusted(cert_server) is True
async def test_empty_crl(cert_files, trust_store):
async def test_empty_crl(cert_files, trust_store) -> None:
shutil.copyfile(cert_files / CA_CERT_FILE, trust_store.trust_locations[0] / CA_CERT_FILE)
shutil.copyfile(cert_files / 'ca_empty_crl.der', trust_store.crl_locations[0] / 'ca_empty_crl.der')
await trust_store.load()
......@@ -172,7 +172,7 @@ async def test_empty_crl(cert_files, trust_store):
assert trust_store.validate(cert_server) is True
async def test_cert_in_crl(cert_files, trust_store):
async def test_cert_in_crl(cert_files, trust_store) -> None:
shutil.copyfile(cert_files / CA_CERT_FILE, trust_store.trust_locations[0] / CA_CERT_FILE)
shutil.copyfile(cert_files / 'ca_crl.der', trust_store.crl_locations[0] / 'ca_crl.der')
await trust_store.load()
......
......@@ -2,7 +2,7 @@ from dataclasses import dataclass
from asyncua.common.ua_utils import copy_dataclass_attr
def test_copy_dataclass_attr():
def test_copy_dataclass_attr() -> None:
@dataclass
class A:
x: int = 1
......
......@@ -11,7 +11,7 @@ import pytest
import logging
from datetime import datetime
from dataclasses import dataclass, field
from typing import Optional, List
from typing import Optional, List, cast
from asyncua import ua
from asyncua.ua import ua_binary
......@@ -773,7 +773,7 @@ def test_where_clause():
op.BrowsePath.append(ua.QualifiedName("property", 2))
el.FilterOperands.append(op)
for i in range(10):
op = ua.LiteralOperand(Value = ua.Variant(i))
op = ua.LiteralOperand(Value=ua.Variant(i))
el.FilterOperands.append(op)
el.FilterOperator = ua.FilterOperator.InList
cf.Elements.append(el)
......@@ -858,7 +858,7 @@ def test_expandedNodeId():
assert nid.Identifier == 85
def test_struct_104():
def test_struct_104() -> None:
@dataclass
class MyStruct:
Encoding: ua.Byte = field(default=0, repr=False, init=False)
......@@ -872,7 +872,7 @@ def test_struct_104():
m2 = struct_from_binary(MyStruct, ua.utils.Buffer(data))
assert m == m2
m = MyStruct(a=4, b=5, c="lkjkæl", l=["a", "b", "c"])
m = MyStruct(a=4, b=5, c="lkjkæl", l=[cast(ua.String, "a"), cast(ua.String, "b"), cast(ua.String, "c")])
data = struct_to_binary(m)
m2 = struct_from_binary(MyStruct, ua.utils.Buffer(data))
assert m == m2
......
......@@ -30,8 +30,8 @@ async def add_server_custom_enum_struct(server: Server):
uatypes.register_enum('ExampleEnum', ua.NodeId(3002, ns), ExampleEnum)
uatypes.register_extension_object('ExampleStruct', ua.NodeId(5001, ns), ExampleStruct)
await server.import_xml(TEST_DIR / "enum_struct_test_nodes.xml"),
val = ua.ExampleStruct()
val = ExampleStruct()
val.IntVal1 = 242
val.EnumVal = ua.ExampleEnum.EnumVal2
val.EnumVal = ExampleEnum.EnumVal2
myvar = server.get_node(ua.NodeId(6009, ns))
await myvar.write_value(val)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment