Commit 2b64f0f1 authored by Alexander Schrode's avatar Alexander Schrode Committed by oroulet

support fields with "AllowSubtypes"

Allow fields in structs that are subtype of a class.
parent 0811f733
...@@ -13,7 +13,7 @@ from dataclasses import is_dataclass, fields ...@@ -13,7 +13,7 @@ from dataclasses import is_dataclass, fields
from asyncua import ua from asyncua import ua
from .uaerrors import UaError from .uaerrors import UaError
from ..common.utils import Buffer from ..common.utils import Buffer
from .uatypes import type_is_list, type_is_union, type_from_list, types_from_union from .uatypes import ExtensionObject, type_is_list, type_is_union, type_from_list, types_from_union, type_allow_subclass
logger = logging.getLogger('__name__') logger = logging.getLogger('__name__')
...@@ -326,6 +326,8 @@ def struct_to_binary(obj): ...@@ -326,6 +326,8 @@ def struct_to_binary(obj):
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def create_type_serializer(uatype): def create_type_serializer(uatype):
"""Create a binary serialization function for the given UA type""" """Create a binary serialization function for the given UA type"""
if type_allow_subclass(uatype):
return extensionobject_to_binary
if type_is_list(uatype): if type_is_list(uatype):
return create_list_serializer(type_from_list(uatype)) return create_list_serializer(type_from_list(uatype))
if hasattr(Primitives, uatype.__name__): if hasattr(Primitives, uatype.__name__):
...@@ -613,6 +615,7 @@ def _create_dataclass_deserializer(objtype): ...@@ -613,6 +615,7 @@ def _create_dataclass_deserializer(objtype):
for field in fields(objtype): for field in fields(objtype):
optional_enc_bit = 0 optional_enc_bit = 0
field_type = resolved_fieldtypes[field.name] field_type = resolved_fieldtypes[field.name]
subtypes = type_allow_subclass(field.type)
# if our member has a switch and it is not set we will need to skip it # if our member has a switch and it is not set we will need to skip it
if type_is_union(field_type): if type_is_union(field_type):
optional_enc_bit = 1 << enc_count optional_enc_bit = 1 << enc_count
...@@ -623,7 +626,7 @@ def _create_dataclass_deserializer(objtype): ...@@ -623,7 +626,7 @@ def _create_dataclass_deserializer(objtype):
def decode(data): def decode(data):
kwargs = {} kwargs = {}
enc = 0 enc = 0
for field, optional_enc_bit, deserialize_field in field_deserializers: for field, optional_enc_bit, deserialize_field, subtypes in field_deserializers:
if field.name == "Encoding": if field.name == "Encoding":
enc = deserialize_field(data) enc = deserialize_field(data)
elif optional_enc_bit == 0 or enc & optional_enc_bit: elif optional_enc_bit == 0 or enc & optional_enc_bit:
......
...@@ -52,6 +52,8 @@ def type_is_union(uatype): ...@@ -52,6 +52,8 @@ def type_is_union(uatype):
def type_is_list(uatype): def type_is_list(uatype):
return get_origin(uatype) == list return get_origin(uatype) == list
def type_allow_subclass(uatype):
return get_origin(uatype) not in [Union, list, None]
def types_from_union(uatype, origin=None): def types_from_union(uatype, origin=None):
if origin is None: if origin is None:
...@@ -68,12 +70,16 @@ def types_from_union(uatype, origin=None): ...@@ -68,12 +70,16 @@ def types_from_union(uatype, origin=None):
def type_from_list(uatype): def type_from_list(uatype):
return get_args(uatype)[0] return get_args(uatype)[0]
def type_from_allow_subtype(uatype):
return get_args(uatype)[0]
def type_string_from_type(uatype): def type_string_from_type(uatype):
if type_is_union(uatype): if type_is_union(uatype):
uatype = types_from_union(uatype)[0] uatype = types_from_union(uatype)[0]
elif type_is_list(uatype): elif type_is_list(uatype):
uatype = type_from_list(uatype) uatype = type_from_list(uatype)
elif type_allow_subclass(uatype):
uatype = type_from_allow_subtype(uatype)
return uatype.__name__ return uatype.__name__
......
...@@ -4,7 +4,7 @@ Generate address space code from xml file specification ...@@ -4,7 +4,7 @@ Generate address space code from xml file specification
from xml.etree import ElementTree from xml.etree import ElementTree
from logging import getLogger from logging import getLogger
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, List from typing import Any, List, Optional
import re import re
from pathlib import Path from pathlib import Path
...@@ -96,7 +96,7 @@ class Field: ...@@ -96,7 +96,7 @@ class Field:
@dataclass @dataclass
class Struct: class Struct:
name: str = None name: str = None
basetype: str = None basetype: Optional[str] = None
node_id: str = None node_id: str = None
doc: str = "" doc: str = ""
fields: List[Field] = field(default_factory=list) fields: List[Field] = field(default_factory=list)
...@@ -123,7 +123,7 @@ class Enum: ...@@ -123,7 +123,7 @@ class Enum:
data_type: str = None data_type: str = None
fields: List[Field] = field(default_factory=list) fields: List[Field] = field(default_factory=list)
doc: str = "" doc: str = ""
is_option_set: bool = False is_option_set: bool = False
...@@ -210,6 +210,8 @@ def reorder_structs(model): ...@@ -210,6 +210,8 @@ def reorder_structs(model):
ok = False ok = False
if ok: if ok:
_add_struct(s, newstructs, waiting_structs, types) _add_struct(s, newstructs, waiting_structs, types)
if len(model.structs) != len(newstructs): if len(model.structs) != len(newstructs):
_logger.warning('Error while reordering structs, some structs could not be reinserted: had %s structs, we now have %s structs', len(model.structs), len(newstructs)) _logger.warning('Error while reordering structs, some structs could not be reinserted: had %s structs, we now have %s structs', len(model.structs), len(newstructs))
s1 = set(model.structs) s1 = set(model.structs)
...@@ -230,6 +232,9 @@ def nodeid_to_names(model): ...@@ -230,6 +232,9 @@ def nodeid_to_names(model):
ids["22"] = "ExtensionObject" ids["22"] = "ExtensionObject"
for struct in model.structs: for struct in model.structs:
if struct.basetype is not None:
if struct.basetype.startswith("i="):
struct.basetype = ids[struct.basetype[2:]]
for sfield in struct.fields: for sfield in struct.fields:
if sfield.data_type.startswith("i="): if sfield.data_type.startswith("i="):
sfield.data_type = ids[sfield.data_type[2:]] sfield.data_type = ids[sfield.data_type[2:]]
...@@ -270,6 +275,16 @@ def split_requests(model): ...@@ -270,6 +275,16 @@ def split_requests(model):
model.structs = structs model.structs = structs
def get_basetypes(el) -> List[str]:
# return all basetypes
basetypes = []
for ref in el.findall("./{*}References/{*}Reference"):
if ref.get("ReferenceType") == "HasSubtype" and \
ref.get("IsForward", "true") == "false" and \
ref.text != "i=22":
basetypes.append(ref.text)
return basetypes
class Parser: class Parser:
def __init__(self, path): def __init__(self, path):
self.path = path self.path = path
...@@ -341,8 +356,14 @@ class Parser: ...@@ -341,8 +356,14 @@ class Parser:
doc=doc, doc=doc,
node_id=el.get("NodeId"), node_id=el.get("NodeId"),
) )
basetypes = get_basetypes(el)
if basetypes:
struct.basetype = basetypes[0]
if len(basetypes) > 1:
print(f'Error found mutliple basetypes for {struct} {basetypes}')
for sfield in el.findall("./{*}Definition/{*}Field"): for sfield in el.findall("./{*}Definition/{*}Field"):
opt = sfield.get("IsOptional", "false") opt = sfield.get("IsOptional", "false")
allow_subtypes = True if sfield.get("AllowSubTypes", "false") == 'true' else False
is_optional = True if opt == "true" else False is_optional = True if opt == "true" else False
f = Field( f = Field(
name=sfield.get("Name"), name=sfield.get("Name"),
...@@ -351,6 +372,7 @@ class Parser: ...@@ -351,6 +372,7 @@ class Parser:
array_dimensions=sfield.get("ArayDimensions"), array_dimensions=sfield.get("ArayDimensions"),
value=sfield.get("Value"), value=sfield.get("Value"),
is_optional=is_optional, is_optional=is_optional,
allow_subtypes=allow_subtypes
) )
if is_optional: if is_optional:
struct.has_optional = True struct.has_optional = True
......
...@@ -4,7 +4,7 @@ import datetime ...@@ -4,7 +4,7 @@ import datetime
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
IgnoredEnums = ["NodeIdType"] IgnoredEnums = ["NodeIdType"]
IgnoredStructs = ["QualifiedName", "NodeId", "ExpandedNodeId", "FilterOperand", "Variant", "DataValue", IgnoredStructs = ["QualifiedName", "NodeId", "ExpandedNodeId", "Variant", "DataValue",
"ExtensionObject", "XmlElement", "LocalizedText"] "ExtensionObject", "XmlElement", "LocalizedText"]
...@@ -61,7 +61,7 @@ class CodeGenerator: ...@@ -61,7 +61,7 @@ class CodeGenerator:
self.write('') self.write('')
self.write('from datetime import datetime') self.write('from datetime import datetime')
self.write('from enum import IntEnum, IntFlag') self.write('from enum import IntEnum, IntFlag')
self.write('from typing import Union, List, Optional') self.write('from typing import Union, List, Optional, Type')
self.write('from dataclasses import dataclass, field') self.write('from dataclasses import dataclass, field')
self.write('') self.write('')
self.write('from asyncua.ua.uatypes import FROZEN') self.write('from asyncua.ua.uatypes import FROZEN')
...@@ -110,7 +110,10 @@ class CodeGenerator: ...@@ -110,7 +110,10 @@ class CodeGenerator:
self.write('') self.write('')
self.iidx = 0 self.iidx = 0
self.write('@dataclass(frozen=FROZEN)') self.write('@dataclass(frozen=FROZEN)')
self.write(f'class {obj.name}:') if obj.basetype:
self.write(f'class {obj.name}({obj.basetype}):')
else:
self.write(f'class {obj.name}:')
self.iidx += 1 self.iidx += 1
self.write('"""') self.write('"""')
if obj.doc: if obj.doc:
...@@ -141,7 +144,8 @@ class CodeGenerator: ...@@ -141,7 +144,8 @@ class CodeGenerator:
typestring = f"List[{typestring}]" typestring = f"List[{typestring}]"
if field.is_optional: if field.is_optional:
typestring = f"Optional[{typestring}]" typestring = f"Optional[{typestring}]"
if field.allow_subtypes:
typestring = f"Type[{typestring}]"
if field.name == field.data_type: if field.name == field.data_type:
# variable name and type name are the same. Dataclass do not like it # variable name and type name are the same. Dataclass do not like it
hack_names.append(field.name) hack_names.append(field.name)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment