Commit e3642e98 authored by vruge's avatar vruge Committed by oroulet

mypy fixes for union attr

parent 592d604a
...@@ -10,6 +10,7 @@ from functools import partial ...@@ -10,6 +10,7 @@ from functools import partial
from typing import TYPE_CHECKING, Dict, Set, Union, List, Optional from typing import TYPE_CHECKING, Dict, Set, Union, List, Optional
from sortedcontainers import SortedDict # type: ignore from sortedcontainers import SortedDict # type: ignore
from asyncua import ua, Client from asyncua import ua, Client
from asyncua.sync import Subscription
from pickle import PicklingError from pickle import PicklingError
from .common import batch, event_wait, get_digest from .common import batch, event_wait, get_digest
...@@ -256,7 +257,7 @@ class Reconciliator: ...@@ -256,7 +257,7 @@ class Reconciliator:
vs_ideal: VirtualSubscription, vs_ideal: VirtualSubscription,
) -> List[asyncio.Task]: ) -> List[asyncio.Task]:
tasks: List[asyncio.Task] = [] tasks: List[asyncio.Task] = []
real_sub = self.name_to_subscription[url].get(sub_name) real_sub: Subscription = self.name_to_subscription[url].get(sub_name)
monitoring = vs_real.monitoring monitoring = vs_real.monitoring
node_to_add = set(vs_ideal.nodes) - set(vs_real.nodes) node_to_add = set(vs_ideal.nodes) - set(vs_real.nodes)
if node_to_add: if node_to_add:
...@@ -304,7 +305,7 @@ class Reconciliator: ...@@ -304,7 +305,7 @@ class Reconciliator:
) -> List[asyncio.Task]: ) -> List[asyncio.Task]:
to_del: List[asyncio.Task] = [] to_del: List[asyncio.Task] = []
node_to_del = set(vs_real.nodes) - set(vs_ideal.nodes) node_to_del = set(vs_real.nodes) - set(vs_ideal.nodes)
real_sub = self.name_to_subscription[url].get(sub_name) real_sub: Subscription = self.name_to_subscription[url].get(sub_name)
if node_to_del: if node_to_del:
_logger.info(f"Removing {len(node_to_del)} Nodes") _logger.info(f"Removing {len(node_to_del)} Nodes")
for batch_nodes in batch(node_to_del, self.BATCH_MI_SIZE): for batch_nodes in batch(node_to_del, self.BATCH_MI_SIZE):
......
...@@ -80,8 +80,8 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -80,8 +80,8 @@ class UASocketProtocol(asyncio.Protocol):
if header.MessageType == ua.MessageType.SecureOpen: if header.MessageType == ua.MessageType.SecureOpen:
params = self._open_secure_channel_exchange params = self._open_secure_channel_exchange
self._open_secure_channel_exchange = struct_from_binary(ua.OpenSecureChannelResponse, msg.body()) self._open_secure_channel_exchange = struct_from_binary(ua.OpenSecureChannelResponse, msg.body())
self._open_secure_channel_exchange.ResponseHeader.ServiceResult.check() self._open_secure_channel_exchange.ResponseHeader.ServiceResult.check() # type: ignore
self._connection.set_channel(self._open_secure_channel_exchange.Parameters, params.RequestType, params.ClientNonce) self._connection.set_channel(self._open_secure_channel_exchange.Parameters, params.RequestType, params.ClientNonce) # type: ignore
if not buf: if not buf:
return return
# Buffer still has bytes left, try to process again # Buffer still has bytes left, try to process again
...@@ -131,7 +131,8 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -131,7 +131,8 @@ class UASocketProtocol(asyncio.Protocol):
self._connection.revolve_tokens() self._connection.revolve_tokens()
msg = self._connection.message_to_binary(binreq, message_type=message_type, request_id=self._request_id) msg = self._connection.message_to_binary(binreq, message_type=message_type, request_id=self._request_id)
self.transport.write(msg) if self.transport is not None:
self.transport.write(msg)
return future return future
async def send_request(self, request, timeout: Optional[float] = None, message_type=ua.MessageType.SecureMessage): async def send_request(self, request, timeout: Optional[float] = None, message_type=ua.MessageType.SecureMessage):
...@@ -200,7 +201,8 @@ class UASocketProtocol(asyncio.Protocol): ...@@ -200,7 +201,8 @@ class UASocketProtocol(asyncio.Protocol):
hello.MaxChunkCount = max_chunkcount hello.MaxChunkCount = max_chunkcount
ack = asyncio.Future() ack = asyncio.Future()
self._callbackmap[0] = ack self._callbackmap[0] = ack
self.transport.write(uatcp_to_binary(ua.MessageType.Hello, hello)) if self.transport is not None:
self.transport.write(uatcp_to_binary(ua.MessageType.Hello, hello))
return await asyncio.wait_for(ack, self.timeout) return await asyncio.wait_for(ack, self.timeout)
async def open_secure_channel(self, params): async def open_secure_channel(self, params):
...@@ -249,7 +251,7 @@ class UaClient: ...@@ -249,7 +251,7 @@ class UaClient:
self._subscription_callbacks = {} self._subscription_callbacks = {}
self._timeout = timeout self._timeout = timeout
self.security_policy = ua.SecurityPolicy() self.security_policy = ua.SecurityPolicy()
self.protocol: Optional[UASocketProtocol] = None self.protocol: UASocketProtocol = None
self._publish_task = None self._publish_task = None
def set_security(self, policy: ua.SecurityPolicy): def set_security(self, policy: ua.SecurityPolicy):
......
...@@ -24,7 +24,7 @@ from typing import Optional, Union, List ...@@ -24,7 +24,7 @@ from typing import Optional, Union, List
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class State(object): class State:
''' '''
Helperclass for States (StateVariableType) Helperclass for States (StateVariableType)
https://reference.opcfoundation.org/v104/Core/docs/Part5/B.4.3/ https://reference.opcfoundation.org/v104/Core/docs/Part5/B.4.3/
...@@ -33,7 +33,7 @@ class State(object): ...@@ -33,7 +33,7 @@ class State(object):
id: "BaseVariableType" Id is a name which uniquely identifies the current state within the StateMachineType. A subtype may restrict the DataType. id: "BaseVariableType" Id is a name which uniquely identifies the current state within the StateMachineType. A subtype may restrict the DataType.
number: Number is an integer which uniquely identifies the current state within the StateMachineType. number: Number is an integer which uniquely identifies the current state within the StateMachineType.
''' '''
def __init__(self, id, name: str=None, number: int=None, node: Node=None): def __init__(self, id, name: str=None, number: int=None, node: Optional[Node]=None):
if id is not None: if id is not None:
self.id = ua.Variant(id) self.id = ua.Variant(id)
else: else:
...@@ -41,10 +41,10 @@ class State(object): ...@@ -41,10 +41,10 @@ class State(object):
self.name = name self.name = name
self.number = number self.number = number
self.effectivedisplayname = ua.LocalizedText(name, "en-US") self.effectivedisplayname = ua.LocalizedText(name, "en-US")
self.node = node #will be written from statemachine.add_state() or you need to overwrite it if the state is part of xml self.node: Node = node #will be written from statemachine.add_state() or you need to overwrite it if the state is part of xml
class Transition(object): class Transition:
''' '''
Helperclass for Transitions (TransitionVariableType) Helperclass for Transitions (TransitionVariableType)
https://reference.opcfoundation.org/v104/Core/docs/Part5/B.4.4/ https://reference.opcfoundation.org/v104/Core/docs/Part5/B.4.4/
...@@ -66,7 +66,7 @@ class Transition(object): ...@@ -66,7 +66,7 @@ class Transition(object):
self.name = name self.name = name
self.number = number self.number = number
self._transitiontime = datetime.datetime.utcnow() #will be overwritten from _write_transition() self._transitiontime = datetime.datetime.utcnow() #will be overwritten from _write_transition()
self.node = node #will be written from statemachine.add_state() or you need to overwrite it if the state is part of xml self.node: Node = node #will be written from statemachine.add_state() or you need to overwrite it if the state is part of xml
class StateMachine(object): class StateMachine(object):
...@@ -88,22 +88,22 @@ class StateMachine(object): ...@@ -88,22 +88,22 @@ class StateMachine(object):
self.locale = "en-US" self.locale = "en-US"
self._server = server self._server = server
self._parent = parent self._parent = parent
self._state_machine_node: Optional[Node] = None self._state_machine_node: Node = None
self._state_machine_type = ua.NodeId(2299, 0) #StateMachineType self._state_machine_type = ua.NodeId(2299, 0) #StateMachineType
self._name = name self._name = name
self._idx = idx self._idx = idx
self._optionals = False self._optionals = False
self._current_state_node: Optional[Node] = None self._current_state_node: Node = None
self._current_state_id_node = None self._current_state_id_node = None
self._current_state_name_node = None self._current_state_name_node = None
self._current_state_number_node = None self._current_state_number_node = None
self._current_state_effective_display_name_node = None self._current_state_effective_display_name_node = None
self._last_transition_node: Optional[Node] = None self._last_transition_node: Node = None
self._last_transition_id_node = None self._last_transition_id_node = None
self._last_transition_name_node = None self._last_transition_name_node = None
self._last_transition_number_node = None self._last_transition_number_node = None
self._last_transition_transitiontime_node = None self._last_transition_transitiontime_node = None
self._evgen: Optional[EventGenerator] = None self._evgen: EventGenerator = None
self.evtype = TransitionEvent() self.evtype = TransitionEvent()
self._current_state = State(None) self._current_state = State(None)
...@@ -185,12 +185,12 @@ class StateMachine(object): ...@@ -185,12 +185,12 @@ class StateMachine(object):
event_msg = ua.LocalizedText(event_msg, self.locale) event_msg = ua.LocalizedText(event_msg, self.locale)
if not isinstance(event_msg, ua.LocalizedText): if not isinstance(event_msg, ua.LocalizedText):
raise ValueError(f"Statemachine: {self._name} -> event_msg: {event_msg} is not a instance of LocalizedText") raise ValueError(f"Statemachine: {self._name} -> event_msg: {event_msg} is not a instance of LocalizedText")
self._evgen.event.Message = event_msg self._evgen.event.Message = event_msg # type: ignore
self._evgen.event.Severity = severity self._evgen.event.Severity = severity # type: ignore
self._evgen.event.ToState = ua.LocalizedText(state.name, self.locale) self._evgen.event.ToState = ua.LocalizedText(state.name, self.locale) # type: ignore
if transition: if transition:
self._evgen.event.Transition = ua.LocalizedText(transition.name, self.locale) self._evgen.event.Transition = ua.LocalizedText(transition.name, self.locale) # type: ignore
self._evgen.event.FromState = ua.LocalizedText(self._current_state.name) self._evgen.event.FromState = ua.LocalizedText(self._current_state.name) # type: ignore
await self._evgen.trigger() await self._evgen.trigger()
self._current_state = state self._current_state = state
...@@ -287,8 +287,8 @@ class FiniteStateMachine(StateMachine): ...@@ -287,8 +287,8 @@ class FiniteStateMachine(StateMachine):
if name is None: if name is None:
self._name = "FiniteStateMachine" self._name = "FiniteStateMachine"
self._state_machine_type = ua.NodeId(2771, 0) self._state_machine_type = ua.NodeId(2771, 0)
self._available_states_node: Optional[Node] = None self._available_states_node: Node = None
self._available_transitions_node: Optional[Node] = None self._available_transitions_node: Node = None
async def set_available_states(self, states: List[ua.NodeId]): async def set_available_states(self, states: List[ua.NodeId]):
if not self._available_states_node: if not self._available_states_node:
......
...@@ -362,7 +362,7 @@ class Subscription: ...@@ -362,7 +362,7 @@ class Subscription:
If you delete the subscription, you do not need to unsubscribe. If you delete the subscription, you do not need to unsubscribe.
:param handle: The handle that was returned when subscribing to the node/nodes :param handle: The handle that was returned when subscribing to the node/nodes
""" """
handles = [handle] if type(handle) is int else handle handles: List[int] = [handle] if isinstance(handle, int) else handle
if not handles: if not handles:
return return
params = ua.DeleteMonitoredItemsParameters() params = ua.DeleteMonitoredItemsParameters()
......
...@@ -24,7 +24,7 @@ class EventGenerator: ...@@ -24,7 +24,7 @@ class EventGenerator:
def __init__(self, isession): def __init__(self, isession):
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.isession = isession self.isession = isession
self.event: Optional[event_objects.BaseEvent] = None self.event: event_objects.BaseEvent = None
self.emitting_node: Optional[Node] = None self.emitting_node: Optional[Node] = None
async def init(self, etype=None, emitting_node=ua.ObjectIds.Server): async def init(self, etype=None, emitting_node=ua.ObjectIds.Server):
......
...@@ -950,7 +950,7 @@ class DataValue: ...@@ -950,7 +950,7 @@ class DataValue:
data_type = NodeId(25) data_type = NodeId(25)
Encoding: Byte = field(default=0, repr=False, init=False, compare=False) Encoding: Byte = field(default=0, repr=False, init=False, compare=False)
Value: Optional[Variant] = None Value: Variant = field(default_factory=Variant)
StatusCode_: Optional[StatusCode] = field(default_factory=StatusCode) StatusCode_: Optional[StatusCode] = field(default_factory=StatusCode)
SourceTimestamp: Optional[DateTime] = None # FIXME type DateType raises type hinting errors because datetime is assigned SourceTimestamp: Optional[DateTime] = None # FIXME type DateType raises type hinting errors because datetime is assigned
ServerTimestamp: Optional[DateTime] = None ServerTimestamp: Optional[DateTime] = None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment