Commit 758d89c0 authored by Alexander Schrode's avatar Alexander Schrode Committed by oroulet

implement encode/decode

parent d2b48b5a
...@@ -5,6 +5,7 @@ Binary protocol specific functions and constants ...@@ -5,6 +5,7 @@ Binary protocol specific functions and constants
import functools import functools
import struct import struct
import logging import logging
from types import NoneType
from typing import Any, Callable from typing import Any, Callable
import typing import typing
import uuid import uuid
...@@ -253,14 +254,19 @@ def _create_uatype_array_deserializer(vtype): ...@@ -253,14 +254,19 @@ def _create_uatype_array_deserializer(vtype):
return deserialize return deserialize
def field_serializer(ftype) -> Callable[[Any], bytes]: def field_serializer(ftype, dataclazz) -> Callable[[Any], bytes]:
is_optional = type_is_union(ftype) is_optional = type_is_union(ftype)
uatype = ftype uatype = ftype
if is_optional: if is_optional:
uatype = types_from_union(uatype)[0] uatype = types_from_union(uatype)[0]
if type_is_list(uatype): if type_is_list(uatype):
return create_list_serializer(type_from_list(uatype)) ft = type_from_list(uatype)
return create_list_serializer(ft, ft == dataclazz)
else: else:
if ftype == dataclazz:
if is_optional:
return lambda val: b'' if val is None else create_type_serializer(uatype)(val)
return lambda x: create_type_serializer(uatype)(x)
serializer = create_type_serializer(uatype) serializer = create_type_serializer(uatype)
if is_optional: if is_optional:
return lambda val: b'' if val is None else serializer(val) return lambda val: b'' if val is None else serializer(val)
...@@ -281,7 +287,7 @@ def create_dataclass_serializer(dataclazz): ...@@ -281,7 +287,7 @@ def create_dataclass_serializer(dataclazz):
if issubclass(dataclazz, ua.UaUnion): if issubclass(dataclazz, ua.UaUnion):
# Union is a class with Encoding and Value field # Union is a class with Encoding and Value field
# the value is depended of encoding # the value is depended of encoding
encoding_funcs = [field_serializer(t) for t in dataclazz._union_types] encoding_funcs = [field_serializer(t, dataclazz) for t in dataclazz._union_types]
def union_serialize(obj): def union_serialize(obj):
bin = Primitives.UInt32.pack(obj.Encoding) bin = Primitives.UInt32.pack(obj.Encoding)
...@@ -305,7 +311,7 @@ def create_dataclass_serializer(dataclazz): ...@@ -305,7 +311,7 @@ def create_dataclass_serializer(dataclazz):
enc |= enc_val enc |= enc_val
return enc return enc
encoding_functions = [(f.name, field_serializer(f.type)) for f in data_fields] encoding_functions = [(f.name, field_serializer(f.type, dataclazz)) for f in data_fields]
def serialize(obj): def serialize(obj):
return b''.join( return b''.join(
...@@ -361,7 +367,7 @@ def to_binary(uatype, val): ...@@ -361,7 +367,7 @@ def to_binary(uatype, val):
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def create_list_serializer(uatype) -> Callable[[Any], bytes]: def create_list_serializer(uatype, recursive: bool = False) -> Callable[[Any], bytes]:
""" """
Given a type, return a function that takes a list of instances Given a type, return a function that takes a list of instances
of that type and serializes it. of that type and serializes it.
...@@ -369,9 +375,16 @@ def create_list_serializer(uatype) -> Callable[[Any], bytes]: ...@@ -369,9 +375,16 @@ def create_list_serializer(uatype) -> Callable[[Any], bytes]:
if hasattr(Primitives1, uatype.__name__): if hasattr(Primitives1, uatype.__name__):
data_type = getattr(Primitives1, uatype.__name__) data_type = getattr(Primitives1, uatype.__name__)
return data_type.pack_array return data_type.pack_array
type_serializer = create_type_serializer(uatype)
none_val = Primitives.Int32.pack(-1) none_val = Primitives.Int32.pack(-1)
if recursive:
def recursive_serialize(val):
if val is None:
return none_val
data_size = Primitives.Int32.pack(len(val))
return data_size + b''.join(create_type_serializer(uatype)(el) for el in val)
return recursive_serialize
type_serializer = create_type_serializer(uatype)
def serialize(val): def serialize(val):
if val is None: if val is None:
return none_val return none_val
...@@ -554,7 +567,13 @@ def extensionobject_to_binary(obj): ...@@ -554,7 +567,13 @@ def extensionobject_to_binary(obj):
return b''.join(packet) return b''.join(packet)
def _create_list_deserializer(uatype): def _create_list_deserializer(uatype, recursive: bool = False):
if recursive:
def _deserialize(data):
size = Primitives.Int32.unpack(data)
return [_create_type_deserializer(uatype)(data) for _ in range(size)]
return _deserialize
element_deserializer = _create_type_deserializer(uatype) element_deserializer = _create_type_deserializer(uatype)
def _deserialize(data): def _deserialize(data):
...@@ -564,7 +583,7 @@ def _create_list_deserializer(uatype): ...@@ -564,7 +583,7 @@ def _create_list_deserializer(uatype):
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _create_type_deserializer(uatype): def _create_type_deserializer(uatype, dataclazz = NoneType):
if type_is_union(uatype): if type_is_union(uatype):
return _create_type_deserializer(types_from_union(uatype)[0]) return _create_type_deserializer(types_from_union(uatype)[0])
if type_is_list(uatype): if type_is_list(uatype):
...@@ -573,7 +592,7 @@ def _create_type_deserializer(uatype): ...@@ -573,7 +592,7 @@ def _create_type_deserializer(uatype):
vtype = getattr(ua.VariantType, utype.__name__) vtype = getattr(ua.VariantType, utype.__name__)
return _create_uatype_array_deserializer(vtype) return _create_uatype_array_deserializer(vtype)
else: else:
return _create_list_deserializer(utype) return _create_list_deserializer(utype, utype == dataclazz)
if hasattr(ua.VariantType, uatype.__name__): if hasattr(ua.VariantType, uatype.__name__):
vtype = getattr(ua.VariantType, uatype.__name__) vtype = getattr(ua.VariantType, uatype.__name__)
return _create_uatype_deserializer(vtype) return _create_uatype_deserializer(vtype)
...@@ -608,7 +627,7 @@ def _create_dataclass_deserializer(objtype): ...@@ -608,7 +627,7 @@ def _create_dataclass_deserializer(objtype):
if issubclass(objtype, ua.UaUnion): if issubclass(objtype, ua.UaUnion):
# unions are just objects with encoding and value field # unions are just objects with encoding and value field
typefields = fields(objtype) typefields = fields(objtype)
field_deserializers = [_create_type_deserializer(t) for t in objtype._union_types] field_deserializers = [_create_type_deserializer(t, objtype) for t in objtype._union_types]
byte_decode = next(_create_type_deserializer(f.type) for f in typefields if f.name == "Encoding") byte_decode = next(_create_type_deserializer(f.type) for f in typefields if f.name == "Encoding")
def decode_union(data): def decode_union(data):
...@@ -640,7 +659,7 @@ def _create_dataclass_deserializer(objtype): ...@@ -640,7 +659,7 @@ def _create_dataclass_deserializer(objtype):
if subtypes: if subtypes:
deserialize_field = extensionobject_from_binary deserialize_field = extensionobject_from_binary
else: else:
deserialize_field = _create_type_deserializer(field_type) deserialize_field = _create_type_deserializer(field_type, objtype)
field_deserializers.append((field, optional_enc_bit, deserialize_field)) field_deserializers.append((field, optional_enc_bit, deserialize_field))
def decode(data): def decode(data):
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
</References> </References>
<Definition Name="1:MyParameterType"> <Definition Name="1:MyParameterType">
<Field DataType="MyParameterType" ValueRank="1" ArrayDimensions="0" Name="Subparameters"/> <Field DataType="MyParameterType" ValueRank="1" ArrayDimensions="0" Name="Subparameters"/>
<Field Name="Value" DataType="i=7"/>
</Definition> </Definition>
</UADataType> </UADataType>
......
...@@ -22,7 +22,7 @@ from asyncua.common.methods import call_method_full ...@@ -22,7 +22,7 @@ from asyncua.common.methods import call_method_full
from asyncua.common.copy_node_util import copy_node from asyncua.common.copy_node_util import copy_node
from asyncua.common.instantiate_util import instantiate from asyncua.common.instantiate_util import instantiate
from asyncua.common.structures104 import new_struct, new_enum, new_struct_field from asyncua.common.structures104 import new_struct, new_enum, new_struct_field
from asyncua.ua.ua_binary import struct_to_binary, struct_from_binary
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.asyncio
...@@ -1640,3 +1640,14 @@ async def test_custom_struct_with_strange_chars(opc): ...@@ -1640,3 +1640,14 @@ async def test_custom_struct_with_strange_chars(opc):
var = await opc.opc.nodes.objects.add_variable(idx, "my_siemens_struct", ua.Variant(mystruct, ua.VariantType.ExtensionObject)) var = await opc.opc.nodes.objects.add_variable(idx, "my_siemens_struct", ua.Variant(mystruct, ua.VariantType.ExtensionObject))
val = await var.read_value() val = await var.read_value()
assert val.My_UInt32 == [78, 79] assert val.My_UInt32 == [78, 79]
async def test_custom_struct_recursive_serialize(opc):
idx = 4
nodes = await opc.opc.import_xml("tests/custom_struct_recursive.xml")
await opc.opc.load_data_type_definitions()
param =ua.MyParameterType(Value=2)
param.Subparameters.append(ua.MyParameterType(Value=1))
bin = struct_to_binary(param)
res = struct_from_binary(ua.MyParameterType, ua.utils.Buffer(bin))
assert param == res
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment