Commit c8bebb29 authored by Yuta Okamoto's avatar Yuta Okamoto Committed by oroulet

fix lint errors and mypy warnings

parent a8091a44
...@@ -456,7 +456,7 @@ class UaClient(AbstractSession): ...@@ -456,7 +456,7 @@ class UaClient(AbstractSession):
self.logger.debug(response) self.logger.debug(response)
response.ResponseHeader.ServiceResult.check() response.ResponseHeader.ServiceResult.check()
# nothing to return for this service # nothing to return for this service
async def unregister_server(self, registered_server): async def unregister_server(self, registered_server):
self.logger.debug("unregister_server") self.logger.debug("unregister_server")
request = ua.RegisterServerRequest() request = ua.RegisterServerRequest()
...@@ -552,7 +552,7 @@ class UaClient(AbstractSession): ...@@ -552,7 +552,7 @@ class UaClient(AbstractSession):
) )
return response.Parameters return response.Parameters
modify_subscription = update_subscription # legacy support modify_subscription = update_subscription # legacy support
async def delete_subscriptions(self, subscription_ids): async def delete_subscriptions(self, subscription_ids):
self.logger.debug("delete_subscriptions %r", subscription_ids) self.logger.debug("delete_subscriptions %r", subscription_ids)
......
...@@ -75,5 +75,5 @@ async def _read_and_copy_attrs(node_type, struct, addnode): ...@@ -75,5 +75,5 @@ async def _read_and_copy_attrs(node_type, struct, addnode):
setattr(struct, name, results[idx].Value.Value) setattr(struct, name, results[idx].Value.Value)
else: else:
_logger.warning(f"Instantiate: while copying attributes from node type {str(node_type)}," _logger.warning(f"Instantiate: while copying attributes from node type {str(node_type)},"
f" attribute {str(name)}, statuscode is {str(results[idx].StatusCode)}") f" attribute {str(name)}, statuscode is {str(results[idx].StatusCode)}")
addnode.NodeAttributes = struct addnode.NodeAttributes = struct
...@@ -15,7 +15,6 @@ from .events import Event ...@@ -15,7 +15,6 @@ from .events import Event
__all__ = ["BaseEvent", "AuditEvent", "AuditSecurityEvent", "AuditChannelEvent", "AuditOpenSecureChannelEvent", "AuditSessionEvent", "AuditCreateSessionEvent", "AuditActivateSessionEvent", "AuditCancelEvent", "AuditCertificateEvent", "AuditCertificateDataMismatchEvent", "AuditCertificateExpiredEvent", "AuditCertificateInvalidEvent", "AuditCertificateUntrustedEvent", "AuditCertificateRevokedEvent", "AuditCertificateMismatchEvent", "AuditNodeManagementEvent", "AuditAddNodesEvent", "AuditDeleteNodesEvent", "AuditAddReferencesEvent", "AuditDeleteReferencesEvent", "AuditUpdateEvent", "AuditWriteUpdateEvent", "AuditHistoryUpdateEvent", "AuditUpdateMethodEvent", "SystemEvent", "DeviceFailureEvent", "BaseModelChangeEvent", "GeneralModelChangeEvent", "TransitionEvent", "AuditUpdateStateEvent", "ProgramTransitionEvent", "SemanticChangeEvent", "AuditUrlMismatchEvent", "Condition", "RefreshStartEvent", "RefreshEndEvent", "RefreshRequiredEvent", "AuditConditionEvent", "AuditConditionEnableEvent", "AuditConditionCommentEvent", "DialogCondition", "AcknowledgeableCondition", "AlarmCondition", "LimitAlarm", "AuditHistoryEventUpdateEvent", "AuditHistoryValueUpdateEvent", "AuditHistoryDeleteEvent", "AuditHistoryRawModifyDeleteEvent", "AuditHistoryAtTimeDeleteEvent", "AuditHistoryEventDeleteEvent", "EventQueueOverflowEvent", "ProgramTransitionAuditEvent", "AuditConditionRespondEvent", "AuditConditionAcknowledgeEvent", "AuditConditionConfirmEvent", "ExclusiveLimitAlarm", "ExclusiveLevelAlarm", "ExclusiveRateOfChangeAlarm", "ExclusiveDeviationAlarm", "NonExclusiveLimitAlarm", "NonExclusiveLevelAlarm", "NonExclusiveRateOfChangeAlarm", "NonExclusiveDeviationAlarm", "DiscreteAlarm", "OffNormalAlarm", "TripAlarm", "AuditConditionShelvingEvent", "ProgressEvent", "SystemStatusChangeEvent", "SystemOffNormalAlarm", "AuditProgramTransitionEvent", "TrustListUpdatedAuditEvent", "CertificateUpdatedAuditEvent", "CertificateExpirationAlarm", "AuditConditionResetEvent", "PubSubStatusEvent", "PubSubTransportLimitsExceedEvent", "PubSubCommunicationFailureEvent", "DiscrepancyAlarm", "AuditConditionSuppressionEvent", "AuditConditionSilenceEvent", "AuditConditionOutOfServiceEvent", "RoleMappingRuleChangedAuditEvent", "KeyCredentialAuditEvent", "KeyCredentialUpdatedAuditEvent", "KeyCredentialDeletedAuditEvent", "InstrumentDiagnosticAlarm", "SystemDiagnosticAlarm", "AuditHistoryAnnotationUpdateEvent", "TrustListOutOfDateAlarm", "AuditClientEvent", "AuditClientUpdateMethodResultEvent"] __all__ = ["BaseEvent", "AuditEvent", "AuditSecurityEvent", "AuditChannelEvent", "AuditOpenSecureChannelEvent", "AuditSessionEvent", "AuditCreateSessionEvent", "AuditActivateSessionEvent", "AuditCancelEvent", "AuditCertificateEvent", "AuditCertificateDataMismatchEvent", "AuditCertificateExpiredEvent", "AuditCertificateInvalidEvent", "AuditCertificateUntrustedEvent", "AuditCertificateRevokedEvent", "AuditCertificateMismatchEvent", "AuditNodeManagementEvent", "AuditAddNodesEvent", "AuditDeleteNodesEvent", "AuditAddReferencesEvent", "AuditDeleteReferencesEvent", "AuditUpdateEvent", "AuditWriteUpdateEvent", "AuditHistoryUpdateEvent", "AuditUpdateMethodEvent", "SystemEvent", "DeviceFailureEvent", "BaseModelChangeEvent", "GeneralModelChangeEvent", "TransitionEvent", "AuditUpdateStateEvent", "ProgramTransitionEvent", "SemanticChangeEvent", "AuditUrlMismatchEvent", "Condition", "RefreshStartEvent", "RefreshEndEvent", "RefreshRequiredEvent", "AuditConditionEvent", "AuditConditionEnableEvent", "AuditConditionCommentEvent", "DialogCondition", "AcknowledgeableCondition", "AlarmCondition", "LimitAlarm", "AuditHistoryEventUpdateEvent", "AuditHistoryValueUpdateEvent", "AuditHistoryDeleteEvent", "AuditHistoryRawModifyDeleteEvent", "AuditHistoryAtTimeDeleteEvent", "AuditHistoryEventDeleteEvent", "EventQueueOverflowEvent", "ProgramTransitionAuditEvent", "AuditConditionRespondEvent", "AuditConditionAcknowledgeEvent", "AuditConditionConfirmEvent", "ExclusiveLimitAlarm", "ExclusiveLevelAlarm", "ExclusiveRateOfChangeAlarm", "ExclusiveDeviationAlarm", "NonExclusiveLimitAlarm", "NonExclusiveLevelAlarm", "NonExclusiveRateOfChangeAlarm", "NonExclusiveDeviationAlarm", "DiscreteAlarm", "OffNormalAlarm", "TripAlarm", "AuditConditionShelvingEvent", "ProgressEvent", "SystemStatusChangeEvent", "SystemOffNormalAlarm", "AuditProgramTransitionEvent", "TrustListUpdatedAuditEvent", "CertificateUpdatedAuditEvent", "CertificateExpirationAlarm", "AuditConditionResetEvent", "PubSubStatusEvent", "PubSubTransportLimitsExceedEvent", "PubSubCommunicationFailureEvent", "DiscrepancyAlarm", "AuditConditionSuppressionEvent", "AuditConditionSilenceEvent", "AuditConditionOutOfServiceEvent", "RoleMappingRuleChangedAuditEvent", "KeyCredentialAuditEvent", "KeyCredentialUpdatedAuditEvent", "KeyCredentialDeletedAuditEvent", "InstrumentDiagnosticAlarm", "SystemDiagnosticAlarm", "AuditHistoryAnnotationUpdateEvent", "TrustListOutOfDateAlarm", "AuditClientEvent", "AuditClientUpdateMethodResultEvent"]
class BaseEvent(Event): class BaseEvent(Event):
""" """
BaseEvent: BaseEvent:
...@@ -1197,4 +1196,4 @@ IMPLEMENTED_EVENTS = { ...@@ -1197,4 +1196,4 @@ IMPLEMENTED_EVENTS = {
ua.ObjectIds.TrustListOutOfDateAlarmType: TrustListOutOfDateAlarm, ua.ObjectIds.TrustListOutOfDateAlarmType: TrustListOutOfDateAlarm,
ua.ObjectIds.AuditClientEventType: AuditClientEvent, ua.ObjectIds.AuditClientEventType: AuditClientEvent,
ua.ObjectIds.AuditClientUpdateMethodResultEventType: AuditClientUpdateMethodResultEvent, ua.ObjectIds.AuditClientUpdateMethodResultEventType: AuditClientUpdateMethodResultEvent,
} }
...@@ -7,7 +7,7 @@ import logging ...@@ -7,7 +7,7 @@ import logging
from datetime import datetime from datetime import datetime
from enum import Enum, IntEnum, IntFlag from enum import Enum, IntEnum, IntFlag
from dateutil import parser from dateutil import parser # type: ignore[attr-defined]
from asyncua import ua from asyncua import ua
...@@ -315,6 +315,7 @@ def data_type_to_string(dtype): ...@@ -315,6 +315,7 @@ def data_type_to_string(dtype):
string = dtype.to_string() string = dtype.to_string()
return string return string
def copy_dataclass_attr(dc_source, dc_dest) -> None: def copy_dataclass_attr(dc_source, dc_dest) -> None:
""" """
Copy the common attributes of dc_source to dc_dest Copy the common attributes of dc_source to dc_dest
......
...@@ -14,9 +14,11 @@ from ..ua.uaerrors import UaError ...@@ -14,9 +14,11 @@ from ..ua.uaerrors import UaError
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def _parse_version(version_string: str) -> List[int]: def _parse_version(version_string: str) -> List[int]:
return [int(v) for v in version_string.split('.')] return [int(v) for v in version_string.split('.')]
class XmlImporter: class XmlImporter:
def __init__(self, server, strict_mode=True): def __init__(self, server, strict_mode=True):
...@@ -162,7 +164,7 @@ class XmlImporter: ...@@ -162,7 +164,7 @@ class XmlImporter:
ua.ObjectIds.FiniteTransitionVariableType, ua.ObjectIds.HasInterface} ua.ObjectIds.FiniteTransitionVariableType, ua.ObjectIds.HasInterface}
dangling_refs_to_missing_nodes = set(new_nodes) dangling_refs_to_missing_nodes = set(new_nodes)
RefSpecKey = Tuple[ua.NodeId, ua.NodeId, ua.NodeId] # (source_node_id, target_node_id, ref_type_id) RefSpecKey = Tuple[ua.NodeId, ua.NodeId, ua.NodeId] # (source_node_id, target_node_id, ref_type_id)
node_reference_map: Dict[RefSpecKey, ua.ReferenceDescription] = {} node_reference_map: Dict[RefSpecKey, ua.ReferenceDescription] = {}
for new_node_id in new_nodes: for new_node_id in new_nodes:
...@@ -189,8 +191,12 @@ class XmlImporter: ...@@ -189,8 +191,12 @@ class XmlImporter:
_logger.debug("Adding missing reference: %s <-> %s (%s)", target_node_id, source_node_id, ref.ReferenceTypeId) _logger.debug("Adding missing reference: %s <-> %s (%s)", target_node_id, source_node_id, ref.ReferenceTypeId)
new_ref = ua.AddReferencesItem(SourceNodeId=target_node_id, TargetNodeId=source_node_id, new_ref = ua.AddReferencesItem(
ReferenceTypeId=ref_type, IsForward=(not ref.IsForward)) SourceNodeId=target_node_id,
TargetNodeId=source_node_id,
ReferenceTypeId=ref_type,
IsForward=(not ref.IsForward)
)
reference_fixes.append(new_ref) reference_fixes.append(new_ref)
await self._add_references(reference_fixes) await self._add_references(reference_fixes)
......
...@@ -459,7 +459,7 @@ class XMLParser: ...@@ -459,7 +459,7 @@ class XMLParser:
date_time = model.attrib.get('PublicationDate') date_time = model.attrib.get('PublicationDate')
if date_time is None: if date_time is None:
date_time = ua.DateTime(1, 1, 1) date_time = ua.DateTime(1, 1, 1)
elif date_time is not None and date_time[-1]=="Z": elif date_time is not None and date_time[-1] == "Z":
date_time = ua.DateTime.strptime(date_time, "%Y-%m-%dT%H:%M:%SZ") date_time = ua.DateTime.strptime(date_time, "%Y-%m-%dT%H:%M:%SZ")
else: else:
date_time = ua.DateTime.strptime(date_time, "%Y-%m-%dT%H:%M:%S%z") date_time = ua.DateTime.strptime(date_time, "%Y-%m-%dT%H:%M:%S%z")
......
...@@ -57,7 +57,7 @@ class TrustStore: ...@@ -57,7 +57,7 @@ class TrustStore:
async def load_trust(self): async def load_trust(self):
"""(re)load the trusted certificates""" """(re)load the trusted certificates"""
self._trust_store: crypto.X509Store = crypto.X509Store() self._trust_store = crypto.X509Store()
for location in self._trust_locations: for location in self._trust_locations:
await self._load_trust_location(location) await self._load_trust_location(location)
...@@ -113,7 +113,7 @@ class TrustStore: ...@@ -113,7 +113,7 @@ class TrustStore:
for revoked in self._revoked_list: for revoked in self._revoked_list:
if revoked.serial_number == certificate.serial_number: if revoked.serial_number == certificate.serial_number:
subject_cn = certificate.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value subject_cn = certificate.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value
_logger.warning('Found revoked serial "%s" [CN=%s]', hex(certificate.serial_number), subject_cn) _logger.warning('Found revoked serial "%s" [CN=%s]', hex(certificate.serial_number), subject_cn)
is_revoked = True is_revoked = True
break break
return is_revoked return is_revoked
......
...@@ -85,6 +85,7 @@ def der_from_x509(certificate): ...@@ -85,6 +85,7 @@ def der_from_x509(certificate):
return b"" return b""
return certificate.public_bytes(serialization.Encoding.DER) return certificate.public_bytes(serialization.Encoding.DER)
def pem_from_key(private_key: rsa.RSAPrivateKey) -> bytes: def pem_from_key(private_key: rsa.RSAPrivateKey) -> bytes:
"""dumps a private key in PEM format """dumps a private key in PEM format
...@@ -96,6 +97,7 @@ def pem_from_key(private_key: rsa.RSAPrivateKey) -> bytes: ...@@ -96,6 +97,7 @@ def pem_from_key(private_key: rsa.RSAPrivateKey) -> bytes:
""" """
return private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption()) return private_key.private_bytes(encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption())
def sign_sha1(private_key, data): def sign_sha1(private_key, data):
return private_key.sign( return private_key.sign(
data, data,
......
...@@ -24,7 +24,7 @@ if TYPE_CHECKING: ...@@ -24,7 +24,7 @@ if TYPE_CHECKING:
VariableAttributes, VariableAttributes,
ObjectTypeAttributes, ObjectTypeAttributes,
ObjectAttributes ObjectAttributes
] # FIXME Check, if there are missing attribute types. ] # FIXME Check, if there are missing attribute types.
from asyncua import ua from asyncua import ua
...@@ -203,8 +203,8 @@ class ViewService(object): ...@@ -203,8 +203,8 @@ class ViewService(object):
return res return res
for nodeid in target_nodeids: for nodeid in target_nodeids:
target = ua.BrowsePathTarget() target = ua.BrowsePathTarget()
target.TargetId = nodeid # FIXME <<<< Type conflict target.TargetId = nodeid # FIXME <<<< Type conflict
target.RemainingPathIndex = ua.Index(4294967295) # FIXME: magic number, why not Index.MAX? target.RemainingPathIndex = ua.Index(4294967295) # FIXME: magic number, why not Index.MAX?
res.Targets.append(target) res.Targets.append(target)
# FIXME: might need to order these one way or another # FIXME: might need to order these one way or another
return res return res
...@@ -274,7 +274,7 @@ class NodeManagementService: ...@@ -274,7 +274,7 @@ class NodeManagementService:
# the namespace of the nodeid, this is an extention of the spec to allow # the namespace of the nodeid, this is an extention of the spec to allow
# to requests the server to generate a new nodeid in a specified namespace # to requests the server to generate a new nodeid in a specified namespace
# self.logger.debug("RequestedNewNodeId has null identifier, generating Identifier") # self.logger.debug("RequestedNewNodeId has null identifier, generating Identifier")
item.RequestedNewNodeId = self._aspace.generate_nodeid(item.RequestedNewNodeId.NamespaceIndex) # FIXME type conflict item.RequestedNewNodeId = self._aspace.generate_nodeid(item.RequestedNewNodeId.NamespaceIndex) # FIXME type conflict
else: else:
if item.RequestedNewNodeId in self._aspace: if item.RequestedNewNodeId in self._aspace:
self.logger.warning("AddNodesItem: Requested NodeId %s already exists", item.RequestedNewNodeId) self.logger.warning("AddNodesItem: Requested NodeId %s already exists", item.RequestedNewNodeId)
...@@ -365,26 +365,26 @@ class NodeManagementService: ...@@ -365,26 +365,26 @@ class NodeManagementService:
desc.BrowseName = item.BrowseName desc.BrowseName = item.BrowseName
desc.DisplayName = item.NodeAttributes.DisplayName desc.DisplayName = item.NodeAttributes.DisplayName
desc.TypeDefinition = item.TypeDefinition desc.TypeDefinition = item.TypeDefinition
desc.IsForward = True # FIXME in uaprotocol_auto.py desc.IsForward = True # FIXME in uaprotocol_auto.py
self._add_unique_reference(parentdata, desc) # FIXME return StatusCode is not evaluated self._add_unique_reference(parentdata, desc) # FIXME return StatusCode is not evaluated
def _add_ref_to_parent(self, nodedata: NodeData, item: ua.AddNodesItem, parentdata: NodeData): def _add_ref_to_parent(self, nodedata: NodeData, item: ua.AddNodesItem, parentdata: NodeData):
addref = ua.AddReferencesItem() addref = ua.AddReferencesItem()
addref.ReferenceTypeId = item.ReferenceTypeId addref.ReferenceTypeId = item.ReferenceTypeId
addref.SourceNodeId = nodedata.nodeid addref.SourceNodeId = nodedata.nodeid
addref.TargetNodeId = item.ParentNodeId addref.TargetNodeId = item.ParentNodeId
addref.TargetNodeClass = parentdata.attributes[ua.AttributeIds.NodeClass].value.Value.Value # type: ignore[union-attr] addref.TargetNodeClass = parentdata.attributes[ua.AttributeIds.NodeClass].value.Value.Value # type: ignore[union-attr]
addref.IsForward = False # FIXME in uaprotocol_auto.py addref.IsForward = False # FIXME in uaprotocol_auto.py
self._add_reference_no_check(nodedata, addref) # FIXME return StatusCode is not evaluated self._add_reference_no_check(nodedata, addref) # FIXME return StatusCode is not evaluated
def _add_type_definition(self, nodedata: NodeData, item: ua.AddNodesItem): def _add_type_definition(self, nodedata: NodeData, item: ua.AddNodesItem):
addref = ua.AddReferencesItem() addref = ua.AddReferencesItem()
addref.SourceNodeId = nodedata.nodeid addref.SourceNodeId = nodedata.nodeid
addref.IsForward = True # FIXME in uaprotocol_auto.py addref.IsForward = True # FIXME in uaprotocol_auto.py
addref.ReferenceTypeId = ua.NodeId(ua.ObjectIds.HasTypeDefinition) addref.ReferenceTypeId = ua.NodeId(ua.ObjectIds.HasTypeDefinition)
addref.TargetNodeId = item.TypeDefinition addref.TargetNodeId = item.TypeDefinition
addref.TargetNodeClass = ua.NodeClass.DataType addref.TargetNodeClass = ua.NodeClass.DataType
self._add_reference_no_check(nodedata, addref) # FIXME return StatusCode is not evaluated self._add_reference_no_check(nodedata, addref) # FIXME return StatusCode is not evaluated
def delete_nodes( def delete_nodes(
self, self,
...@@ -427,7 +427,7 @@ class NodeManagementService: ...@@ -427,7 +427,7 @@ class NodeManagementService:
"Error calling delete node callback callback %s, %s, %s", nodedata, ua.AttributeIds.Value, ex "Error calling delete node callback callback %s, %s, %s", nodedata, ua.AttributeIds.Value, ex
) )
def add_references(self, refs: List[ua.AddReferencesItem], user: User = User(role=UserRole.Admin)): # FIXME return type def add_references(self, refs: List[ua.AddReferencesItem], user: User = User(role=UserRole.Admin)): # FIXME return type
result = [self._add_reference(ref, user) for ref in refs] result = [self._add_reference(ref, user) for ref in refs]
return result return result
...@@ -612,7 +612,7 @@ class AddressSpace: ...@@ -612,7 +612,7 @@ class AddressSpace:
https://reference.opcfoundation.org/Core/docs/Part3/ https://reference.opcfoundation.org/Core/docs/Part3/
""" """
def __init__(self): def __init__(self) -> None:
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.force_server_timestamp: bool = True self.force_server_timestamp: bool = True
self._nodes: Dict[ua.NodeId, NodeData] = {} self._nodes: Dict[ua.NodeId, NodeData] = {}
...@@ -625,7 +625,7 @@ class AddressSpace: ...@@ -625,7 +625,7 @@ class AddressSpace:
return self._nodes.__getitem__(nodeid) return self._nodes.__getitem__(nodeid)
def get(self, nodeid: ua.NodeId) -> Union[NodeData, None]: def get(self, nodeid: ua.NodeId) -> Union[NodeData, None]:
return self._nodes.get(nodeid, None) # Fixme This is another behaviour than __getitem__ where an KeyError exception is thrown, right? return self._nodes.get(nodeid, None) # Fixme This is another behaviour than __getitem__ where an KeyError exception is thrown, right?
def __setitem__(self, nodeid: ua.NodeId, value: NodeData): def __setitem__(self, nodeid: ua.NodeId, value: NodeData):
return self._nodes.__setitem__(nodeid, value) return self._nodes.__setitem__(nodeid, value)
...@@ -801,11 +801,11 @@ class AddressSpace: ...@@ -801,11 +801,11 @@ class AddressSpace:
def _is_expected_variant_type(self, value: ua.DataValue, attval: AttributeValue, node: NodeData) -> bool: def _is_expected_variant_type(self, value: ua.DataValue, attval: AttributeValue, node: NodeData) -> bool:
# FIXME Type hinting reveals that it is possible that Value (Optional) is None which would raise an exception # FIXME Type hinting reveals that it is possible that Value (Optional) is None which would raise an exception
vtype = attval.value.Value.VariantType # type: ignore[union-attr] vtype = attval.value.Value.VariantType # type: ignore[union-attr]
if vtype == ua.VariantType.Null: if vtype == ua.VariantType.Null:
# Node had a null value, many nodes are initialized with that value # Node had a null value, many nodes are initialized with that value
# we should check what the real type is # we should check what the real type is
dtype = node.attributes[ua.AttributeIds.DataType].value.Value.Value # type: ignore[union-attr] dtype = node.attributes[ua.AttributeIds.DataType].value.Value.Value # type: ignore[union-attr]
if dtype.NamespaceIndex == 0 and dtype.Identifier <= 25: if dtype.NamespaceIndex == 0 and dtype.Identifier <= 25:
vtype = ua.VariantType(dtype.Identifier) vtype = ua.VariantType(dtype.Identifier)
else: else:
...@@ -814,8 +814,12 @@ class AddressSpace: ...@@ -814,8 +814,12 @@ class AddressSpace:
return True return True
if value.Value.VariantType == vtype: # type: ignore[union-attr] if value.Value.VariantType == vtype: # type: ignore[union-attr]
return True return True
_logger.warning("Write refused: Variant: %s with type %s does not have expected type: %s", _logger.warning(
value.Value, value.Value.VariantType, attval.value.Value.VariantType) # type: ignore[union-attr] "Write refused: Variant: %s with type %s does not have expected type: %s",
value.Value,
value.Value.VariantType if value.Value else None,
attval.value.Value.VariantType if attval.value.Value else None,
)
return False return False
def add_datachange_callback(self, nodeid: ua.NodeId, attr: ua.AttributeIds, callback: Callable) -> Tuple[ua.StatusCode, int]: def add_datachange_callback(self, nodeid: ua.NodeId, attr: ua.AttributeIds, callback: Callable) -> Tuple[ua.StatusCode, int]:
......
...@@ -5,6 +5,7 @@ import uuid ...@@ -5,6 +5,7 @@ import uuid
import sys import sys
from asyncua import ua from asyncua import ua
from asyncua.server.internal_session import InternalSession
from ..common import events, event_objects, Node from ..common import events, event_objects, Node
...@@ -20,7 +21,7 @@ class EventGenerator: ...@@ -20,7 +21,7 @@ class EventGenerator:
etype: The event type, either an objectId, a NodeId or a Node object etype: The event type, either an objectId, a NodeId or a Node object
""" """
def __init__(self, isession): def __init__(self, isession: InternalSession):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.isession = isession self.isession = isession
self.event: event_objects.BaseEvent = None self.event: event_objects.BaseEvent = None
...@@ -91,7 +92,7 @@ class EventGenerator: ...@@ -91,7 +92,7 @@ class EventGenerator:
self.event.LocalTime = ua.uaprotocol_auto.TimeZoneDataType() self.event.LocalTime = ua.uaprotocol_auto.TimeZoneDataType()
if sys.version_info.major > 2: if sys.version_info.major > 2:
localtime = time.localtime(self.event.Time.timestamp()) localtime = time.localtime(self.event.Time.timestamp())
self.event.LocalTime.Offset = localtime.tm_gmtoff//60 self.event.LocalTime.Offset = localtime.tm_gmtoff // 60
else: else:
localtime = time.localtime(time.mktime(self.event.Time.timetuple())) localtime = time.localtime(time.mktime(self.event.Time.timetuple()))
self.event.LocalTime.Offset = -(time.altzone if localtime.tm_isdst else time.timezone) self.event.LocalTime.Offset = -(time.altzone if localtime.tm_isdst else time.timezone)
......
...@@ -22,7 +22,7 @@ class HistorySQLite(HistoryStorageInterface): ...@@ -22,7 +22,7 @@ class HistorySQLite(HistoryStorageInterface):
note that PARSE_DECLTYPES is active so certain data types (such as datetime) will not be BLOBs note that PARSE_DECLTYPES is active so certain data types (such as datetime) will not be BLOBs
""" """
def __init__(self, path="history.db", max_history_data_response_size=10000): def __init__(self, path="history.db", max_history_data_response_size=10000) -> None:
self.max_history_data_response_size = max_history_data_response_size self.max_history_data_response_size = max_history_data_response_size
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self._datachanges_period = {} self._datachanges_period = {}
......
...@@ -10,7 +10,6 @@ from struct import unpack_from ...@@ -10,7 +10,6 @@ from struct import unpack_from
from pathlib import Path from pathlib import Path
import logging import logging
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Coroutine
from asyncua import ua from asyncua import ua
from .user_managers import PermissiveUserManager, UserManager from .user_managers import PermissiveUserManager, UserManager
...@@ -151,7 +150,7 @@ class InternalServer: ...@@ -151,7 +150,7 @@ class InternalServer:
# path was supplied, but file doesn't exist - create one for next start up # path was supplied, but file doesn't exist - create one for next start up
await asyncio.get_running_loop().run_in_executor(None, self.aspace.make_aspace_shelf, shelf_file) await asyncio.get_running_loop().run_in_executor(None, self.aspace.make_aspace_shelf, shelf_file)
async def _address_space_fixes(self) -> Coroutine: # type: ignore async def _address_space_fixes(self): # type: ignore
""" """
Looks like the xml definition of address space has some error. This is a good place to fix them Looks like the xml definition of address space has some error. This is a good place to fix them
""" """
......
...@@ -11,6 +11,7 @@ from asyncua import ua ...@@ -11,6 +11,7 @@ from asyncua import ua
from .monitored_item_service import MonitoredItemService from .monitored_item_service import MonitoredItemService
from .address_space import AddressSpace from .address_space import AddressSpace
class InternalSubscription: class InternalSubscription:
""" """
Server internal subscription. Server internal subscription.
......
...@@ -22,7 +22,7 @@ class MonitoredItemData: ...@@ -22,7 +22,7 @@ class MonitoredItemData:
class MonitoredItemValues: class MonitoredItemValues:
def __init__(self): def __init__(self) -> None:
self.current_dvalue: Optional[ua.DataValue] = None self.current_dvalue: Optional[ua.DataValue] = None
self.old_dvalue: Optional[ua.DataValue] = None self.old_dvalue: Optional[ua.DataValue] = None
...@@ -117,8 +117,8 @@ class MonitoredItemService: ...@@ -117,8 +117,8 @@ class MonitoredItemService:
params.ItemToMonitor.AttributeId) params.ItemToMonitor.AttributeId)
result, mdata = self._make_monitored_item_common(params) result, mdata = self._make_monitored_item_common(params)
ev_notify_byte = self.aspace.read_attribute_value(params.ItemToMonitor.NodeId, # type: ignore[union-attr] ev_notify_byte = self.aspace.read_attribute_value(params.ItemToMonitor.NodeId, # type: ignore[union-attr]
ua.AttributeIds.EventNotifier).Value.Value ua.AttributeIds.EventNotifier).Value.Value
if ev_notify_byte is None or not ua.ua_binary.test_bit(ev_notify_byte, ua.EventNotifier.SubscribeToEvents): if ev_notify_byte is None or not ua.ua_binary.test_bit(ev_notify_byte, ua.EventNotifier.SubscribeToEvents):
result.StatusCode = ua.StatusCode(ua.StatusCodes.BadServiceUnsupported) result.StatusCode = ua.StatusCode(ua.StatusCodes.BadServiceUnsupported)
...@@ -189,13 +189,12 @@ class MonitoredItemService: ...@@ -189,13 +189,12 @@ class MonitoredItemService:
if old.StatusCode != current.StatusCode: if old.StatusCode != current.StatusCode:
return True return True
if trg in [ua.DataChangeTrigger.StatusValue,ua.DataChangeTrigger.StatusValueTimestamp ] and \ if trg in [ua.DataChangeTrigger.StatusValue, ua.DataChangeTrigger.StatusValueTimestamp] and old.Value != current.Value:
old.Value != current.Value:
return True return True
if trg == ua.DataChangeTrigger.StatusValueTimestamp and \ if trg == ua.DataChangeTrigger.StatusValueTimestamp and (
(old.SourceTimestamp != current.SourceTimestamp or old.SourceTimestamp != current.SourceTimestamp or old.SourcePicoseconds != current.SourcePicoseconds
old.SourcePicoseconds != current.SourcePicoseconds): ):
return True return True
return False return False
...@@ -213,8 +212,9 @@ class MonitoredItemService: ...@@ -213,8 +212,9 @@ class MonitoredItemService:
mdata = self._monitored_items[mid] mdata = self._monitored_items[mid]
mdata.mvalue.set_current_datavalue(value) mdata.mvalue.set_current_datavalue(value)
if mdata.filter: if mdata.filter:
deadband_flag_pass = self._is_data_changed(mdata.mvalue, mdata.filter.Trigger) and \ deadband_flag_pass = self._is_data_changed(
self._is_deadband_exceeded(mdata.mvalue, mdata.filter) mdata.mvalue, mdata.filter.Trigger
) and self._is_deadband_exceeded(mdata.mvalue, mdata.filter)
else: else:
# Trigger defaults to StatusValue # Trigger defaults to StatusValue
deadband_flag_pass = self._is_data_changed(mdata.mvalue, ua.DataChangeTrigger.StatusValue) deadband_flag_pass = self._is_data_changed(mdata.mvalue, ua.DataChangeTrigger.StatusValue)
......
...@@ -8,7 +8,7 @@ import math ...@@ -8,7 +8,7 @@ import math
from datetime import timedelta, datetime from datetime import timedelta, datetime
import socket import socket
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Coroutine, Optional, Tuple, Union from typing import Optional, Tuple, Union
from pathlib import Path from pathlib import Path
from asyncua import ua from asyncua import ua
...@@ -124,7 +124,7 @@ class Server: ...@@ -124,7 +124,7 @@ class Server:
await self.set_application_uri(self._application_uri) await self.set_application_uri(self._application_uri)
sa_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_ServerArray)) sa_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_ServerArray))
await sa_node.write_value([self._application_uri]) await sa_node.write_value([self._application_uri])
#TODO: ServiceLevel is 255 default, should be calculated in later Versions # TODO: ServiceLevel is 255 default, should be calculated in later Versions
sl_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_ServiceLevel)) sl_node = self.get_node(ua.NodeId(ua.ObjectIds.Server_ServiceLevel))
await sl_node.write_value(ua.Variant(255, ua.VariantType.Byte)) await sl_node.write_value(ua.Variant(255, ua.VariantType.Byte))
...@@ -165,7 +165,7 @@ class Server: ...@@ -165,7 +165,7 @@ class Server:
product_name, product_name,
software_version, software_version,
build_number build_number
]): ]):
raise TypeError(f"""Expected all str got raise TypeError(f"""Expected all str got
product_uri: {type(product_uri)}, product_uri: {type(product_uri)},
manufacturer_name: {type(manufacturer_name)}, manufacturer_name: {type(manufacturer_name)},
...@@ -306,7 +306,7 @@ class Server: ...@@ -306,7 +306,7 @@ class Server:
async def _renew_registration(self): async def _renew_registration(self):
for client in self._discovery_clients.values(): for client in self._discovery_clients.values():
await client.connect_sessionless() await client.connect_sessionless()
await client.register_server(self) #FIXME discovery_configuration? await client.register_server(self) # FIXME discovery_configuration?
await client.disconnect_sessionless() await client.disconnect_sessionless()
def allow_remote_admin(self, allow): def allow_remote_admin(self, allow):
...@@ -318,7 +318,7 @@ class Server: ...@@ -318,7 +318,7 @@ class Server:
def set_endpoint(self, url): def set_endpoint(self, url):
self.endpoint = urlparse(url) self.endpoint = urlparse(url)
async def get_endpoints(self) -> Coroutine: async def get_endpoints(self):
return await self.iserver.get_endpoints() return await self.iserver.get_endpoints()
def set_security_policy(self, security_policy, permission_ruleset=None): def set_security_policy(self, security_policy, permission_ruleset=None):
...@@ -414,9 +414,8 @@ class Server: ...@@ -414,9 +414,8 @@ class Server:
ua.MessageSecurityMode.SignAndEncrypt, self.certificate, self.iserver.private_key, ua.MessageSecurityMode.SignAndEncrypt, self.certificate, self.iserver.private_key,
permission_ruleset=self._permission_ruleset)) permission_ruleset=self._permission_ruleset))
@staticmethod @staticmethod
def lookup_security_level_for_policy_type( security_policy_type: ua.SecurityPolicyType ) -> ua.Byte: def lookup_security_level_for_policy_type(security_policy_type: ua.SecurityPolicyType) -> ua.Byte:
"""Returns the security level for an ua.SecurityPolicyType. """Returns the security level for an ua.SecurityPolicyType.
This is endpoint & server implementation specific! This is endpoint & server implementation specific!
...@@ -426,20 +425,19 @@ class Server: ...@@ -426,20 +425,19 @@ class Server:
""" """
return ua.Byte({ return ua.Byte({
ua.SecurityPolicyType.NoSecurity : 0, ua.SecurityPolicyType.NoSecurity: 0,
ua.SecurityPolicyType.Basic128Rsa15_Sign : 1, ua.SecurityPolicyType.Basic128Rsa15_Sign: 1,
ua.SecurityPolicyType.Basic128Rsa15_SignAndEncrypt : 2, ua.SecurityPolicyType.Basic128Rsa15_SignAndEncrypt: 2,
ua.SecurityPolicyType.Basic256_Sign : 11, ua.SecurityPolicyType.Basic256_Sign: 11,
ua.SecurityPolicyType.Basic256_SignAndEncrypt : 21, ua.SecurityPolicyType.Basic256_SignAndEncrypt: 21,
ua.SecurityPolicyType.Basic256Sha256_Sign : 50, ua.SecurityPolicyType.Basic256Sha256_Sign: 50,
ua.SecurityPolicyType.Basic256Sha256_SignAndEncrypt : 70, ua.SecurityPolicyType.Basic256Sha256_SignAndEncrypt: 70,
ua.SecurityPolicyType.Aes128Sha256RsaOaep_Sign : 55, ua.SecurityPolicyType.Aes128Sha256RsaOaep_Sign: 55,
ua.SecurityPolicyType.Aes128Sha256RsaOaep_SignAndEncrypt : 75 ua.SecurityPolicyType.Aes128Sha256RsaOaep_SignAndEncrypt: 75
}[security_policy_type]) }[security_policy_type])
@staticmethod @staticmethod
def determine_security_level(security_policy_uri:str, security_mode: ua.MessageSecurityMode) -> ua.Byte: def determine_security_level(security_policy_uri: str, security_mode: ua.MessageSecurityMode) -> ua.Byte:
"""Determine the security level of an EndPoint. """Determine the security level of an EndPoint.
The security level indicates how secure an EndPoint is, compared to other EndPoints of the same server. The security level indicates how secure an EndPoint is, compared to other EndPoints of the same server.
Value 0 is a special value; EndPoint isn't recommended, typical for ua.MessageSecurityMode.None_. Value 0 is a special value; EndPoint isn't recommended, typical for ua.MessageSecurityMode.None_.
...@@ -569,7 +567,7 @@ class Server: ...@@ -569,7 +567,7 @@ class Server:
""" """
return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.ObjectsFolder)) return self.get_node(ua.TwoByteNodeId(ua.ObjectIds.ObjectsFolder))
def get_node(self, nodeid): def get_node(self, nodeid: Union[Node, ua.NodeId, str, int]) -> Node:
""" """
Get a specific node using NodeId object or a string representing a NodeId Get a specific node using NodeId object or a string representing a NodeId
""" """
...@@ -629,7 +627,7 @@ class Server: ...@@ -629,7 +627,7 @@ class Server:
return ev_gen return ev_gen
async def create_custom_data_type(self, idx, name, basetype=ua.ObjectIds.BaseDataType, async def create_custom_data_type(self, idx, name, basetype=ua.ObjectIds.BaseDataType,
properties=None, description=None) -> Coroutine: properties=None, description=None):
if properties is None: if properties is None:
properties = [] properties = []
base_t = _get_node(self.iserver.isession, basetype) base_t = _get_node(self.iserver.isession, basetype)
...@@ -644,7 +642,7 @@ class Server: ...@@ -644,7 +642,7 @@ class Server:
return custom_t return custom_t
async def create_custom_event_type(self, idx, name, async def create_custom_event_type(self, idx, name,
basetype=ua.ObjectIds.BaseEventType, properties=None) -> Coroutine: basetype=ua.ObjectIds.BaseEventType, properties=None):
if properties is None: if properties is None:
properties = [] properties = []
return await self._create_custom_type(idx, name, basetype, properties, [], []) return await self._create_custom_type(idx, name, basetype, properties, [], [])
...@@ -655,7 +653,7 @@ class Server: ...@@ -655,7 +653,7 @@ class Server:
basetype=ua.ObjectIds.BaseObjectType, basetype=ua.ObjectIds.BaseObjectType,
properties=None, properties=None,
variables=None, variables=None,
methods=None) -> Coroutine: methods=None):
if properties is None: if properties is None:
properties = [] properties = []
if variables is None: if variables is None:
...@@ -673,7 +671,7 @@ class Server: ...@@ -673,7 +671,7 @@ class Server:
basetype=ua.ObjectIds.BaseVariableType, basetype=ua.ObjectIds.BaseVariableType,
properties=None, properties=None,
variables=None, variables=None,
methods=None) -> Coroutine: methods=None):
if properties is None: if properties is None:
properties = [] properties = []
if variables is None: if variables is None:
...@@ -701,7 +699,7 @@ class Server: ...@@ -701,7 +699,7 @@ class Server:
await custom_t.add_method(idx, method[0], method[1], method[2], method[3]) await custom_t.add_method(idx, method[0], method[1], method[2], method[3])
return custom_t return custom_t
async def import_xml(self, path=None, xmlstring=None, strict_mode=True) -> Coroutine: async def import_xml(self, path=None, xmlstring=None, strict_mode=True):
""" """
Import nodes defined in xml Import nodes defined in xml
""" """
...@@ -731,7 +729,7 @@ class Server: ...@@ -731,7 +729,7 @@ class Server:
nodes = await get_nodes_of_namespace(self, namespaces) nodes = await get_nodes_of_namespace(self, namespaces)
await self.export_xml(nodes, path, export_values=export_values) await self.export_xml(nodes, path, export_values=export_values)
async def delete_nodes(self, nodes, recursive=False) -> Coroutine: async def delete_nodes(self, nodes, recursive=False):
return await delete_nodes(self.iserver.isession, nodes, recursive) return await delete_nodes(self.iserver.isession, nodes, recursive)
async def historize_node_data_change(self, node, period=timedelta(days=7), count=0): async def historize_node_data_change(self, node, period=timedelta(days=7), count=0):
...@@ -789,7 +787,7 @@ class Server: ...@@ -789,7 +787,7 @@ class Server:
""" """
self.iserver.isession.add_method_callback(node.nodeid, callback) self.iserver.isession.add_method_callback(node.nodeid, callback)
async def load_type_definitions(self, nodes=None) -> Coroutine: async def load_type_definitions(self, nodes=None):
""" """
load custom structures from our server. load custom structures from our server.
Server side this can be used to create python objects from custom structures Server side this can be used to create python objects from custom structures
...@@ -806,7 +804,7 @@ class Server: ...@@ -806,7 +804,7 @@ class Server:
""" """
return await load_data_type_definitions(self, node) return await load_data_type_definitions(self, node)
async def load_enums(self) -> Coroutine: async def load_enums(self):
""" """
load UA structures and generate python Enums in ua module for custom enums in server load UA structures and generate python Enums in ua module for custom enums in server
""" """
......
...@@ -12,4 +12,6 @@ show_error_codes = True ...@@ -12,4 +12,6 @@ show_error_codes = True
check_untyped_defs = False check_untyped_defs = False
[mypy-asyncua.ua.uaprotocol_auto.*] [mypy-asyncua.ua.uaprotocol_auto.*]
# Autogenerated file # Autogenerated file
disable_error_code = literal-required disable_error_code = literal-required
\ No newline at end of file [mypy-asynctest.*]
ignore_missing_imports = True
...@@ -19,14 +19,11 @@ from asyncua.server.history_sql import HistorySQLite ...@@ -19,14 +19,11 @@ from asyncua.server.history_sql import HistorySQLite
from .test_common import add_server_methods from .test_common import add_server_methods
from .util_enum_struct import add_server_custom_enum_struct from .util_enum_struct import add_server_custom_enum_struct
RETRY = 20 RETRY = 20
SLEEP = 0.4 SLEEP = 0.4
PORTS_USED = set() PORTS_USED = set()
Opc = namedtuple('opc', ['opc', 'server']) Opc = namedtuple('Opc', ['opc', 'server'])
def find_free_port(): def find_free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
...@@ -39,6 +36,7 @@ def find_free_port(): ...@@ -39,6 +36,7 @@ def find_free_port():
else: else:
return find_free_port() return find_free_port()
port_num = find_free_port() port_num = find_free_port()
port_num1 = find_free_port() port_num1 = find_free_port()
port_discovery = find_free_port() port_discovery = find_free_port()
...@@ -266,6 +264,7 @@ async def history_server(request): ...@@ -266,6 +264,7 @@ async def history_server(request):
yield srv yield srv
await srv.srv.stop() await srv.srv.stop()
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def client_key_and_cert(request): def client_key_and_cert(request):
base_dir = Path(__file__).parent.parent base_dir = Path(__file__).parent.parent
......
from pathlib import Path from pathlib import Path
from typing import Tuple
import pytest import pytest
import asyncio import asyncio
...@@ -107,6 +108,7 @@ async def srv_crypto_one_cert(request): ...@@ -107,6 +108,7 @@ async def srv_crypto_one_cert(request):
# stop the server # stop the server
await srv.stop() await srv.stop()
@pytest.fixture(params=srv_crypto_params) @pytest.fixture(params=srv_crypto_params)
async def srv_crypto_all_cert_basic128rsa15(request): async def srv_crypto_all_cert_basic128rsa15(request):
# start our own server # start our own server
...@@ -479,6 +481,7 @@ async def test_anonymous_rejection(): ...@@ -479,6 +481,7 @@ async def test_anonymous_rejection():
await clt.connect() await clt.connect()
await srv.stop() await srv.stop()
async def test_security_level_all(): async def test_security_level_all():
assert Server.determine_security_level(ua.SecurityPolicy.URI, ua.MessageSecurityMode.None_) == Server.lookup_security_level_for_policy_type(ua.SecurityPolicyType.NoSecurity) assert Server.determine_security_level(ua.SecurityPolicy.URI, ua.MessageSecurityMode.None_) == Server.lookup_security_level_for_policy_type(ua.SecurityPolicyType.NoSecurity)
...@@ -495,7 +498,8 @@ async def test_security_level_all(): ...@@ -495,7 +498,8 @@ async def test_security_level_all():
assert Server.determine_security_level(security_policies.SecurityPolicyBasic256.URI, ua.MessageSecurityMode.Sign) == Server.lookup_security_level_for_policy_type(ua.SecurityPolicyType.Basic256_Sign) assert Server.determine_security_level(security_policies.SecurityPolicyBasic256.URI, ua.MessageSecurityMode.Sign) == Server.lookup_security_level_for_policy_type(ua.SecurityPolicyType.Basic256_Sign)
assert Server.determine_security_level(security_policies.SecurityPolicyBasic256.URI, ua.MessageSecurityMode.SignAndEncrypt) == Server.lookup_security_level_for_policy_type(ua.SecurityPolicyType.Basic256_SignAndEncrypt) assert Server.determine_security_level(security_policies.SecurityPolicyBasic256.URI, ua.MessageSecurityMode.SignAndEncrypt) == Server.lookup_security_level_for_policy_type(ua.SecurityPolicyType.Basic256_SignAndEncrypt)
async def test_security_level_endpoints(srv_crypto_all_certs):
async def test_security_level_endpoints(srv_crypto_all_certs: Tuple[Server, str]):
srv = srv_crypto_all_certs[0] srv = srv_crypto_all_certs[0]
end_points: list[ua.EndpointDescription] = await srv.get_endpoints() end_points: list[ua.EndpointDescription] = await srv.get_endpoints()
......
...@@ -9,7 +9,7 @@ from cryptography.x509.extensions import _key_identifier_from_public_key as key_ ...@@ -9,7 +9,7 @@ from cryptography.x509.extensions import _key_identifier_from_public_key as key_
from asyncua.crypto.cert_gen import generate_private_key, generate_app_certificate_signing_request, generate_self_signed_app_certificate, sign_certificate_request from asyncua.crypto.cert_gen import generate_private_key, generate_app_certificate_signing_request, generate_self_signed_app_certificate, sign_certificate_request
async def test_create_self_signed_app_certificate(): async def test_create_self_signed_app_certificate() -> None:
""" Checks if the self signed certificate complies to OPC 10000-6 6.2.2""" """ Checks if the self signed certificate complies to OPC 10000-6 6.2.2"""
hostname = socket.gethostname() hostname = socket.gethostname()
...@@ -52,7 +52,7 @@ async def test_create_self_signed_app_certificate(): ...@@ -52,7 +52,7 @@ async def test_create_self_signed_app_certificate():
# check valid time range # check valid time range
assert dt_before_generation <= cert.not_valid_before <= dt_after_generation assert dt_before_generation <= cert.not_valid_before <= dt_after_generation
assert (dt_before_generation+timedelta(days_valid)) <= cert.not_valid_after <= (dt_after_generation+timedelta(days_valid)) assert (dt_before_generation + timedelta(days_valid)) <= cert.not_valid_after <= (dt_after_generation + timedelta(days_valid))
# check issuer # check issuer
assert cert.issuer.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value == f"myserver@{hostname}" assert cert.issuer.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value == f"myserver@{hostname}"
...@@ -63,6 +63,8 @@ async def test_create_self_signed_app_certificate(): ...@@ -63,6 +63,8 @@ async def test_create_self_signed_app_certificate():
# check Authority Key Identifier # check Authority Key Identifier
auth_key_identifier: x509.AuthorityKeyIdentifier = cert.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier).value auth_key_identifier: x509.AuthorityKeyIdentifier = cert.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier).value
assert auth_key_identifier assert auth_key_identifier
assert isinstance(auth_key_identifier.authority_cert_issuer, list)
assert len(auth_key_identifier.authority_cert_issuer) > 0
issuer: x509.Name = auth_key_identifier.authority_cert_issuer[0].value issuer: x509.Name = auth_key_identifier.authority_cert_issuer[0].value
assert issuer.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value == f"myserver@{hostname}" assert issuer.get_attributes_for_oid(NameOID.COMMON_NAME)[0].value == f"myserver@{hostname}"
assert auth_key_identifier.authority_cert_serial_number == cert.serial_number assert auth_key_identifier.authority_cert_serial_number == cert.serial_number
...@@ -89,7 +91,7 @@ async def test_create_self_signed_app_certificate(): ...@@ -89,7 +91,7 @@ async def test_create_self_signed_app_certificate():
assert cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value == x509.ExtendedKeyUsage(extended) assert cert.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value == x509.ExtendedKeyUsage(extended)
async def test_app_create_certificate_signing_request(): async def test_app_create_certificate_signing_request() -> None:
""" Checks if the self signed certificate complies to OPC 10000-6 6.2.2""" """ Checks if the self signed certificate complies to OPC 10000-6 6.2.2"""
hostname = socket.gethostname() hostname = socket.gethostname()
...@@ -139,7 +141,7 @@ async def test_app_create_certificate_signing_request(): ...@@ -139,7 +141,7 @@ async def test_app_create_certificate_signing_request():
assert csr.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value == x509.ExtendedKeyUsage(extended) assert csr.extensions.get_extension_for_class(x509.ExtendedKeyUsage).value == x509.ExtendedKeyUsage(extended)
async def test_app_sign_certificate_request(): async def test_app_sign_certificate_request() -> None:
"""Check the correct signing of certificate signing request""" """Check the correct signing of certificate signing request"""
hostname = socket.gethostname() hostname = socket.gethostname()
...@@ -180,6 +182,8 @@ async def test_app_sign_certificate_request(): ...@@ -180,6 +182,8 @@ async def test_app_sign_certificate_request():
# check authority Key Identifier # check authority Key Identifier
auth_key_identifier: x509.AuthorityKeyIdentifier = cert.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier).value auth_key_identifier: x509.AuthorityKeyIdentifier = cert.extensions.get_extension_for_class(x509.AuthorityKeyIdentifier).value
assert auth_key_identifier assert auth_key_identifier
assert isinstance(auth_key_identifier.authority_cert_issuer, list)
assert len(auth_key_identifier.authority_cert_issuer) > 0
assert auth_key_identifier.authority_cert_issuer[0].value == issuer.subject assert auth_key_identifier.authority_cert_issuer[0].value == issuer.subject
assert auth_key_identifier.authority_cert_serial_number == issuer.serial_number assert auth_key_identifier.authority_cert_serial_number == issuer.serial_number
assert auth_key_identifier.key_identifier == key_identifier_from_public_key(key_ca.public_key()) assert auth_key_identifier.key_identifier == key_identifier_from_public_key(key_ca.public_key())
......
...@@ -189,6 +189,7 @@ async def test_references_for_added_nodes_method(server): ...@@ -189,6 +189,7 @@ async def test_references_for_added_nodes_method(server):
assert await m.get_parent() == o assert await m.get_parent() == o
await server.delete_nodes([o]) await server.delete_nodes([o])
async def test_get_event_from_type_node_BaseEvent(server): async def test_get_event_from_type_node_BaseEvent(server):
""" """
This should work for following BaseEvent tests to work This should work for following BaseEvent tests to work
...@@ -299,6 +300,7 @@ async def test_eventgenerator_sourceMyObject(server): ...@@ -299,6 +300,7 @@ async def test_eventgenerator_sourceMyObject(server):
await check_event_generator_object(evgen, o) await check_event_generator_object(evgen, o)
await server.delete_nodes([o]) await server.delete_nodes([o])
async def test_eventgenerator_source_collision(server): async def test_eventgenerator_source_collision(server):
objects = server.nodes.objects objects = server.nodes.objects
o = await objects.add_object(3, 'MyObject') o = await objects.add_object(3, 'MyObject')
...@@ -308,6 +310,7 @@ async def test_eventgenerator_source_collision(server): ...@@ -308,6 +310,7 @@ async def test_eventgenerator_source_collision(server):
await check_event_generator_object(evgen, o, emitting_node=asyncua.Node(server.iserver.isession, ua.ObjectIds.Server)) await check_event_generator_object(evgen, o, emitting_node=asyncua.Node(server.iserver.isession, ua.ObjectIds.Server))
await server.delete_nodes([o]) await server.delete_nodes([o])
async def test_eventgenerator_inherited_event(server): async def test_eventgenerator_inherited_event(server):
evgen = await server.get_event_generator(ua.ObjectIds.AuditEventType) evgen = await server.get_event_generator(ua.ObjectIds.AuditEventType)
await check_eventgenerator_source_server(evgen, server) await check_eventgenerator_source_server(evgen, server)
...@@ -402,8 +405,7 @@ async def test_create_custom_event_type_node_id(server): ...@@ -402,8 +405,7 @@ async def test_create_custom_event_type_node_id(server):
async def test_create_custom_event_type_node(server): async def test_create_custom_event_type_node(server):
etype = await server.create_custom_event_type(2, 'MyEvent1', asyncua.Node(server.iserver.isession, etype = await server.create_custom_event_type(2, 'MyEvent1', asyncua.Node(server.iserver.isession, ua.NodeId(ua.ObjectIds.BaseEventType)),
ua.NodeId(ua.ObjectIds.BaseEventType)),
[('PropertyNum', ua.VariantType.Int32), [('PropertyNum', ua.VariantType.Int32),
('PropertyString', ua.VariantType.String)]) ('PropertyString', ua.VariantType.String)])
await check_custom_type(etype, ua.ObjectIds.BaseEventType, server) await check_custom_type(etype, ua.ObjectIds.BaseEventType, server)
...@@ -440,8 +442,7 @@ async def test_eventgenerator_custom_event_with_variables(server): ...@@ -440,8 +442,7 @@ async def test_eventgenerator_custom_event_with_variables(server):
('PropertyString', ua.VariantType.String)] ('PropertyString', ua.VariantType.String)]
variables = [('VariableString', ua.VariantType.String), variables = [('VariableString', ua.VariantType.String),
('MyEnumVar', ua.VariantType.Int32, ua.NodeId(ua.ObjectIds.ApplicationType))] ('MyEnumVar', ua.VariantType.Int32, ua.NodeId(ua.ObjectIds.ApplicationType))]
etype = await server.create_custom_object_type(2, 'MyEvent33', ua.ObjectIds.BaseEventType, etype = await server.create_custom_object_type(2, 'MyEvent33', ua.ObjectIds.BaseEventType, properties, variables)
properties, variables)
evgen = await server.get_event_generator(etype, ua.ObjectIds.Server) evgen = await server.get_event_generator(etype, ua.ObjectIds.Server)
check_eventgenerator_custom_event(evgen, etype, server) check_eventgenerator_custom_event(evgen, etype, server)
await check_eventgenerator_source_server(evgen, server) await check_eventgenerator_source_server(evgen, server)
...@@ -555,8 +556,11 @@ async def test_get_node_by_ns(server): ...@@ -555,8 +556,11 @@ async def test_get_node_by_ns(server):
async def test_load_enum_strings(server): async def test_load_enum_strings(server):
dt = await server.nodes.enum_data_type.add_data_type(0, "MyStringEnum") dt = await server.nodes.enum_data_type.add_data_type(0, "MyStringEnum")
await dt.add_property(0, "EnumStrings", [ua.LocalizedText("e1"), ua.LocalizedText("e2"), ua.LocalizedText("e3"), await dt.add_property(
ua.LocalizedText("e 4")]) 0,
"EnumStrings",
[ua.LocalizedText("e1"), ua.LocalizedText("e2"), ua.LocalizedText("e3"), ua.LocalizedText("e 4")]
)
await server.load_enums() await server.load_enums()
e = getattr(ua, "MyStringEnum") e = getattr(ua, "MyStringEnum")
assert isinstance(e, EnumMeta) assert isinstance(e, EnumMeta)
...@@ -567,18 +571,9 @@ async def test_load_enum_strings(server): ...@@ -567,18 +571,9 @@ async def test_load_enum_strings(server):
async def test_load_enum_values(server): async def test_load_enum_values(server):
dt = await server.nodes.enum_data_type.add_data_type(0, "MyValuesEnum") dt = await server.nodes.enum_data_type.add_data_type(0, "MyValuesEnum")
v1 = ua.EnumValueType( v1 = ua.EnumValueType(DisplayName=ua.LocalizedText("v1"), Value=2)
DisplayName=ua.LocalizedText("v1"), v2 = ua.EnumValueType(DisplayName=ua.LocalizedText("v2"), Value=3)
Value=2, v3 = ua.EnumValueType(DisplayName=ua.LocalizedText("v 3 "), Value=4.)
)
v2 = ua.EnumValueType(
DisplayName=ua.LocalizedText("v2"),
Value=3,
)
v3 = ua.EnumValueType(
DisplayName=ua.LocalizedText("v 3 "),
Value=4.
)
await dt.add_property(0, "EnumValues", [v1, v2, v3]) await dt.add_property(0, "EnumValues", [v1, v2, v3])
await server.load_enums() await server.load_enums()
e = getattr(ua, "MyValuesEnum") e = getattr(ua, "MyValuesEnum")
...@@ -645,8 +640,7 @@ def check_custom_event(ev, etype): ...@@ -645,8 +640,7 @@ def check_custom_event(ev, etype):
async def check_custom_type(ntype, base_type, server: Server, node_class=None): async def check_custom_type(ntype, base_type, server: Server, node_class=None):
base = asyncua.Node(server.iserver.isession, ua.NodeId(base_type)) base = asyncua.Node(server.iserver.isession, ua.NodeId(base_type))
assert ntype in await base.get_children() assert ntype in await base.get_children()
nodes = await ntype.get_referenced_nodes(refs=ua.ObjectIds.HasSubtype, direction=ua.BrowseDirection.Inverse, nodes = await ntype.get_referenced_nodes(refs=ua.ObjectIds.HasSubtype, direction=ua.BrowseDirection.Inverse, includesubtypes=True)
includesubtypes=True)
assert base == nodes[0] assert base == nodes[0]
if node_class: if node_class:
assert node_class == await ntype.read_node_class() assert node_class == await ntype.read_node_class()
...@@ -658,6 +652,7 @@ async def check_custom_type(ntype, base_type, server: Server, node_class=None): ...@@ -658,6 +652,7 @@ async def check_custom_type(ntype, base_type, server: Server, node_class=None):
assert await ntype.get_child("2:PropertyString") in properties assert await ntype.get_child("2:PropertyString") in properties
assert (await(await ntype.get_child("2:PropertyString")).read_data_value()).Value.VariantType == ua.VariantType.String assert (await(await ntype.get_child("2:PropertyString")).read_data_value()).Value.VariantType == ua.VariantType.String
async def test_server_read_write_attribute_value(server: Server): async def test_server_read_write_attribute_value(server: Server):
node = await server.get_objects_node().add_variable(0, "0:TestVar", 0, varianttype=ua.VariantType.Int64) node = await server.get_objects_node().add_variable(0, "0:TestVar", 0, varianttype=ua.VariantType.Int64)
dv = server.read_attribute_value(node.nodeid, attr=ua.AttributeIds.Value) dv = server.read_attribute_value(node.nodeid, attr=ua.AttributeIds.Value)
...@@ -672,6 +667,7 @@ async def test_server_read_write_attribute_value(server: Server): ...@@ -672,6 +667,7 @@ async def test_server_read_write_attribute_value(server: Server):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def restore_transport_limits_server(server: Server): def restore_transport_limits_server(server: Server):
# Restore limits after test # Restore limits after test
assert server.bserver is not None
max_recv = server.bserver.limits.max_recv_buffer max_recv = server.bserver.limits.max_recv_buffer
max_chunk_count = server.bserver.limits.max_chunk_count max_chunk_count = server.bserver.limits.max_chunk_count
yield server yield server
...@@ -681,6 +677,7 @@ def restore_transport_limits_server(server: Server): ...@@ -681,6 +677,7 @@ def restore_transport_limits_server(server: Server):
async def test_message_limits_fail_write(restore_transport_limits_server: Server): async def test_message_limits_fail_write(restore_transport_limits_server: Server):
server = restore_transport_limits_server server = restore_transport_limits_server
assert server.bserver is not None
server.bserver.limits.max_recv_buffer = 1024 server.bserver.limits.max_recv_buffer = 1024
server.bserver.limits.max_send_buffer = 10240000 server.bserver.limits.max_send_buffer = 10240000
server.bserver.limits.max_chunk_count = 10 server.bserver.limits.max_chunk_count = 10
...@@ -698,6 +695,7 @@ async def test_message_limits_fail_write(restore_transport_limits_server: Server ...@@ -698,6 +695,7 @@ async def test_message_limits_fail_write(restore_transport_limits_server: Server
async def test_message_limits_fail_read(restore_transport_limits_server: Server): async def test_message_limits_fail_read(restore_transport_limits_server: Server):
server = restore_transport_limits_server server = restore_transport_limits_server
assert server.bserver is not None
server.bserver.limits.max_recv_buffer = 10240000 server.bserver.limits.max_recv_buffer = 10240000
server.bserver.limits.max_send_buffer = 1024 server.bserver.limits.max_send_buffer = 1024
server.bserver.limits.max_chunk_count = 10 server.bserver.limits.max_chunk_count = 10
...@@ -716,6 +714,7 @@ async def test_message_limits_fail_read(restore_transport_limits_server: Server) ...@@ -716,6 +714,7 @@ async def test_message_limits_fail_read(restore_transport_limits_server: Server)
async def test_message_limits_works(restore_transport_limits_server: Server): async def test_message_limits_works(restore_transport_limits_server: Server):
server = restore_transport_limits_server server = restore_transport_limits_server
# server.bserver.limits.max_recv_buffer = 1024 # server.bserver.limits.max_recv_buffer = 1024
assert server.bserver is not None
server.bserver.limits.max_send_buffer = 1024 server.bserver.limits.max_send_buffer = 1024
server.bserver.limits.max_chunk_count = 10 server.bserver.limits.max_chunk_count = 10
n = await server.nodes.objects.add_variable(1, "MyLimitVariable2", "t") n = await server.nodes.objects.add_variable(1, "MyLimitVariable2", "t")
...@@ -729,7 +728,6 @@ async def test_message_limits_works(restore_transport_limits_server: Server): ...@@ -729,7 +728,6 @@ async def test_message_limits_works(restore_transport_limits_server: Server):
await n.read_value() await n.read_value()
""" """
class TestServerCaching(unittest.TestCase): class TestServerCaching(unittest.TestCase):
def runTest(self): def runTest(self):
...@@ -755,6 +753,7 @@ class TestServerCaching(unittest.TestCase): ...@@ -755,6 +753,7 @@ class TestServerCaching(unittest.TestCase):
""" """
async def test_null_auth(server): async def test_null_auth(server):
""" """
OPC-UA Specification Part 4, 5.6.3 specifies that a: OPC-UA Specification Part 4, 5.6.3 specifies that a:
...@@ -763,6 +762,7 @@ async def test_null_auth(server): ...@@ -763,6 +762,7 @@ async def test_null_auth(server):
Ensure a Null token is accepted as an anonymous connection token. Ensure a Null token is accepted as an anonymous connection token.
""" """
client = Client(server.endpoint.geturl()) client = Client(server.endpoint.geturl())
# Modify the authentication creation in the client request # Modify the authentication creation in the client request
def _add_null_auth(self, params): def _add_null_auth(self, params):
params.UserIdentityToken = ua.ExtensionObject(ua.NodeId(ua.ObjectIds.Null)) params.UserIdentityToken = ua.ExtensionObject(ua.NodeId(ua.ObjectIds.Null))
...@@ -772,11 +772,11 @@ async def test_null_auth(server): ...@@ -772,11 +772,11 @@ async def test_null_auth(server):
pass pass
async def test_start_server_when_port_is_in_use(server: str): async def test_start_server_when_port_is_in_use(server: Server):
server2 = Server() server2 = Server()
await server2.init() await server2.init()
url = server.endpoint.geturl() url = server.endpoint.geturl()
server2.set_endpoint(url) # try to bind on the same endpoint as an already running server server2.set_endpoint(url) # try to bind on the same endpoint as an already running server
with pytest.raises(OSError): with pytest.raises(OSError):
await server2.start() await server2.start()
# now it should still be possible to stop the server with exceptions # now it should still be possible to stop the server with exceptions
......
...@@ -8,10 +8,12 @@ from asyncua.common.subscription import Subscription ...@@ -8,10 +8,12 @@ from asyncua.common.subscription import Subscription
try: try:
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
except ImportError: except ImportError:
from asynctest import CoroutineMock as AsyncMock from asynctest import CoroutineMock as AsyncMock # type: ignore[no-redef]
import asyncua import asyncua
from asyncua import ua, Client from asyncua import ua, Client
from .conftest import Opc
class MySubHandler: class MySubHandler:
""" """
...@@ -245,7 +247,8 @@ async def test_subscription_data_change(opc): ...@@ -245,7 +247,8 @@ async def test_subscription_data_change(opc):
await sub.unsubscribe(handle1) # sub does not exist anymore await sub.unsubscribe(handle1) # sub does not exist anymore
await opc.opc.delete_nodes([v1]) await opc.opc.delete_nodes([v1])
async def test_subscription_monitored_item(opc):
async def test_subscription_monitored_item(opc: Opc):
""" """
test subscriptions with a monitored item with a datachange filter. test subscriptions with a monitored item with a datachange filter.
...@@ -258,12 +261,11 @@ async def test_subscription_monitored_item(opc): ...@@ -258,12 +261,11 @@ async def test_subscription_monitored_item(opc):
v1 = await o.add_variable(3, 'SubscriptionVariableV1', startv1) v1 = await o.add_variable(3, 'SubscriptionVariableV1', startv1)
sub: Subscription = await opc.opc.create_subscription(100, myhandler) sub: Subscription = await opc.opc.create_subscription(100, myhandler)
mfilter = ua.DataChangeFilter(Trigger=ua.DataChangeTrigger.StatusValueTimestamp) mfilter = ua.DataChangeFilter(Trigger=ua.DataChangeTrigger.StatusValueTimestamp)
#For creating monitor items create_monitored_items is availablem, but that one is not very easy in use. # For creating monitor items create_monitored_items is availablem, but that one is not very easy in use.
#So use the internal function instead. # So use the internal function instead.
#TODO: Should there be an easy shorthand for making monitored items with filter? # TODO: Should there be an easy shorthand for making monitored items with filter?
handles = await sub._subscribe(nodes=v1, mfilter=mfilter) handles = await sub._subscribe(nodes=v1, mfilter=mfilter)
# # Now check we get the start value # # Now check we get the start value
...@@ -1043,6 +1045,7 @@ async def test_publish(opc, mocker): ...@@ -1043,6 +1045,7 @@ async def test_publish(opc, mocker):
publish_event = asyncio.Event() publish_event = asyncio.Event()
publish_org = client.uaclient.publish publish_org = client.uaclient.publish
async def publish(acks): async def publish(acks):
await publish_event.wait() await publish_event.wait()
publish_event.clear() publish_event.clear()
......
...@@ -132,12 +132,12 @@ def trust_store(tmp_path) -> TrustStore: ...@@ -132,12 +132,12 @@ def trust_store(tmp_path) -> TrustStore:
return _trust_store return _trust_store
async def test_selfsigned_not_in_trust_store(cert_files, trust_store): async def test_selfsigned_not_in_trust_store(cert_files, trust_store) -> None:
cert_self_signed: x509.Certificate = await load_certificate(cert_files / SERVER_CERT_SELF_SIGNED_FILE) cert_self_signed: x509.Certificate = await load_certificate(cert_files / SERVER_CERT_SELF_SIGNED_FILE)
assert trust_store.is_trusted(cert_self_signed) is False assert trust_store.is_trusted(cert_self_signed) is False
async def test_selfsigned_in_trust_store(cert_files, trust_store): async def test_selfsigned_in_trust_store(cert_files, trust_store) -> None:
shutil.copyfile(cert_files / SERVER_CERT_SELF_SIGNED_FILE, trust_store.trust_locations[0] / SERVER_CERT_SELF_SIGNED_FILE) shutil.copyfile(cert_files / SERVER_CERT_SELF_SIGNED_FILE, trust_store.trust_locations[0] / SERVER_CERT_SELF_SIGNED_FILE)
await trust_store.load() await trust_store.load()
...@@ -145,12 +145,12 @@ async def test_selfsigned_in_trust_store(cert_files, trust_store): ...@@ -145,12 +145,12 @@ async def test_selfsigned_in_trust_store(cert_files, trust_store):
assert trust_store.is_trusted(cert_self_signed) is True assert trust_store.is_trusted(cert_self_signed) is True
async def test_ca_not_in_trust_store(cert_files, trust_store): async def test_ca_not_in_trust_store(cert_files, trust_store) -> None:
cert_self_signed: x509.Certificate = await load_certificate(cert_files / SERVER_CERT_SELF_SIGNED_FILE) cert_self_signed: x509.Certificate = await load_certificate(cert_files / SERVER_CERT_SELF_SIGNED_FILE)
assert trust_store.is_trusted(cert_self_signed) is False assert trust_store.is_trusted(cert_self_signed) is False
async def test_ca_in_trust_store(cert_files, trust_store): async def test_ca_in_trust_store(cert_files, trust_store) -> None:
shutil.copyfile(cert_files / CA_CERT_FILE, trust_store.trust_locations[0] / CA_CERT_FILE) shutil.copyfile(cert_files / CA_CERT_FILE, trust_store.trust_locations[0] / CA_CERT_FILE)
await trust_store.load() await trust_store.load()
...@@ -158,7 +158,7 @@ async def test_ca_in_trust_store(cert_files, trust_store): ...@@ -158,7 +158,7 @@ async def test_ca_in_trust_store(cert_files, trust_store):
assert trust_store.is_trusted(cert_server) is True assert trust_store.is_trusted(cert_server) is True
async def test_empty_crl(cert_files, trust_store): async def test_empty_crl(cert_files, trust_store) -> None:
shutil.copyfile(cert_files / CA_CERT_FILE, trust_store.trust_locations[0] / CA_CERT_FILE) shutil.copyfile(cert_files / CA_CERT_FILE, trust_store.trust_locations[0] / CA_CERT_FILE)
shutil.copyfile(cert_files / 'ca_empty_crl.der', trust_store.crl_locations[0] / 'ca_empty_crl.der') shutil.copyfile(cert_files / 'ca_empty_crl.der', trust_store.crl_locations[0] / 'ca_empty_crl.der')
await trust_store.load() await trust_store.load()
...@@ -172,7 +172,7 @@ async def test_empty_crl(cert_files, trust_store): ...@@ -172,7 +172,7 @@ async def test_empty_crl(cert_files, trust_store):
assert trust_store.validate(cert_server) is True assert trust_store.validate(cert_server) is True
async def test_cert_in_crl(cert_files, trust_store): async def test_cert_in_crl(cert_files, trust_store) -> None:
shutil.copyfile(cert_files / CA_CERT_FILE, trust_store.trust_locations[0] / CA_CERT_FILE) shutil.copyfile(cert_files / CA_CERT_FILE, trust_store.trust_locations[0] / CA_CERT_FILE)
shutil.copyfile(cert_files / 'ca_crl.der', trust_store.crl_locations[0] / 'ca_crl.der') shutil.copyfile(cert_files / 'ca_crl.der', trust_store.crl_locations[0] / 'ca_crl.der')
await trust_store.load() await trust_store.load()
......
...@@ -2,7 +2,7 @@ from dataclasses import dataclass ...@@ -2,7 +2,7 @@ from dataclasses import dataclass
from asyncua.common.ua_utils import copy_dataclass_attr from asyncua.common.ua_utils import copy_dataclass_attr
def test_copy_dataclass_attr(): def test_copy_dataclass_attr() -> None:
@dataclass @dataclass
class A: class A:
x: int = 1 x: int = 1
......
...@@ -11,7 +11,7 @@ import pytest ...@@ -11,7 +11,7 @@ import pytest
import logging import logging
from datetime import datetime from datetime import datetime
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, List from typing import Optional, List, cast
from asyncua import ua from asyncua import ua
from asyncua.ua import ua_binary from asyncua.ua import ua_binary
...@@ -773,7 +773,7 @@ def test_where_clause(): ...@@ -773,7 +773,7 @@ def test_where_clause():
op.BrowsePath.append(ua.QualifiedName("property", 2)) op.BrowsePath.append(ua.QualifiedName("property", 2))
el.FilterOperands.append(op) el.FilterOperands.append(op)
for i in range(10): for i in range(10):
op = ua.LiteralOperand(Value = ua.Variant(i)) op = ua.LiteralOperand(Value=ua.Variant(i))
el.FilterOperands.append(op) el.FilterOperands.append(op)
el.FilterOperator = ua.FilterOperator.InList el.FilterOperator = ua.FilterOperator.InList
cf.Elements.append(el) cf.Elements.append(el)
...@@ -858,7 +858,7 @@ def test_expandedNodeId(): ...@@ -858,7 +858,7 @@ def test_expandedNodeId():
assert nid.Identifier == 85 assert nid.Identifier == 85
def test_struct_104(): def test_struct_104() -> None:
@dataclass @dataclass
class MyStruct: class MyStruct:
Encoding: ua.Byte = field(default=0, repr=False, init=False) Encoding: ua.Byte = field(default=0, repr=False, init=False)
...@@ -872,7 +872,7 @@ def test_struct_104(): ...@@ -872,7 +872,7 @@ def test_struct_104():
m2 = struct_from_binary(MyStruct, ua.utils.Buffer(data)) m2 = struct_from_binary(MyStruct, ua.utils.Buffer(data))
assert m == m2 assert m == m2
m = MyStruct(a=4, b=5, c="lkjkæl", l=["a", "b", "c"]) m = MyStruct(a=4, b=5, c="lkjkæl", l=[cast(ua.String, "a"), cast(ua.String, "b"), cast(ua.String, "c")])
data = struct_to_binary(m) data = struct_to_binary(m)
m2 = struct_from_binary(MyStruct, ua.utils.Buffer(data)) m2 = struct_from_binary(MyStruct, ua.utils.Buffer(data))
assert m == m2 assert m == m2
......
...@@ -30,8 +30,8 @@ async def add_server_custom_enum_struct(server: Server): ...@@ -30,8 +30,8 @@ async def add_server_custom_enum_struct(server: Server):
uatypes.register_enum('ExampleEnum', ua.NodeId(3002, ns), ExampleEnum) uatypes.register_enum('ExampleEnum', ua.NodeId(3002, ns), ExampleEnum)
uatypes.register_extension_object('ExampleStruct', ua.NodeId(5001, ns), ExampleStruct) uatypes.register_extension_object('ExampleStruct', ua.NodeId(5001, ns), ExampleStruct)
await server.import_xml(TEST_DIR / "enum_struct_test_nodes.xml"), await server.import_xml(TEST_DIR / "enum_struct_test_nodes.xml"),
val = ua.ExampleStruct() val = ExampleStruct()
val.IntVal1 = 242 val.IntVal1 = 242
val.EnumVal = ua.ExampleEnum.EnumVal2 val.EnumVal = ExampleEnum.EnumVal2
myvar = server.get_node(ua.NodeId(6009, ns)) myvar = server.get_node(ua.NodeId(6009, ns))
await myvar.write_value(val) await myvar.write_value(val)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment