Commit d06a09b9 authored by Johannes Berg's avatar Johannes Berg Committed by David S. Miller

netlink: extend policy range validation

Using a pointer to a struct indicating the min/max values,
extend the ability to do range validation for arbitrary
values. Small values in the s16 range can be kept in the
policy directly.
Signed-off-by: default avatarJohannes Berg <johannes.berg@intel.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent d15da2a2
...@@ -189,11 +189,20 @@ enum { ...@@ -189,11 +189,20 @@ enum {
#define NLA_TYPE_MAX (__NLA_TYPE_MAX - 1) #define NLA_TYPE_MAX (__NLA_TYPE_MAX - 1)
struct netlink_range_validation {
u64 min, max;
};
struct netlink_range_validation_signed {
s64 min, max;
};
enum nla_policy_validation { enum nla_policy_validation {
NLA_VALIDATE_NONE, NLA_VALIDATE_NONE,
NLA_VALIDATE_RANGE, NLA_VALIDATE_RANGE,
NLA_VALIDATE_MIN, NLA_VALIDATE_MIN,
NLA_VALIDATE_MAX, NLA_VALIDATE_MAX,
NLA_VALIDATE_RANGE_PTR,
NLA_VALIDATE_FUNCTION, NLA_VALIDATE_FUNCTION,
}; };
...@@ -271,6 +280,22 @@ enum nla_policy_validation { ...@@ -271,6 +280,22 @@ enum nla_policy_validation {
* of s16 - do that as usual in the code instead. * of s16 - do that as usual in the code instead.
* Use the NLA_POLICY_MIN(), NLA_POLICY_MAX() and * Use the NLA_POLICY_MIN(), NLA_POLICY_MAX() and
* NLA_POLICY_RANGE() macros. * NLA_POLICY_RANGE() macros.
* NLA_U8,
* NLA_U16,
* NLA_U32,
* NLA_U64 If the validation_type field instead is set to
* NLA_VALIDATE_RANGE_PTR, `range' must be a pointer
* to a struct netlink_range_validation that indicates
* the min/max values.
* Use NLA_POLICY_FULL_RANGE().
* NLA_S8,
* NLA_S16,
* NLA_S32,
* NLA_S64 If the validation_type field instead is set to
* NLA_VALIDATE_RANGE_PTR, `range_signed' must be a
* pointer to a struct netlink_range_validation_signed
* that indicates the min/max values.
* Use NLA_POLICY_FULL_RANGE_SIGNED().
* All other Unused - but note that it's a union * All other Unused - but note that it's a union
* *
* Meaning of `validate' field, use via NLA_POLICY_VALIDATE_FN: * Meaning of `validate' field, use via NLA_POLICY_VALIDATE_FN:
...@@ -296,6 +321,8 @@ struct nla_policy { ...@@ -296,6 +321,8 @@ struct nla_policy {
const u32 bitfield32_valid; const u32 bitfield32_valid;
const char *reject_message; const char *reject_message;
const struct nla_policy *nested_policy; const struct nla_policy *nested_policy;
struct netlink_range_validation *range;
struct netlink_range_validation_signed *range_signed;
struct { struct {
s16 min, max; s16 min, max;
}; };
...@@ -342,6 +369,12 @@ struct nla_policy { ...@@ -342,6 +369,12 @@ struct nla_policy {
{ .type = NLA_BITFIELD32, .bitfield32_valid = valid } { .type = NLA_BITFIELD32, .bitfield32_valid = valid }
#define __NLA_ENSURE(condition) BUILD_BUG_ON_ZERO(!(condition)) #define __NLA_ENSURE(condition) BUILD_BUG_ON_ZERO(!(condition))
#define NLA_ENSURE_UINT_TYPE(tp) \
(__NLA_ENSURE(tp == NLA_U8 || tp == NLA_U16 || \
tp == NLA_U32 || tp == NLA_U64) + tp)
#define NLA_ENSURE_SINT_TYPE(tp) \
(__NLA_ENSURE(tp == NLA_S8 || tp == NLA_S16 || \
tp == NLA_S32 || tp == NLA_S64) + tp)
#define NLA_ENSURE_INT_TYPE(tp) \ #define NLA_ENSURE_INT_TYPE(tp) \
(__NLA_ENSURE(tp == NLA_S8 || tp == NLA_U8 || \ (__NLA_ENSURE(tp == NLA_S8 || tp == NLA_U8 || \
tp == NLA_S16 || tp == NLA_U16 || \ tp == NLA_S16 || tp == NLA_U16 || \
...@@ -360,6 +393,18 @@ struct nla_policy { ...@@ -360,6 +393,18 @@ struct nla_policy {
.max = _max \ .max = _max \
} }
#define NLA_POLICY_FULL_RANGE(tp, _range) { \
.type = NLA_ENSURE_UINT_TYPE(tp), \
.validation_type = NLA_VALIDATE_RANGE_PTR, \
.range = _range, \
}
#define NLA_POLICY_FULL_RANGE_SIGNED(tp, _range) { \
.type = NLA_ENSURE_SINT_TYPE(tp), \
.validation_type = NLA_VALIDATE_RANGE_PTR, \
.range_signed = _range, \
}
#define NLA_POLICY_MIN(tp, _min) { \ #define NLA_POLICY_MIN(tp, _min) { \
.type = NLA_ENSURE_INT_TYPE(tp), \ .type = NLA_ENSURE_INT_TYPE(tp), \
.validation_type = NLA_VALIDATE_MIN, \ .validation_type = NLA_VALIDATE_MIN, \
......
...@@ -111,17 +111,34 @@ static int nla_validate_array(const struct nlattr *head, int len, int maxtype, ...@@ -111,17 +111,34 @@ static int nla_validate_array(const struct nlattr *head, int len, int maxtype,
return 0; return 0;
} }
static int nla_validate_int_range(const struct nla_policy *pt, static int nla_validate_int_range_unsigned(const struct nla_policy *pt,
const struct nlattr *nla, const struct nlattr *nla,
struct netlink_ext_ack *extack) struct netlink_ext_ack *extack)
{ {
bool validate_min, validate_max; struct netlink_range_validation _range = {
s64 value; .min = 0,
.max = U64_MAX,
}, *range = &_range;
u64 value;
validate_min = pt->validation_type == NLA_VALIDATE_RANGE || WARN_ON_ONCE(pt->validation_type != NLA_VALIDATE_RANGE_PTR &&
pt->validation_type == NLA_VALIDATE_MIN; (pt->min < 0 || pt->max < 0));
validate_max = pt->validation_type == NLA_VALIDATE_RANGE ||
pt->validation_type == NLA_VALIDATE_MAX; switch (pt->validation_type) {
case NLA_VALIDATE_RANGE:
range->min = pt->min;
range->max = pt->max;
break;
case NLA_VALIDATE_RANGE_PTR:
range = pt->range;
break;
case NLA_VALIDATE_MIN:
range->min = pt->min;
break;
case NLA_VALIDATE_MAX:
range->max = pt->max;
break;
}
switch (pt->type) { switch (pt->type) {
case NLA_U8: case NLA_U8:
...@@ -133,6 +150,49 @@ static int nla_validate_int_range(const struct nla_policy *pt, ...@@ -133,6 +150,49 @@ static int nla_validate_int_range(const struct nla_policy *pt,
case NLA_U32: case NLA_U32:
value = nla_get_u32(nla); value = nla_get_u32(nla);
break; break;
case NLA_U64:
value = nla_get_u64(nla);
break;
default:
return -EINVAL;
}
if (value < range->min || value > range->max) {
NL_SET_ERR_MSG_ATTR(extack, nla,
"integer out of range");
return -ERANGE;
}
return 0;
}
static int nla_validate_int_range_signed(const struct nla_policy *pt,
const struct nlattr *nla,
struct netlink_ext_ack *extack)
{
struct netlink_range_validation_signed _range = {
.min = S64_MIN,
.max = S64_MAX,
}, *range = &_range;
s64 value;
switch (pt->validation_type) {
case NLA_VALIDATE_RANGE:
range->min = pt->min;
range->max = pt->max;
break;
case NLA_VALIDATE_RANGE_PTR:
range = pt->range_signed;
break;
case NLA_VALIDATE_MIN:
range->min = pt->min;
break;
case NLA_VALIDATE_MAX:
range->max = pt->max;
break;
}
switch (pt->type) {
case NLA_S8: case NLA_S8:
value = nla_get_s8(nla); value = nla_get_s8(nla);
break; break;
...@@ -145,22 +205,11 @@ static int nla_validate_int_range(const struct nla_policy *pt, ...@@ -145,22 +205,11 @@ static int nla_validate_int_range(const struct nla_policy *pt,
case NLA_S64: case NLA_S64:
value = nla_get_s64(nla); value = nla_get_s64(nla);
break; break;
case NLA_U64:
/* treat this one specially, since it may not fit into s64 */
if ((validate_min && nla_get_u64(nla) < pt->min) ||
(validate_max && nla_get_u64(nla) > pt->max)) {
NL_SET_ERR_MSG_ATTR(extack, nla,
"integer out of range");
return -ERANGE;
}
return 0;
default: default:
WARN_ON(1);
return -EINVAL; return -EINVAL;
} }
if ((validate_min && value < pt->min) || if (value < range->min || value > range->max) {
(validate_max && value > pt->max)) {
NL_SET_ERR_MSG_ATTR(extack, nla, NL_SET_ERR_MSG_ATTR(extack, nla,
"integer out of range"); "integer out of range");
return -ERANGE; return -ERANGE;
...@@ -169,6 +218,27 @@ static int nla_validate_int_range(const struct nla_policy *pt, ...@@ -169,6 +218,27 @@ static int nla_validate_int_range(const struct nla_policy *pt,
return 0; return 0;
} }
static int nla_validate_int_range(const struct nla_policy *pt,
const struct nlattr *nla,
struct netlink_ext_ack *extack)
{
switch (pt->type) {
case NLA_U8:
case NLA_U16:
case NLA_U32:
case NLA_U64:
return nla_validate_int_range_unsigned(pt, nla, extack);
case NLA_S8:
case NLA_S16:
case NLA_S32:
case NLA_S64:
return nla_validate_int_range_signed(pt, nla, extack);
default:
WARN_ON(1);
return -EINVAL;
}
}
static int validate_nla(const struct nlattr *nla, int maxtype, static int validate_nla(const struct nlattr *nla, int maxtype,
const struct nla_policy *policy, unsigned int validate, const struct nla_policy *policy, unsigned int validate,
struct netlink_ext_ack *extack, unsigned int depth) struct netlink_ext_ack *extack, unsigned int depth)
...@@ -348,6 +418,7 @@ static int validate_nla(const struct nlattr *nla, int maxtype, ...@@ -348,6 +418,7 @@ static int validate_nla(const struct nlattr *nla, int maxtype,
case NLA_VALIDATE_NONE: case NLA_VALIDATE_NONE:
/* nothing to do */ /* nothing to do */
break; break;
case NLA_VALIDATE_RANGE_PTR:
case NLA_VALIDATE_RANGE: case NLA_VALIDATE_RANGE:
case NLA_VALIDATE_MIN: case NLA_VALIDATE_MIN:
case NLA_VALIDATE_MAX: case NLA_VALIDATE_MAX:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment