Commit 93979db8 authored by Roman Yurchak's avatar Roman Yurchak Committed by oroulet

A few more type annotations fixes

parent 2892c7f8
......@@ -4,7 +4,7 @@ import socket
from cryptography import x509
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union, cast, Callable, Coroutine
from urllib.parse import urlparse, unquote
from urllib.parse import urlparse, unquote, ParseResult
from pathlib import Path
import asyncua
......@@ -36,9 +36,9 @@ class Client:
which offers the raw OPC-UA services interface.
"""
_username = None
_password = None
strip_url_credentials = True
_username: Optional[str] = None
_password: Optional[str] = None
strip_url_credentials: bool = True
def __init__(self, url: str, timeout: float = 4, watchdog_intervall: float = 1.0):
"""
......@@ -104,7 +104,7 @@ class Client:
__repr__ = __str__
@property
def server_url(self):
def server_url(self) -> ParseResult:
"""Return the server URL with stripped credentials
if self.strip_url_credentials is True. Disabling this
......@@ -681,7 +681,7 @@ class Client:
params.UserIdentityToken.EncryptionAlgorithm = uri
params.UserIdentityToken.PolicyId = self.server_policy_id(ua.UserTokenType.UserName, "username_basic256")
def _encrypt_password(self, password: str, policy_uri):
def _encrypt_password(self, password: str, policy_uri) -> Tuple[bytes, str]:
pubkey = uacrypto.x509_from_der(self.security_policy.peer_certificate).public_key()
# see specs part 4, 7.36.3: if the token is encrypted, password
# shall be converted to UTF-8 and serialized with server nonce
......@@ -758,7 +758,7 @@ class Client:
_logger.info("Result from subscription update: %s", results)
return subscription
def get_subscription_revised_params( # type: ignore
def get_subscription_revised_params(
self,
params: ua.CreateSubscriptionParameters,
results: ua.CreateSubscriptionResult,
......@@ -768,7 +768,7 @@ class Client:
and results.RevisedLifetimeCount == params.RequestedLifetimeCount
and results.RevisedMaxKeepAliveCount == params.RequestedMaxKeepAliveCount
):
return # type: ignore
return None
_logger.warning(
"Revised values returned differ from subscription values: %s", results
)
......@@ -795,6 +795,7 @@ class Client:
# update LifetimeCount but chances are it will be re-revised again
modified_params.RequestedLifetimeCount = results.RevisedLifetimeCount
return modified_params
return None
async def delete_subscriptions(self, subscription_ids: Iterable[int]) -> List[ua.StatusCode]:
"""
......
......@@ -10,7 +10,7 @@ from sortedcontainers import SortedDict # type: ignore
from asyncua import Node, ua, Client
from asyncua.client.ua_client import UASocketProtocol
from asyncua.ua.uaerrors import BadSessionClosed, BadSessionNotActivated
from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Type, Union
from typing import Dict, Generator, Iterable, List, Optional, Set, Tuple, Type, Union, Sequence
from .reconciliator import Reconciliator
from .common import ClientNotFound, event_wait
......@@ -169,7 +169,7 @@ class HaClient:
self.is_running = True
async def stop(self):
to_stop = chain(
to_stop: Sequence[Union[KeepAlive, HaManager, Reconciliator]] = chain(
self._keepalive_task, self._manager_task, self._reconciliator_task
)
stop = [p.stop() for p in to_stop]
......
......@@ -10,6 +10,7 @@ from asyncua import ua
from asyncua.common.session_interface import AbstractSession
from ..ua.ua_binary import struct_from_binary, uatcp_to_binary, struct_to_binary, nodeid_from_binary, header_from_binary
from ..ua.uaerrors import BadTimeout, BadNoSubscription, BadSessionClosed, BadUserAccessDenied, UaStructParsingError
from ..ua.uaprotocol_auto import OpenSecureChannelResult, SubscriptionAcknowledgement
from ..common.connection import SecureConnection, TransportLimits
......@@ -50,7 +51,7 @@ class UASocketProtocol(asyncio.Protocol):
# Hook for upper layer tasks before a request is sent (optional)
self.pre_request_hook: Optional[Callable[[], Awaitable[None]]] = None
def connection_made(self, transport: asyncio.Transport): # type: ignore
def connection_made(self, transport: asyncio.Transport): # type: ignore[override]
self.state = self.OPEN
self.transport = transport
......@@ -59,13 +60,13 @@ class UASocketProtocol(asyncio.Protocol):
self.state = self.CLOSED
self.transport = None
def data_received(self, data: bytes):
def data_received(self, data: bytes) -> None:
if self.receive_buffer:
data = self.receive_buffer + data
self.receive_buffer = None
self._process_received_data(data)
def _process_received_data(self, data: bytes):
def _process_received_data(self, data: bytes) -> None:
"""
Try to parse received data as asyncua message. Data may be chunked but will be in correct order.
See: https://docs.python.org/3/library/asyncio-protocol.html#asyncio.Protocol.data_received
......@@ -222,7 +223,7 @@ class UASocketProtocol(asyncio.Protocol):
self.transport.write(uatcp_to_binary(ua.MessageType.Hello, hello))
return await asyncio.wait_for(ack, self.timeout)
async def open_secure_channel(self, params):
async def open_secure_channel(self, params) -> OpenSecureChannelResult:
self.logger.info("open_secure_channel")
request = ua.OpenSecureChannelRequest()
request.Parameters = params
......@@ -230,7 +231,7 @@ class UASocketProtocol(asyncio.Protocol):
raise RuntimeError('Two Open Secure Channel requests can not happen too close to each other. ' 'The response must be processed and returned before the next request can be sent.')
self._open_secure_channel_exchange = params
await asyncio.wait_for(self._send_request(request, message_type=ua.MessageType.SecureOpen), self.timeout)
_return = self._open_secure_channel_exchange.Parameters
_return = self._open_secure_channel_exchange.Parameters # type: ignore[union-attr]
self._open_secure_channel_exchange = None
return _return
......@@ -587,7 +588,7 @@ class UaClient(AbstractSession):
Start a loop that sends a publish requests and waits for the publish responses.
Forward the `PublishResult` to the matching `Subscription` by callback.
"""
ack = None
ack: Optional[SubscriptionAcknowledgement] = None
while not self._closing:
try:
response = await self.publish([ack] if ack else [])
......
......@@ -5,6 +5,7 @@ Binary protocol specific functions and constants
import functools
import struct
import logging
from io import BytesIO
from typing import IO, Any, Callable, Optional, Sequence, Type, TypeVar, Union
import typing
import uuid
......@@ -437,7 +438,7 @@ def nodeid_to_binary(nodeid):
return data
def nodeid_from_binary(data):
def nodeid_from_binary(data: Union[BytesIO, Buffer]) -> Union[ua.NodeId, ua.ExpandedNodeId]:
encoding = ord(data.read(1))
nidtype = ua.NodeIdType(encoding & 0b00111111)
uri = None
......@@ -518,7 +519,7 @@ def _reshape(flat, dims):
return [_reshape(flat[i:i + subsize], subdims) for i in range(0, len(flat), subsize)]
def extensionobject_from_binary(data):
def extensionobject_from_binary(data: Buffer) -> ua.ExtensionObject:
"""
Convert binary-coded ExtensionObject to a Python object.
Returns an object, or None if TypeId is zero
......
......@@ -5,7 +5,7 @@ Date:2022-09-22 18:18:39.272455
from datetime import datetime, timezone
from enum import IntEnum, IntFlag
from typing import Union, List, Optional, Type
from typing import List, Optional, Type
from dataclasses import dataclass, field
from asyncua.ua.uatypes import FROZEN
......@@ -1673,8 +1673,8 @@ class ExceptionDeviationFormat(IntEnum):
Unknown = 4
@dataclass(frozen=FROZEN) # type: ignore
class Union: # type: ignore
@dataclass(frozen=FROZEN)
class Union:
"""
https://reference.opcfoundation.org/v105/Core/docs/Part5/12.2.12/#12.2.12.12
......
......@@ -14,5 +14,7 @@ check_untyped_defs = False
[mypy-asyncua.ua.uaprotocol_auto.*]
# Autogenerated file
disable_error_code = literal-required
[mypy-asyncua.client.*]
check_untyped_defs = True
[mypy-asynctest.*]
ignore_missing_imports = True
......@@ -740,7 +740,7 @@ async def test_value(opc):
assert 1.98 == await v.read_value()
dvar = ua.DataValue(var)
dv = await v.read_data_value()
assert ua.DataValue == type(dv)
assert ua.DataValue is type(dv)
assert dvar.Value == dv.Value
assert dvar.Value == var
await opc.opc.delete_nodes([v])
......
......@@ -519,7 +519,7 @@ def test_null_string():
def test_empty_extension_object():
obj = ua.ExtensionObject()
obj2 = extensionobject_from_binary(ua.utils.Buffer(extensionobject_to_binary(obj)))
assert type(obj) == type(obj2)
assert type(obj) is type(obj2)
assert obj == obj2
......@@ -528,12 +528,12 @@ def test_extension_object():
obj.UserName = "admin"
obj.Password = b"pass"
obj2 = extensionobject_from_binary(ua.utils.Buffer(extensionobject_to_binary(obj)))
assert type(obj) == type(obj2)
assert type(obj) is type(obj2)
assert obj.UserName == obj2.UserName
assert obj.Password == obj2.Password
v1 = ua.Variant(obj)
v2 = variant_from_binary(ua.utils.Buffer(variant_to_binary(v1)))
assert type(v1) == type(v2)
assert type(v1) is type(v2)
assert v1.VariantType == v2.VariantType
......@@ -545,7 +545,7 @@ def test_unknown_extension_object():
data = ua.utils.Buffer(extensionobject_to_binary(obj))
obj2 = extensionobject_from_binary(data)
assert type(obj2) == ua.ExtensionObject
assert type(obj2) is ua.ExtensionObject
assert obj2.TypeId == obj.TypeId
assert obj2.Body == b'example of data in custom format'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment