Commit 9be567cc authored by oroulet's avatar oroulet Committed by oroulet

yapf to make flake8 complain less

parent 5eab9c4b
......@@ -87,8 +87,7 @@ class Client:
"""
_logger.info("find_endpoint %r %r %r", endpoints, security_mode, policy_uri)
for ep in endpoints:
if (ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and ep.SecurityMode == security_mode and
ep.SecurityPolicyUri == policy_uri):
if (ep.EndpointUrl.startswith(ua.OPC_TCP_SCHEME) and ep.SecurityMode == security_mode and ep.SecurityPolicyUri == policy_uri):
return ep
raise ua.UaError(f"No matching endpoints: {security_mode}, {policy_uri}")
......@@ -137,16 +136,17 @@ class Client:
policy_class = getattr(security_policies, f"SecurityPolicy{parts[0]}")
mode = getattr(ua.MessageSecurityMode, parts[1])
return await self.set_security(policy_class, parts[2], parts[3], client_key_password,
parts[4] if len(parts) >= 5 else None, mode)
async def set_security(self,
policy: ua.SecurityPolicy,
certificate: Union[str, uacrypto.CertProperties],
private_key: Union[str, uacrypto.CertProperties],
private_key_password: Optional[Union[str, bytes]] = None,
server_certificate: Optional[Union[str, uacrypto.CertProperties]] = None,
mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt):
return await self.set_security(policy_class, parts[2], parts[3], client_key_password, parts[4] if len(parts) >= 5 else None, mode)
async def set_security(
self,
policy: ua.SecurityPolicy,
certificate: Union[str, uacrypto.CertProperties],
private_key: Union[str, uacrypto.CertProperties],
private_key_password: Optional[Union[str, bytes]] = None,
server_certificate: Optional[Union[str, uacrypto.CertProperties]] = None,
mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt,
):
"""
Set SecureConnection mode.
Call this before connect()
......@@ -164,17 +164,23 @@ class Client:
private_key = uacrypto.CertProperties(private_key, password=private_key_password)
return await self._set_security(policy, certificate, private_key, server_certificate, mode)
async def _set_security(self,
policy: ua.SecurityPolicy,
certificate: uacrypto.CertProperties,
private_key: uacrypto.CertProperties,
server_cert: uacrypto.CertProperties,
mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt):
async def _set_security(
self,
policy: ua.SecurityPolicy,
certificate: uacrypto.CertProperties,
private_key: uacrypto.CertProperties,
server_cert: uacrypto.CertProperties,
mode: ua.MessageSecurityMode = ua.MessageSecurityMode.SignAndEncrypt,
):
if isinstance(server_cert, uacrypto.CertProperties):
server_cert = await uacrypto.load_certificate(server_cert.path, server_cert.extension)
cert = await uacrypto.load_certificate(certificate.path, certificate.extension)
pk = await uacrypto.load_private_key(private_key.path, private_key.password, private_key.extension)
pk = await uacrypto.load_private_key(
private_key.path,
private_key.password,
private_key.extension,
)
self.security_policy = policy(server_cert, cert, pk, mode)
self.uaclient.set_security(self.security_policy)
......@@ -184,10 +190,7 @@ class Client:
"""
self.user_certificate = await uacrypto.load_certificate(path, extension)
async def load_private_key(self,
path: str,
password: Optional[Union[str, bytes]] = None,
extension: Optional[str] = None):
async def load_private_key(self, path: str, password: Optional[Union[str, bytes]] = None, extension: Optional[str] = None):
"""
Load user private key. This is used for authenticating using certificate
"""
......@@ -296,8 +299,7 @@ class Client:
params.ClientNonce = create_nonce(self.security_policy.symmetric_key_size)
result = await self.uaclient.open_secure_channel(params)
if self.secure_channel_timeout != result.SecurityToken.RevisedLifetime:
_logger.info("Requested secure channel timeout to be %dms, got %dms instead",
self.secure_channel_timeout, result.SecurityToken.RevisedLifetime)
_logger.info("Requested secure channel timeout to be %dms, got %dms instead", self.secure_channel_timeout, result.SecurityToken.RevisedLifetime)
self.secure_channel_timeout = result.SecurityToken.RevisedLifetime
async def close_secure_channel(self):
......@@ -384,8 +386,7 @@ class Client:
self._policy_ids = ep.UserIdentityTokens
# Actual maximum number of milliseconds that a Session shall remain open without activity
if self.session_timeout != response.RevisedSessionTimeout:
_logger.warning("Requested session timeout to be %dms, got %dms instead",
self.secure_channel_timeout, response.RevisedSessionTimeout)
_logger.warning("Requested session timeout to be %dms, got %dms instead", self.secure_channel_timeout, response.RevisedSessionTimeout)
self.session_timeout = response.RevisedSessionTimeout
self._renew_channel_task = self.loop.create_task(self._renew_channel_loop())
return response
......@@ -407,7 +408,7 @@ class Client:
_logger.debug("server state is: %s ", val)
except asyncio.CancelledError:
pass
except:
except Exception:
_logger.exception("Error while renewing session")
raise
......@@ -448,9 +449,7 @@ class Client:
if self.security_policy.AsymmetricSignatureURI:
params.ClientSignature.Algorithm = self.security_policy.AsymmetricSignatureURI
else:
params.ClientSignature.Algorithm = (
security_policies.SecurityPolicyBasic256.AsymmetricSignatureURI
)
params.ClientSignature.Algorithm = (security_policies.SecurityPolicyBasic256.AsymmetricSignatureURI)
params.ClientSignature.Signature = self.security_policy.asymmetric_cryptography.signature(challenge)
params.LocaleIds.append("en")
if not username and not certificate:
......
......@@ -74,9 +74,7 @@ class UASocketProtocol(asyncio.Protocol):
self.receive_buffer = data
return
if len(buf) < header.body_size:
self.logger.debug(
'We did not receive enough data from server. Need %s got %s', header.body_size, len(buf)
)
self.logger.debug('We did not receive enough data from server. Need %s got %s', header.body_size, len(buf))
self.receive_buffer = data
return
msg = self._connection.receive_from_header_and_body(header, buf)
......@@ -146,10 +144,7 @@ class UASocketProtocol(asyncio.Protocol):
"""
timeout = self.timeout if timeout is None else timeout
try:
data = await asyncio.wait_for(
self._send_request(request, timeout, message_type),
timeout if timeout else None
)
data = await asyncio.wait_for(self._send_request(request, timeout, message_type), timeout if timeout else None)
except Exception:
if self.state != self.OPEN:
raise ConnectionError("Connection is closed") from None
......@@ -173,9 +168,7 @@ class UASocketProtocol(asyncio.Protocol):
try:
self._callbackmap[request_id].set_result(body)
except KeyError:
raise ua.UaError(
f"No request found for request id: {request_id}, pending are {self._callbackmap.keys()}"
)
raise ua.UaError(f"No request found for request id: {request_id}, pending are {self._callbackmap.keys()}")
except asyncio.InvalidStateError:
if not self.closed:
raise ua.UaError(f"Future for request id {request_id} is already done")
......@@ -216,13 +209,9 @@ class UASocketProtocol(asyncio.Protocol):
request = ua.OpenSecureChannelRequest()
request.Parameters = params
if self._open_secure_channel_exchange is not None:
raise RuntimeError('Two Open Secure Channel requests can not happen too close to each other. '
'The response must be processed and returned before the next request can be sent.')
raise RuntimeError('Two Open Secure Channel requests can not happen too close to each other. ' 'The response must be processed and returned before the next request can be sent.')
self._open_secure_channel_exchange = params
await asyncio.wait_for(
self._send_request(request, message_type=ua.MessageType.SecureOpen),
self.timeout
)
await asyncio.wait_for(self._send_request(request, message_type=ua.MessageType.SecureOpen), self.timeout)
_return = self._open_secure_channel_exchange.Parameters
self._open_secure_channel_exchange = None
return _return
......@@ -253,7 +242,6 @@ class UaClient:
In this Python implementation most of the structures are defined in
uaprotocol_auto.py and uaprotocol_hand.py available under asyncua.ua
"""
def __init__(self, timeout=1, loop=None):
"""
:param timeout: Timout in seconds
......@@ -278,9 +266,7 @@ class UaClient:
"""Connect to server socket."""
self.logger.info("opening connection")
# Timeout the connection when the server isn't available
await asyncio.wait_for(
self.loop.create_connection(self._make_protocol, host, port), self._timeout
)
await asyncio.wait_for(self.loop.create_connection(self._make_protocol, host, port), self._timeout)
def disconnect_socket(self):
if self.protocol and self.protocol.state == UASocketProtocol.CLOSED:
......@@ -462,10 +448,7 @@ class UaClient:
request = ua.CreateSubscriptionRequest()
request.Parameters = params
data = await self.protocol.send_request(request)
response = struct_from_binary(
ua.CreateSubscriptionResponse,
data
)
response = struct_from_binary(ua.CreateSubscriptionResponse, data)
response.ResponseHeader.ServiceResult.check()
self._subscription_callbacks[response.Parameters.SubscriptionId] = callback
self.logger.info("create_subscription success SubscriptionId %s", response.Parameters.SubscriptionId)
......@@ -482,10 +465,7 @@ class UaClient:
request = ua.DeleteSubscriptionsRequest()
request.Parameters.SubscriptionIds = subscription_ids
data = await self.protocol.send_request(request)
response = struct_from_binary(
ua.DeleteSubscriptionsResponse,
data
)
response = struct_from_binary(ua.DeleteSubscriptionsResponse, data)
response.ResponseHeader.ServiceResult.check()
self.logger.info("remove subscription callbacks for %r", subscription_ids)
for sid in subscription_ids:
......@@ -549,10 +529,7 @@ class UaClient:
try:
callback = self._subscription_callbacks[subscription_id]
except KeyError:
self.logger.warning(
"Received data for unknown subscription %s active are %s", subscription_id,
self._subscription_callbacks.keys()
)
self.logger.warning("Received data for unknown subscription %s active are %s", subscription_id, self._subscription_callbacks.keys())
else:
try:
if asyncio.iscoroutinefunction(callback):
......
......@@ -13,8 +13,7 @@ class MessageChunk(ua.FrozenClass):
"""
Message Chunk, as described in OPC UA specs Part 6, 6.7.2.
"""
def __init__(self, security_policy, body=b'', msg_type=ua.MessageType.SecureMessage,
chunk_type=ua.ChunkType.Single):
def __init__(self, security_policy, body=b'', msg_type=ua.MessageType.SecureMessage, chunk_type=ua.ChunkType.Single):
self.MessageHeader = ua.Header(msg_type, chunk_type)
if msg_type in (ua.MessageType.SecureMessage, ua.MessageType.SecureClose):
self.SecurityHeader = ua.SymmetricAlgorithmHeader()
......@@ -53,8 +52,7 @@ class MessageChunk(ua.FrozenClass):
if signature_size > 0:
signature = decrypted[-signature_size:]
decrypted = decrypted[:-signature_size]
crypto.verify(header_to_binary(obj.MessageHeader) + struct_to_binary(obj.SecurityHeader) + decrypted,
signature)
crypto.verify(header_to_binary(obj.MessageHeader) + struct_to_binary(obj.SecurityHeader) + decrypted, signature)
data = ua.utils.Buffer(crypto.remove_padding(decrypted))
obj.SequenceHeader = struct_from_binary(ua.SequenceHeader, data)
obj.Body = data.read(len(data))
......@@ -83,8 +81,7 @@ class MessageChunk(ua.FrozenClass):
return max_plain_size - ua.SequenceHeader.max_size() - crypto.signature_size() - crypto.min_padding_size()
@staticmethod
def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage,
channel_id=1, request_id=1, token_id=1):
def message_to_chunks(security_policy, body, max_chunk_size, message_type=ua.MessageType.SecureMessage, channel_id=1, request_id=1, token_id=1):
"""
Pack message body (as binary string) into one or more chunks.
Size of each chunk will not exceed max_chunk_size.
......@@ -239,9 +236,15 @@ class SecureConnection:
The only supported types are SecureOpen, SecureMessage, SecureClose.
If message_type is SecureMessage, the AlgorithmHeader should be passed as arg.
"""
chunks = MessageChunk.message_to_chunks(self.security_policy, message, self._max_chunk_size,
message_type=message_type, channel_id=self.security_token.ChannelId,
request_id=request_id, token_id=self.security_token.TokenId)
chunks = MessageChunk.message_to_chunks(
self.security_policy,
message,
self._max_chunk_size,
message_type=message_type,
channel_id=self.security_token.ChannelId,
request_id=request_id,
token_id=self.security_token.TokenId,
)
for chunk in chunks:
self._sequence_number += 1
if self._sequence_number >= (1 << 32):
......@@ -270,10 +273,9 @@ class SecureConnection:
# Messages sent by the Server before the token expired are not rejected because of
# network delays.
timeout = self.prev_security_token.CreatedAt + \
timedelta(milliseconds=self.prev_security_token.RevisedLifetime * 1.25)
timedelta(milliseconds=self.prev_security_token.RevisedLifetime * 1.25)
if timeout < datetime.utcnow():
raise ua.UaError(f"Security token id {security_hdr.TokenId} has timed out "
f"({timeout} < {datetime.utcnow()})")
raise ua.UaError(f"Security token id {security_hdr.TokenId} has timed out " f"({timeout} < {datetime.utcnow()})")
return
expected_tokens = [self.security_token.TokenId, self.next_security_token.TokenId]
......@@ -286,12 +288,10 @@ class SecureConnection:
raise ValueError(f'Expected chunk, got: {chunk}')
if chunk.MessageHeader.MessageType != ua.MessageType.SecureOpen:
if chunk.MessageHeader.ChannelId != self.security_token.ChannelId:
raise ua.UaError(f'Wrong channel id {chunk.MessageHeader.ChannelId},'
f' expected {self.security_token.ChannelId}')
raise ua.UaError(f'Wrong channel id {chunk.MessageHeader.ChannelId},' f' expected {self.security_token.ChannelId}')
if self._incoming_parts:
if self._incoming_parts[0].SequenceHeader.RequestId != chunk.SequenceHeader.RequestId:
raise ua.UaError(f'Wrong request id {chunk.SequenceHeader.RequestId},'
f' expected {self._incoming_parts[0].SequenceHeader.RequestId}')
raise ua.UaError(f'Wrong request id {chunk.SequenceHeader.RequestId},' f' expected {self._incoming_parts[0].SequenceHeader.RequestId}')
# The sequence number must monotonically increase (but it can wrap around)
seq_num = chunk.SequenceHeader.SequenceNumber
if self._peer_sequence_number is not None:
......@@ -302,9 +302,7 @@ class SecureConnection:
logger.debug('Sequence number wrapped: %d -> %d', self._peer_sequence_number, seq_num)
else:
# Condition for monotonically increase is not met
raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:"
f" {self._peer_sequence_number}, received: {seq_num},"
f" spec says to close connection")
raise ua.UaError(f"Received chunk: {chunk} with wrong sequence expecting:" f" {self._peer_sequence_number}, received: {seq_num}," f" spec says to close connection")
self._peer_sequence_number = seq_num
def receive_from_header_and_body(self, header, body):
......
......@@ -45,7 +45,6 @@ class Node:
OPC-UA protocol. Feel free to look at the code of this class and call
directly UA services methods to optimize your code
"""
def __init__(self, server, nodeid):
self.server = server
self.nodeid = None
......@@ -58,8 +57,7 @@ class Node:
elif isinstance(nodeid, int):
self.nodeid = ua.NodeId(nodeid, 0)
else:
raise ua.UaError(f"argument to node must be a NodeId object or a string"
f" defining a nodeid found {nodeid} of type {type(nodeid)}")
raise ua.UaError(f"argument to node must be a NodeId object or a string" f" defining a nodeid found {nodeid} of type {type(nodeid)}")
self.basenodeid = None
def __eq__(self, other):
......@@ -138,10 +136,7 @@ class Node:
:param values: an iterable of EventNotifier enum values.
"""
event_notifier_bitfield = ua.EventNotifier.to_bitfield(values)
await self.write_attribute(
ua.AttributeIds.EventNotifier,
ua.DataValue(ua.Variant(event_notifier_bitfield, ua.VariantType.Byte))
)
await self.write_attribute(ua.AttributeIds.EventNotifier, ua.DataValue(ua.Variant(event_notifier_bitfield, ua.VariantType.Byte)))
async def read_node_class(self):
"""
......@@ -288,11 +283,9 @@ class Node:
result = await self.server.write(params)
result[0].check()
async def write_params(self, params):
result = await self.server.write(params)
return result
async def read_attribute(self, attr):
"""
......@@ -326,7 +319,7 @@ class Node:
async def read_params(self, params):
result = await self.server.read(params)
return result
async def get_children(self, refs=ua.ObjectIds.HierarchicalReferences, nodeclassmask=ua.NodeClass.Unspecified):
"""
Get all children of a node. By default hierarchical references and all node classes are returned.
......@@ -373,8 +366,7 @@ class Node:
"""
return self.get_children(refs=ua.ObjectIds.HasComponent, nodeclassmask=ua.NodeClass.Method)
async def get_children_descriptions(self, refs=ua.ObjectIds.HierarchicalReferences,
nodeclassmask=ua.NodeClass.Unspecified, includesubtypes=True):
async def get_children_descriptions(self, refs=ua.ObjectIds.HierarchicalReferences, nodeclassmask=ua.NodeClass.Unspecified, includesubtypes=True):
return await self.get_references(refs, ua.BrowseDirection.Forward, nodeclassmask, includesubtypes)
def get_encoding_refs(self):
......@@ -383,8 +375,7 @@ class Node:
def get_description_refs(self):
return self.get_referenced_nodes(ua.ObjectIds.HasDescription, ua.BrowseDirection.Forward)
async def get_references(self, refs=ua.ObjectIds.References, direction=ua.BrowseDirection.Both,
nodeclassmask=ua.NodeClass.Unspecified, includesubtypes=True):
async def get_references(self, refs=ua.ObjectIds.References, direction=ua.BrowseDirection.Both, nodeclassmask=ua.NodeClass.Unspecified, includesubtypes=True):
"""
returns references of the node based on specific filter defined with:
......@@ -418,8 +409,7 @@ class Node:
references.extend(results[0].References)
return references
async def get_referenced_nodes(self, refs=ua.ObjectIds.References, direction=ua.BrowseDirection.Both,
nodeclassmask=ua.NodeClass.Unspecified, includesubtypes=True):
async def get_referenced_nodes(self, refs=ua.ObjectIds.References, direction=ua.BrowseDirection.Both, nodeclassmask=ua.NodeClass.Unspecified, includesubtypes=True):
"""
returns referenced nodes based on specific filter
Paramters are the same as for get_references
......@@ -436,8 +426,7 @@ class Node:
"""
returns type definition of the node.
"""
references = await self.get_references(refs=ua.ObjectIds.HasTypeDefinition,
direction=ua.BrowseDirection.Forward)
references = await self.get_references(refs=ua.ObjectIds.HasTypeDefinition, direction=ua.BrowseDirection.Forward)
if len(references) == 0:
return None
return references[0].NodeId
......@@ -470,9 +459,7 @@ class Node:
path = []
node = self
while True:
refs = await node.get_references(
refs=ua.ObjectIds.HierarchicalReferences, direction=ua.BrowseDirection.Inverse
)
refs = await node.get_references(refs=ua.ObjectIds.HierarchicalReferences, direction=ua.BrowseDirection.Inverse)
if len(refs) > 0:
path.insert(0, refs[0])
node = Node(self.server, refs[0].NodeId)
......@@ -593,9 +580,7 @@ class Node:
result.StatusCode.check()
event_res = []
for res in result.HistoryData.Events:
event_res.append(
Event.from_event_fields(evfilter.SelectClauses, res.EventFields)
)
event_res.append(Event.from_event_fields(evfilter.SelectClauses, res.EventFields))
return event_res
async def history_read_events(self, details):
......
......@@ -8,7 +8,7 @@ import uuid
import logging
# The next two imports are for generated code
from datetime import datetime
from enum import Enum, IntEnum, EnumMeta
from enum import IntEnum, EnumMeta
from xml.etree import ElementTree as ET
from asyncua import ua
......@@ -25,6 +25,7 @@ class EnumType(object):
def __str__(self):
return f"EnumType({self.name, self.fields})"
__repr__ = __str__
def get_code(self):
......
......@@ -53,7 +53,7 @@ def make_structure_code(data_type, name, sdef):
given a StructureDefinition object, generate Python code
"""
if sdef.StructureType not in (ua.StructureType.Structure, ua.StructureType.StructureWithOptionalFields):
#if sdef.StructureType != ua.StructureType.Structure:
# if sdef.StructureType != ua.StructureType.Structure:
raise NotImplementedError(f"Only StructureType implemented, not {ua.StructureType(sdef.StructureType).name} for node {name} with DataTypdeDefinition {sdef}")
code = f"""
......@@ -81,7 +81,7 @@ class {name}:
code += ' ua_types = [\n'
if sdef.StructureType == ua.StructureType.StructureWithOptionalFields:
code += f" ('Encoding', 'Byte'),\n"
code += " ('Encoding', 'Byte'),\n"
uatypes = []
for field in sdef.Fields:
prefix = 'ListOf' if field.ValueRank >= 1 else ''
......@@ -92,8 +92,8 @@ class {name}:
elif field.DataType in ua.enums_by_datatype:
uatype = ua.enums_by_datatype[field.DataType].__name__
else:
#FIXME: we are probably missing many custom tyes here based on builtin types
#maybe we can use ua_utils.get_base_data_type()
# FIXME: we are probably missing many custom tyes here based on builtin types
# maybe we can use ua_utils.get_base_data_type()
raise RuntimeError(f"Unknown datatype for field: {field} in structure:{name}, please report")
if field.ValueRank >= 1 and uatype == 'Char':
uatype = 'String'
......@@ -112,7 +112,7 @@ class {name}:
if not sdef.Fields:
code += " pass"
if sdef.StructureType == ua.StructureType.StructureWithOptionalFields:
code += f" self.Encoding = 0\n"
code += " self.Encoding = 0\n"
for field, uatype in uatypes:
if field.ValueRank >= 1:
default_value = "[]"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment