Commit bc4aea7e authored by Alexander Schrode's avatar Alexander Schrode Committed by GitHub

[Datatypes] allow optional arrays (#1238)

* allow optional arrays

* is list fix for py 3.8

* finish optional array support
parent 182d0252
......@@ -229,8 +229,9 @@ class {struct_name}{base_class}:
uatype = 'String'
elif sfield.ValueRank >= 1 or sfield.ArrayDimensions:
uatype = f"List[{uatype}]"
elif sfield.IsOptional:
if sfield.IsOptional:
uatype = f"Optional[{uatype}]"
default_value = 'None'
fields.append((fname, uatype, default_value))
if is_union:
# Generate getter and setter to mimic opc ua union access
......@@ -256,7 +257,6 @@ class {struct_name}{base_class}:
else:
for fname, uatype, default_value in fields:
code += f" {fname}: {uatype} = {default_value}\n"
return code
......
......@@ -5,7 +5,7 @@ Binary protocol specific functions and constants
import functools
import struct
import logging
from typing import Any, Callable
from typing import Any, Callable, Union
import typing
import uuid
from enum import Enum, IntFlag
......@@ -13,7 +13,7 @@ from dataclasses import is_dataclass, fields
from asyncua import ua
from .uaerrors import UaError
from ..common.utils import Buffer
from .uatypes import type_is_list, type_is_union, type_from_list, types_from_union, type_allow_subclass
from .uatypes import type_from_optional, type_is_list, type_is_union, type_from_list, types_or_list_from_union, type_allow_subclass
logger = logging.getLogger('__name__')
......@@ -257,9 +257,12 @@ def field_serializer(ftype, dataclazz) -> Callable[[Any], bytes]:
is_optional = type_is_union(ftype)
uatype = ftype
if is_optional:
uatype = types_from_union(uatype)[0]
# unpack optional because this we will handeled by the decoding
uatype = type_from_optional(uatype)
if type_is_list(uatype):
ft = type_from_list(uatype)
if is_optional:
return lambda val: b'' if val is None else create_list_serializer(ft, ft == dataclazz)(val)
return create_list_serializer(ft, ft == dataclazz)
else:
if ftype == dataclazz:
......@@ -588,7 +591,9 @@ def _create_list_deserializer(uatype, recursive: bool = False):
@functools.lru_cache(maxsize=None)
def _create_type_deserializer(uatype, dataclazz):
if type_is_union(uatype):
return _create_type_deserializer(types_from_union(uatype)[0], uatype)
array, uatype = types_or_list_from_union(uatype)
if not array:
return _create_type_deserializer(uatype, uatype)
if type_is_list(uatype):
utype = type_from_list(uatype)
if hasattr(ua.VariantType, utype.__name__):
......
......@@ -59,6 +59,27 @@ def type_is_list(uatype):
def type_allow_subclass(uatype):
return get_origin(uatype) not in [Union, list, None]
def types_or_list_from_union(uatype):
# returns the type of a union or the list of type if a list is inside the union
types = []
for subtype in get_args(uatype):
if hasattr(subtype, '_paramspec_tvars'):
# @hack how to check if a parameter is a list:
# check if have _paramspec_tvars works for type[X]
return True, subtype
elif hasattr(subtype, '_name'):
# @hack how to check if parameter is union or list
# if _name is not List, it is Union
if subtype._name == 'List':
return True, subtype
elif not isinstance(None, subtype):
types.append(subtype)
if not types:
raise ValueError(f"Union {uatype} does not seem to contain a valid type")
return False, types[0]
def types_from_union(uatype, origin=None):
if origin is None:
origin = get_origin(uatype)
......@@ -74,9 +95,15 @@ def types_from_union(uatype, origin=None):
def type_from_list(uatype):
return get_args(uatype)[0]
def type_from_optional(uatype):
return get_args(uatype)[0]
def type_from_allow_subtype(uatype):
return get_args(uatype)[0]
def type_string_from_type(uatype):
if type_is_union(uatype):
uatype = types_from_union(uatype)[0]
......
......@@ -27,31 +27,33 @@
<ua:ModelInfo Tool="UaModeler" Hash="l1kpQd8c2aTEMZi8xJvKfA==" Version="1.6.3"/>
</Extension>
</Extensions>
<UADataType NodeId="ns=1;i=3002" BrowseName="1:MyStruct">
<DisplayName>MyStruct</DisplayName>
<References>
<Reference ReferenceType="HasEncoding">ns=1;i=5001</Reference>
<Reference ReferenceType="HasEncoding">ns=1;i=5003</Reference>
<Reference ReferenceType="HasEncoding">ns=1;i=5002</Reference>
<Reference ReferenceType="HasSubtype" IsForward="false">i=22</Reference>
</References>
<Definition Name="1:MyStruct">
<Field DataType="Double" Name="toto"/>
</Definition>
</UADataType>
<UADataType NodeId="ns=1;i=3003" BrowseName="1:MySubstruct">
<DisplayName>MySubstruct</DisplayName>
<References>
<Reference ReferenceType="HasEncoding">ns=1;i=5004</Reference>
<Reference ReferenceType="HasEncoding">ns=1;i=5006</Reference>
<Reference ReferenceType="HasEncoding">ns=1;i=5005</Reference>
<Reference ReferenceType="HasSubtype" IsForward="false">ns=1;i=3002</Reference>
<Reference ReferenceType="HasSubtype" IsForward="false">i=22</Reference>
</References>
<Definition Name="1:MySubstruct">
<Field IsOptional="true" DataType="Double" Name="titi"/>
<Field IsOptional="true" DataType="Double" Name="opt_array" ValueRank="1" ArrayDimensions="0"/>
<Field DataType="MyStruct" ValueRank="1" ArrayDimensions="0" Name="structs"/>
</Definition>
</UADataType>
<UADataType NodeId="ns=1;i=3002" BrowseName="1:MyStruct">
<DisplayName>MyStruct</DisplayName>
<References>
<Reference ReferenceType="HasEncoding">ns=1;i=5001</Reference>
<Reference ReferenceType="HasEncoding">ns=1;i=5003</Reference>
<Reference ReferenceType="HasEncoding">ns=1;i=5002</Reference>
<Reference ReferenceType="HasSubtype" IsForward="false">i=22</Reference>
</References>
<Definition Name="1:MyStruct">
<Field DataType="Double" Name="toto"/>
</Definition>
</UADataType>
<UAObject SymbolicName="http___yourorganisation_org_testsubstruct_" NodeId="ns=1;i=5007" BrowseName="1:http://yourorganisation.org/testsubstruct/">
<DisplayName>http://yourorganisation.org/testsubstruct/</DisplayName>
<References>
......
......@@ -1146,12 +1146,13 @@ async def test_import_xml_data_type_definition(opc):
sdef = await datatype.read_data_type_definition()
assert isinstance(sdef, ua.StructureDefinition)
s = ua.MyStruct()
s.toto = 0.1
ss = ua.MySubstruct()
assert ss.titi == None
assert ss.titi is None
assert ss.opt_array is None
assert isinstance(ss.structs, list)
ss.titi = 1
ss.titi = 1.0
ss.structs.append(s)
ss.structs.append(s)
......@@ -1159,6 +1160,11 @@ async def test_import_xml_data_type_definition(opc):
s2 = await var.read_value()
assert s2.structs[1].toto == ss.structs[1].toto == 0.1
assert s2.opt_array is None
s2.opt_array = [1]
await var.write_value(s2)
s2 = await var.read_value()
assert s2.opt_array == [1]
await opc.opc.delete_nodes([datatype, var])
n = []
[n.append(opc.opc.get_node(node)) for node in nodes]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment