Commit 47469d2d authored by David S. Miller's avatar David S. Miller

Merge branch 'tools-ynl-byteorder'

Donald Hunter says:

====================
tools: ynl: Add byte-order support for struct members

This patchset adds support to ynl for handling byte-order in struct
members. The first patch is a refactor to use predefined Struct() objects
instead of generating byte-order specific formats on the fly. The second
patch adds byte-order handling for struct members.
====================
Reviewed-by: default avatarJakub Kicinski <kuba@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 59088b5a bddd2e56
...@@ -122,6 +122,8 @@ properties: ...@@ -122,6 +122,8 @@ properties:
enum: [ u8, u16, u32, u64, s8, s16, s32, s64, string ] enum: [ u8, u16, u32, u64, s8, s16, s32, s64, string ]
len: len:
$ref: '#/$defs/len-or-define' $ref: '#/$defs/len-or-define'
byte-order:
enum: [ little-endian, big-endian ]
# End genetlink-legacy # End genetlink-legacy
attribute-sets: attribute-sets:
......
...@@ -227,10 +227,12 @@ class SpecStructMember(SpecElement): ...@@ -227,10 +227,12 @@ class SpecStructMember(SpecElement):
Attributes: Attributes:
type string, type of the member attribute type string, type of the member attribute
byte_order string or None for native byte order
""" """
def __init__(self, family, yaml): def __init__(self, family, yaml):
super().__init__(family, yaml) super().__init__(family, yaml)
self.type = yaml['type'] self.type = yaml['type']
self.byte_order = yaml.get('byte-order')
class SpecStruct(SpecElement): class SpecStruct(SpecElement):
......
# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
from collections import namedtuple
import functools import functools
import os import os
import random import random
import socket import socket
import struct import struct
from struct import Struct
import yaml import yaml
from .nlspec import SpecFamily from .nlspec import SpecFamily
...@@ -76,10 +78,17 @@ class NlError(Exception): ...@@ -76,10 +78,17 @@ class NlError(Exception):
class NlAttr: class NlAttr:
type_formats = { 'u8' : ('B', 1), 's8' : ('b', 1), ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
'u16': ('H', 2), 's16': ('h', 2), type_formats = {
'u32': ('I', 4), 's32': ('i', 4), 'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")),
'u64': ('Q', 8), 's64': ('q', 8) } 's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")),
'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
}
def __init__(self, raw, offset): def __init__(self, raw, offset):
self._len, self._type = struct.unpack("HH", raw[offset:offset + 4]) self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
...@@ -88,25 +97,17 @@ class NlAttr: ...@@ -88,25 +97,17 @@ class NlAttr:
self.full_len = (self.payload_len + 3) & ~3 self.full_len = (self.payload_len + 3) & ~3
self.raw = raw[offset + 4:offset + self.payload_len] self.raw = raw[offset + 4:offset + self.payload_len]
def format_byte_order(byte_order): @classmethod
def get_format(cls, attr_type, byte_order=None):
format = cls.type_formats[attr_type]
if byte_order: if byte_order:
return ">" if byte_order == "big-endian" else "<" return format.big if byte_order == "big-endian" \
return "" else format.little
return format.native
def as_u8(self): def as_scalar(self, attr_type, byte_order=None):
return struct.unpack("B", self.raw)[0] format = self.get_format(attr_type, byte_order)
return format.unpack(self.raw)[0]
def as_u16(self, byte_order=None):
endian = NlAttr.format_byte_order(byte_order)
return struct.unpack(f"{endian}H", self.raw)[0]
def as_u32(self, byte_order=None):
endian = NlAttr.format_byte_order(byte_order)
return struct.unpack(f"{endian}I", self.raw)[0]
def as_u64(self, byte_order=None):
endian = NlAttr.format_byte_order(byte_order)
return struct.unpack(f"{endian}Q", self.raw)[0]
def as_strz(self): def as_strz(self):
return self.raw.decode('ascii')[:-1] return self.raw.decode('ascii')[:-1]
...@@ -115,17 +116,17 @@ class NlAttr: ...@@ -115,17 +116,17 @@ class NlAttr:
return self.raw return self.raw
def as_c_array(self, type): def as_c_array(self, type):
format, _ = self.type_formats[type] format = self.get_format(type)
return list({ x[0] for x in struct.iter_unpack(format, self.raw) }) return [ x[0] for x in format.iter_unpack(self.raw) ]
def as_struct(self, members): def as_struct(self, members):
value = dict() value = dict()
offset = 0 offset = 0
for m in members: for m in members:
# TODO: handle non-scalar members # TODO: handle non-scalar members
format, size = self.type_formats[m.type] format = self.get_format(m.type, m.byte_order)
decoded = struct.unpack_from(format, self.raw, offset) decoded = format.unpack_from(self.raw, offset)
offset += size offset += format.size
value[m.name] = decoded[0] value[m.name] = decoded[0]
return value return value
...@@ -184,11 +185,11 @@ class NlMsg: ...@@ -184,11 +185,11 @@ class NlMsg:
if extack.type == Netlink.NLMSGERR_ATTR_MSG: if extack.type == Netlink.NLMSGERR_ATTR_MSG:
self.extack['msg'] = extack.as_strz() self.extack['msg'] = extack.as_strz()
elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE: elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
self.extack['miss-type'] = extack.as_u32() self.extack['miss-type'] = extack.as_scalar('u32')
elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST: elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
self.extack['miss-nest'] = extack.as_u32() self.extack['miss-nest'] = extack.as_scalar('u32')
elif extack.type == Netlink.NLMSGERR_ATTR_OFFS: elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
self.extack['bad-attr-offs'] = extack.as_u32() self.extack['bad-attr-offs'] = extack.as_scalar('u32')
else: else:
if 'unknown' not in self.extack: if 'unknown' not in self.extack:
self.extack['unknown'] = [] self.extack['unknown'] = []
...@@ -272,11 +273,11 @@ def _genl_load_families(): ...@@ -272,11 +273,11 @@ def _genl_load_families():
fam = dict() fam = dict()
for attr in gm.raw_attrs: for attr in gm.raw_attrs:
if attr.type == Netlink.CTRL_ATTR_FAMILY_ID: if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
fam['id'] = attr.as_u16() fam['id'] = attr.as_scalar('u16')
elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME: elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
fam['name'] = attr.as_strz() fam['name'] = attr.as_strz()
elif attr.type == Netlink.CTRL_ATTR_MAXATTR: elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
fam['maxattr'] = attr.as_u32() fam['maxattr'] = attr.as_scalar('u32')
elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS: elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
fam['mcast'] = dict() fam['mcast'] = dict()
for entry in NlAttrs(attr.raw): for entry in NlAttrs(attr.raw):
...@@ -286,7 +287,7 @@ def _genl_load_families(): ...@@ -286,7 +287,7 @@ def _genl_load_families():
if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME: if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
mcast_name = entry_attr.as_strz() mcast_name = entry_attr.as_strz()
elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID: elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
mcast_id = entry_attr.as_u32() mcast_id = entry_attr.as_scalar('u32')
if mcast_name and mcast_id is not None: if mcast_name and mcast_id is not None:
fam['mcast'][mcast_name] = mcast_id fam['mcast'][mcast_name] = mcast_id
if 'name' in fam and 'id' in fam: if 'name' in fam and 'id' in fam:
...@@ -304,9 +305,9 @@ class GenlMsg: ...@@ -304,9 +305,9 @@ class GenlMsg:
self.fixed_header_attrs = dict() self.fixed_header_attrs = dict()
for m in fixed_header_members: for m in fixed_header_members:
format, size = NlAttr.type_formats[m.type] format = NlAttr.get_format(m.type, m.byte_order)
decoded = struct.unpack_from(format, nl_msg.raw, offset) decoded = format.unpack_from(nl_msg.raw, offset)
offset += size offset += format.size
self.fixed_header_attrs[m.name] = decoded[0] self.fixed_header_attrs[m.name] = decoded[0]
self.raw = nl_msg.raw[offset:] self.raw = nl_msg.raw[offset:]
...@@ -381,21 +382,13 @@ class YnlFamily(SpecFamily): ...@@ -381,21 +382,13 @@ class YnlFamily(SpecFamily):
attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue) attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
elif attr["type"] == 'flag': elif attr["type"] == 'flag':
attr_payload = b'' attr_payload = b''
elif attr["type"] == 'u8':
attr_payload = struct.pack("B", int(value))
elif attr["type"] == 'u16':
endian = NlAttr.format_byte_order(attr.byte_order)
attr_payload = struct.pack(f"{endian}H", int(value))
elif attr["type"] == 'u32':
endian = NlAttr.format_byte_order(attr.byte_order)
attr_payload = struct.pack(f"{endian}I", int(value))
elif attr["type"] == 'u64':
endian = NlAttr.format_byte_order(attr.byte_order)
attr_payload = struct.pack(f"{endian}Q", int(value))
elif attr["type"] == 'string': elif attr["type"] == 'string':
attr_payload = str(value).encode('ascii') + b'\x00' attr_payload = str(value).encode('ascii') + b'\x00'
elif attr["type"] == 'binary': elif attr["type"] == 'binary':
attr_payload = value attr_payload = value
elif attr['type'] in NlAttr.type_formats:
format = NlAttr.get_format(attr['type'], attr.byte_order)
attr_payload = format.pack(int(value))
else: else:
raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}') raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
...@@ -434,22 +427,16 @@ class YnlFamily(SpecFamily): ...@@ -434,22 +427,16 @@ class YnlFamily(SpecFamily):
if attr_spec["type"] == 'nest': if attr_spec["type"] == 'nest':
subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes']) subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
decoded = subdict decoded = subdict
elif attr_spec['type'] == 'u8':
decoded = attr.as_u8()
elif attr_spec['type'] == 'u16':
decoded = attr.as_u16(attr_spec.byte_order)
elif attr_spec['type'] == 'u32':
decoded = attr.as_u32(attr_spec.byte_order)
elif attr_spec['type'] == 'u64':
decoded = attr.as_u64(attr_spec.byte_order)
elif attr_spec["type"] == 'string': elif attr_spec["type"] == 'string':
decoded = attr.as_strz() decoded = attr.as_strz()
elif attr_spec["type"] == 'binary': elif attr_spec["type"] == 'binary':
decoded = self._decode_binary(attr, attr_spec) decoded = self._decode_binary(attr, attr_spec)
elif attr_spec["type"] == 'flag': elif attr_spec["type"] == 'flag':
decoded = True decoded = True
elif attr_spec["type"] in NlAttr.type_formats:
decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
else: else:
raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}') raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
if not attr_spec.is_multi: if not attr_spec.is_multi:
rsp[attr_spec['name']] = decoded rsp[attr_spec['name']] = decoded
...@@ -555,8 +542,8 @@ class YnlFamily(SpecFamily): ...@@ -555,8 +542,8 @@ class YnlFamily(SpecFamily):
fixed_header_members = self.consts[op.fixed_header].members fixed_header_members = self.consts[op.fixed_header].members
for m in fixed_header_members: for m in fixed_header_members:
value = vals.pop(m.name) value = vals.pop(m.name)
format, _ = NlAttr.type_formats[m.type] format = NlAttr.get_format(m.type, m.byte_order)
msg += struct.pack(format, value) msg += format.pack(value)
for name, value in vals.items(): for name, value in vals.items():
msg += self._add_attr(op.attr_set.name, name, value) msg += self._add_attr(op.attr_set.name, name, value)
msg = _genl_msg_finalize(msg) msg = _genl_msg_finalize(msg)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment