Commit 93979db8 authored by Roman Yurchak's avatar Roman Yurchak Committed by oroulet

A few more type annotations fixes

parent 2892c7f8
...@@ -4,7 +4,7 @@ import socket ...@@ -4,7 +4,7 @@ import socket
from cryptography import x509 from cryptography import x509
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union, cast, Callable, Coroutine from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union, cast, Callable, Coroutine
from urllib.parse import urlparse, unquote from urllib.parse import urlparse, unquote, ParseResult
from pathlib import Path from pathlib import Path
import asyncua import asyncua
...@@ -36,9 +36,9 @@ class Client: ...@@ -36,9 +36,9 @@ class Client:
which offers the raw OPC-UA services interface. which offers the raw OPC-UA services interface.
""" """
_username = None _username: Optional[str] = None
_password = None _password: Optional[str] = None
strip_url_credentials = True strip_url_credentials: bool = True
def __init__(self, url: str, timeout: float = 4, watchdog_intervall: float = 1.0): def __init__(self, url: str, timeout: float = 4, watchdog_intervall: float = 1.0):
""" """
...@@ -104,7 +104,7 @@ class Client: ...@@ -104,7 +104,7 @@ class Client:
__repr__ = __str__ __repr__ = __str__
@property @property
def server_url(self): def server_url(self) -> ParseResult:
"""Return the server URL with stripped credentials """Return the server URL with stripped credentials
if self.strip_url_credentials is True. Disabling this if self.strip_url_credentials is True. Disabling this
...@@ -681,7 +681,7 @@ class Client: ...@@ -681,7 +681,7 @@ class Client:
params.UserIdentityToken.EncryptionAlgorithm = uri params.UserIdentityToken.EncryptionAlgorithm = uri
params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.UserName, "username_basic256") params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.UserName, "username_basic256")
def _encrypt_password(self, password: str, policy_uri): def _encrypt_password(self, password: str, policy_uri) -> Tuple[bytes, str]:
pubkey = uacrypto.x509_from_der(self.security_policy.peer_certificate).public_key() pubkey = uacrypto.x509_from_der(self.security_policy.peer_certificate).public_key()
# see specs part 4, 7.36.3: if the token is encrypted, password # see specs part 4, 7.36.3: if the token is encrypted, password
# shall be converted to UTF-8 and serialized with server nonce # shall be converted to UTF-8 and serialized with server nonce
...@@ -758,7 +758,7 @@ class Client: ...@@ -758,7 +758,7 @@ class Client:
_logger.info("Result from subscription update: %s", results) _logger.info("Result from subscription update: %s", results)
return subscription return subscription
def get_subscription_revised_params( # type: ignore def get_subscription_revised_params(
self, self,
params: ua.CreateSubscriptionParameters, params: ua.CreateSubscriptionParameters,
results: ua.CreateSubscriptionResult, results: ua.CreateSubscriptionResult,
...@@ -768,7 +768,7 @@ class Client: ...@@ -768,7 +768,7 @@ class Client:
and results.RevisedLifetimeCount == params.RequestedLifetimeCount and results.RevisedLifetimeCount == params.RequestedLifetimeCount
and results.RevisedMaxKeepAliveCount == params.RequestedMaxKeepAliveCount and results.RevisedMaxKeepAliveCount == params.RequestedMaxKeepAliveCount
): ):
return # type: ignore return None
_logger.warning( _logger.warning(
"Revised values returned differ from subscription values: %s", results "Revised values returned differ from subscription values: %s", results
) )
...@@ -795,6 +795,7 @@ class Client: ...@@ -795,6 +795,7 @@ class Client:
# update LifetimeCount but chances are it will be re-revised again # update LifetimeCount but chances are it will be re-revised again
modified_params.RequestedLifetimeCount = results.RevisedLifetimeCount modified_params.RequestedLifetimeCount = results.RevisedLifetimeCount
return modified_params return modified_params
return None
async def delete_subscriptions(self, subscription_ids: Iterable[int]) -> List[ua.StatusCode]: async def delete_subscriptions(self, subscription_ids: Iterable[int]) -> List[ua.StatusCode]:
""" """
......
...@@ -10,7 +10,7 @@ from sortedcontainers import SortedDict # type: ignore ...@@ -10,7 +10,7 @@ from sortedcontainers import SortedDict # type: ignore
from asyncua import Node, ua, Client from asyncua import Node, ua, Client
from asyncua.client.ua_client import UASocketProtocol from asyncua.client.ua_client import UASocketProtocol
from asyncua.ua.uaerrors import BadSessionClosed, BadSessionNotActivated from asyncua.ua.uaerrors import BadSessionClosed, BadSessionNotActivated
from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Type, Union from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Type, Union, Sequence
from .reconciliator import Reconciliator from .reconciliator import Reconciliator
from .common import ClientNotFound, event_wait from .common import ClientNotFound, event_wait
...@@ -169,7 +169,7 @@ class HaClient: ...@@ -169,7 +169,7 @@ class HaClient:
self.is_running = True self.is_running = True
async def stop(self): async def stop(self):
to_stop = chain( to_stop: Sequence[Union[KeepAlive, HaManager, Reconciliator]] = chain(
self._keepalive_task, self._manager_task, self._reconciliator_task self._keepalive_task, self._manager_task, self._reconciliator_task
) )
stop = [p.stop() for p in to_stop] stop = [p.stop() for p in to_stop]
......
...@@ -10,6 +10,7 @@ from asyncua import ua ...@@ -10,6 +10,7 @@ from asyncua import ua
from asyncua.common.session_interface import AbstractSession from asyncua.common.session_interface import AbstractSession
from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary
from ..ua.uaerrors import BadTimeout, BadNoSubscription, BadSessionClosed, BadUserAccessDenied, UaStructParsingError from ..ua.uaerrors import BadTimeout, BadNoSubscription, BadSessionClosed, BadUserAccessDenied, UaStructParsingError
from ..ua.uaprotocol_auto import OpenSecureChannelResult, SubscriptionAcknowledgement
from ..common.connection import SecureConnection, TransportLimits from ..common.connection import SecureConnection, TransportLimits
...@@ -50,7 +51,7 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -50,7 +51,7 @@ class UASocketProtocol(asyncio.Protocol):
# Hook for upper layer tasks before a request is sent (optional) # Hook for upper layer tasks before a request is sent (optional)
self.pre_request_hook: Optional[Callable[[], Awaitable[None]]] = None self.pre_request_hook: Optional[Callable[[], Awaitable[None]]] = None
def connection_made(self, transport: asyncio.Transport): # type: ignore def connection_made(self, transport: asyncio.Transport): # type: ignore[override]
self.state = self.OPEN self.state = self.OPEN
self.transport = transport self.transport = transport
...@@ -59,13 +60,13 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -59,13 +60,13 @@ class UASocketProtocol(asyncio.Protocol):
self.state = self.CLOSED self.state = self.CLOSED
self.transport = None self.transport = None
def data_received(self, data: bytes): def data_received(self, data: bytes) -> None:
if self.receive_buffer: if self.receive_buffer:
data = self.receive_buffer + data data = self.receive_buffer + data
self.receive_buffer = None self.receive_buffer = None
self._process_received_data(data) self._process_received_data(data)
def _process_received_data(self, data: bytes): def _process_received_data(self, data: bytes) -> None:
""" """
Try to parse received data as asyncua message. Data may be chunked but will be in correct order. Try to parse received data as asyncua message. Data may be chunked but will be in correct order.
See: https://docs.python.org/3/library/asyncio-protocol.html#asyncio.Protocol.data_received See: https://docs.python.org/3/library/asyncio-protocol.html#asyncio.Protocol.data_received
...@@ -222,7 +223,7 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -222,7 +223,7 @@ class UASocketProtocol(asyncio.Protocol):
self.transport.write(uatcp_to_binary(ua.MessageType.Hello, hello)) self.transport.write(uatcp_to_binary(ua.MessageType.Hello, hello))
return await asyncio.wait_for(ack, self.timeout) return await asyncio.wait_for(ack, self.timeout)
async def open_secure_channel(self, params): async def open_secure_channel(self, params) -> OpenSecureChannelResult:
self.logger.info("open_secure_channel") self.logger.info("open_secure_channel")
request = ua.OpenSecureChannelRequest() request = ua.OpenSecureChannelRequest()
request.Parameters = params request.Parameters = params
...@@ -230,7 +231,7 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -230,7 +231,7 @@ class UASocketProtocol(asyncio.Protocol):
raise RuntimeError('Two Open Secure Channel requests can not happen too close to each other. ' 'The response must be processed and returned before the next request can be sent.') raise RuntimeError('Two Open Secure Channel requests can not happen too close to each other. ' 'The response must be processed and returned before the next request can be sent.')
self._open_secure_channel_exchange = params self._open_secure_channel_exchange = params
await asyncio.wait_for(self._send_request(request, message_type=ua.MessageType.SecureOpen), self.timeout) await asyncio.wait_for(self._send_request(request, message_type=ua.MessageType.SecureOpen), self.timeout)
_return = self._open_secure_channel_exchange.Parameters _return = self._open_secure_channel_exchange.Parameters # type: ignore[union-attr]
self._open_secure_channel_exchange = None self._open_secure_channel_exchange = None
return _return return _return
...@@ -587,7 +588,7 @@ class UaClient(AbstractSession): ...@@ -587,7 +588,7 @@ class UaClient(AbstractSession):
Start a loop that sends a publish requests and waits for the publish responses. Start a loop that sends a publish requests and waits for the publish responses.
Forward the `PublishResult` to the matching `Subscription` by callback. Forward the `PublishResult` to the matching `Subscription` by callback.
""" """
ack = None ack: Optional[SubscriptionAcknowledgement] = None
while not self._closing: while not self._closing:
try: try:
response = await self.publish([ack] if ack else []) response = await self.publish([ack] if ack else [])
......
...@@ -5,6 +5,7 @@ Binary protocol specific functions and constants ...@@ -5,6 +5,7 @@ Binary protocol specific functions and constants
import functools import functools
import struct import struct
import logging import logging
from io import BytesIO
from typing import IO, Any, Callable, Optional, Sequence, Type, TypeVar, Union from typing import IO, Any, Callable, Optional, Sequence, Type, TypeVar, Union
import typing import typing
import uuid import uuid
...@@ -437,7 +438,7 @@ def nodeid_to_binary(nodeid): ...@@ -437,7 +438,7 @@ def nodeid_to_binary(nodeid):
return data return data
def nodeid_from_binary(data): def nodeid_from_binary(data: Union[BytesIO, Buffer]) -> Union[ua.NodeId, ua.ExpandedNodeId]:
encoding = ord(data.read(1)) encoding = ord(data.read(1))
nidtype = ua.NodeIdType(encoding & 0b00111111) nidtype = ua.NodeIdType(encoding & 0b00111111)
uri = None uri = None
...@@ -518,7 +519,7 @@ def _reshape(flat, dims): ...@@ -518,7 +519,7 @@ def _reshape(flat, dims):
return [_reshape(flat[i:i + subsize], subdims) for i in range(0, len(flat), subsize)] return [_reshape(flat[i:i + subsize], subdims) for i in range(0, len(flat), subsize)]
def extensionobject_from_binary(data): def extensionobject_from_binary(data: Buffer) -> ua.ExtensionObject:
""" """
Convert binary-coded ExtensionObject to a Python object. Convert binary-coded ExtensionObject to a Python object.
Returns an object, or None if TypeId is zero Returns an object, or None if TypeId is zero
......
...@@ -5,7 +5,7 @@ Date:2022-09-22 18:18:39.272455 ...@@ -5,7 +5,7 @@ Date:2022-09-22 18:18:39.272455
from datetime import datetime, timezone from datetime import datetime, timezone
from enum import IntEnum, IntFlag from enum import IntEnum, IntFlag
from typing import Union, List, Optional, Type from typing import List, Optional, Type
from dataclasses import dataclass, field from dataclasses import dataclass, field
from asyncua.ua.uatypes import FROZEN from asyncua.ua.uatypes import FROZEN
...@@ -1673,8 +1673,8 @@ class ExceptionDeviationFormat(IntEnum): ...@@ -1673,8 +1673,8 @@ class ExceptionDeviationFormat(IntEnum):
Unknown = 4 Unknown = 4
@dataclass(frozen=FROZEN) # type: ignore @dataclass(frozen=FROZEN)
class Union: # type: ignore class Union:
""" """
https://reference.opcfoundation.org/v105/Core/docs/Part5/12.2.12/#12.2.12.12 https://reference.opcfoundation.org/v105/Core/docs/Part5/12.2.12/#12.2.12.12
......
...@@ -14,5 +14,7 @@ check_untyped_defs = False ...@@ -14,5 +14,7 @@ check_untyped_defs = False
[mypy-asyncua.ua.uaprotocol_auto.*] [mypy-asyncua.ua.uaprotocol_auto.*]
# Autogenerated file # Autogenerated file
disable_error_code = literal-required disable_error_code = literal-required
[mypy-asyncua.client.*]
check_untyped_defs = True
[mypy-asynctest.*] [mypy-asynctest.*]
ignore_missing_imports = True ignore_missing_imports = True
...@@ -740,7 +740,7 @@ async def test_value(opc): ...@@ -740,7 +740,7 @@ async def test_value(opc):
assert 1.98 == await v.read_value() assert 1.98 == await v.read_value()
dvar = ua.DataValue(var) dvar = ua.DataValue(var)
dv = await v.read_data_value() dv = await v.read_data_value()
assert ua.DataValue == type(dv) assert ua.DataValue is type(dv)
assert dvar.Value == dv.Value assert dvar.Value == dv.Value
assert dvar.Value == var assert dvar.Value == var
await opc.opc.delete_nodes([v]) await opc.opc.delete_nodes([v])
......
...@@ -519,7 +519,7 @@ def test_null_string(): ...@@ -519,7 +519,7 @@ def test_null_string():
def test_empty_extension_object(): def test_empty_extension_object():
obj = ua.ExtensionObject() obj = ua.ExtensionObject()
obj2 = extensionobject_from_binary(ua.utils.Buffer(extensionobject_to_binary(obj))) obj2 = extensionobject_from_binary(ua.utils.Buffer(extensionobject_to_binary(obj)))
assert type(obj) == type(obj2) assert type(obj) is type(obj2)
assert obj == obj2 assert obj == obj2
...@@ -528,12 +528,12 @@ def test_extension_object(): ...@@ -528,12 +528,12 @@ def test_extension_object():
obj.UserName = "admin" obj.UserName = "admin"
obj.Password = b"pass" obj.Password = b"pass"
obj2 = extensionobject_from_binary(ua.utils.Buffer(extensionobject_to_binary(obj))) obj2 = extensionobject_from_binary(ua.utils.Buffer(extensionobject_to_binary(obj)))
assert type(obj) == type(obj2) assert type(obj) is type(obj2)
assert obj.UserName == obj2.UserName assert obj.UserName == obj2.UserName
assert obj.Password == obj2.Password assert obj.Password == obj2.Password
v1 = ua.Variant(obj) v1 = ua.Variant(obj)
v2 = variant_from_binary(ua.utils.Buffer(variant_to_binary(v1))) v2 = variant_from_binary(ua.utils.Buffer(variant_to_binary(v1)))
assert type(v1) == type(v2) assert type(v1) is type(v2)
assert v1.VariantType == v2.VariantType assert v1.VariantType == v2.VariantType
...@@ -545,7 +545,7 @@ def test_unknown_extension_object(): ...@@ -545,7 +545,7 @@ def test_unknown_extension_object():
data = ua.utils.Buffer(extensionobject_to_binary(obj)) data = ua.utils.Buffer(extensionobject_to_binary(obj))
obj2 = extensionobject_from_binary(data) obj2 = extensionobject_from_binary(data)
assert type(obj2) == ua.ExtensionObject assert type(obj2) is ua.ExtensionObject
assert obj2.TypeId == obj.TypeId assert obj2.TypeId == obj.TypeId
assert obj2.Body == b'example of data in custom format' assert obj2.Body == b'example of data in custom format'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment