Commit 812285fa authored by Alexei Starovoitov's avatar Alexei Starovoitov

Merge branch 'bpf_sk_storage_via_inet_diag'

Martin KaFai Lau says:

====================
The bpf_prog can store specific info to a sk by using bpf_sk_storage.
In other words, a sk can be extended by a bpf_prog.

This series is to support providing bpf_sk_storage data during inet_diag's
dump.  The primary target is the usage like iproute2's "ss".

The first two patches are refactoring works in inet_diag to make
adding bpf_sk_storage support easier.  The next two patches do
the actual work.

Please see individual patch for details.

v2:
- Add commit message for u16 to u32 change in min_dump_alloc in Patch 4 (Song)
- Add comment to explain the !skb->len check in __inet_diag_dump in Patch 4.
- Do the map->map_type check earlier in Patch 3 for readability.
====================
Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents d7f10df8 085c20ca
...@@ -1023,6 +1023,7 @@ void __bpf_free_used_maps(struct bpf_prog_aux *aux, ...@@ -1023,6 +1023,7 @@ void __bpf_free_used_maps(struct bpf_prog_aux *aux,
void bpf_prog_free_id(struct bpf_prog *prog, bool do_idr_lock); void bpf_prog_free_id(struct bpf_prog *prog, bool do_idr_lock);
void bpf_map_free_id(struct bpf_map *map, bool do_idr_lock); void bpf_map_free_id(struct bpf_map *map, bool do_idr_lock);
struct bpf_map *bpf_map_get(u32 ufd);
struct bpf_map *bpf_map_get_with_uref(u32 ufd); struct bpf_map *bpf_map_get_with_uref(u32 ufd);
struct bpf_map *__bpf_map_get(struct fd f); struct bpf_map *__bpf_map_get(struct fd f);
void bpf_map_inc(struct bpf_map *map); void bpf_map_inc(struct bpf_map *map);
......
...@@ -15,11 +15,9 @@ struct netlink_callback; ...@@ -15,11 +15,9 @@ struct netlink_callback;
struct inet_diag_handler { struct inet_diag_handler {
void (*dump)(struct sk_buff *skb, void (*dump)(struct sk_buff *skb,
struct netlink_callback *cb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, const struct inet_diag_req_v2 *r);
struct nlattr *bc);
int (*dump_one)(struct sk_buff *in_skb, int (*dump_one)(struct netlink_callback *cb,
const struct nlmsghdr *nlh,
const struct inet_diag_req_v2 *req); const struct inet_diag_req_v2 *req);
void (*idiag_get_info)(struct sock *sk, void (*idiag_get_info)(struct sock *sk,
...@@ -40,18 +38,25 @@ struct inet_diag_handler { ...@@ -40,18 +38,25 @@ struct inet_diag_handler {
__u16 idiag_info_size; __u16 idiag_info_size;
}; };
struct bpf_sk_storage_diag;
struct inet_diag_dump_data {
struct nlattr *req_nlas[__INET_DIAG_REQ_MAX];
#define inet_diag_nla_bc req_nlas[INET_DIAG_REQ_BYTECODE]
#define inet_diag_nla_bpf_stgs req_nlas[INET_DIAG_REQ_SK_BPF_STORAGES]
struct bpf_sk_storage_diag *bpf_stg_diag;
};
struct inet_connection_sock; struct inet_connection_sock;
int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
struct sk_buff *skb, const struct inet_diag_req_v2 *req, struct sk_buff *skb, struct netlink_callback *cb,
struct user_namespace *user_ns, const struct inet_diag_req_v2 *req,
u32 pid, u32 seq, u16 nlmsg_flags, u16 nlmsg_flags, bool net_admin);
const struct nlmsghdr *unlh, bool net_admin);
void inet_diag_dump_icsk(struct inet_hashinfo *h, struct sk_buff *skb, void inet_diag_dump_icsk(struct inet_hashinfo *h, struct sk_buff *skb,
struct netlink_callback *cb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, const struct inet_diag_req_v2 *r);
struct nlattr *bc);
int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
struct sk_buff *in_skb, const struct nlmsghdr *nlh, struct netlink_callback *cb,
const struct inet_diag_req_v2 *req); const struct inet_diag_req_v2 *req);
struct sock *inet_diag_find_one_icsk(struct net *net, struct sock *inet_diag_find_one_icsk(struct net *net,
......
...@@ -188,10 +188,10 @@ struct netlink_callback { ...@@ -188,10 +188,10 @@ struct netlink_callback {
struct module *module; struct module *module;
struct netlink_ext_ack *extack; struct netlink_ext_ack *extack;
u16 family; u16 family;
u16 min_dump_alloc;
bool strict_check;
u16 answer_flags; u16 answer_flags;
u32 min_dump_alloc;
unsigned int prev_seq, seq; unsigned int prev_seq, seq;
bool strict_check;
union { union {
u8 ctx[48]; u8 ctx[48];
......
...@@ -10,14 +10,41 @@ void bpf_sk_storage_free(struct sock *sk); ...@@ -10,14 +10,41 @@ void bpf_sk_storage_free(struct sock *sk);
extern const struct bpf_func_proto bpf_sk_storage_get_proto; extern const struct bpf_func_proto bpf_sk_storage_get_proto;
extern const struct bpf_func_proto bpf_sk_storage_delete_proto; extern const struct bpf_func_proto bpf_sk_storage_delete_proto;
struct bpf_sk_storage_diag;
struct sk_buff;
struct nlattr;
struct sock;
#ifdef CONFIG_BPF_SYSCALL #ifdef CONFIG_BPF_SYSCALL
int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk); int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk);
struct bpf_sk_storage_diag *
bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs);
void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag);
int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
struct sock *sk, struct sk_buff *skb,
int stg_array_type,
unsigned int *res_diag_size);
#else #else
static inline int bpf_sk_storage_clone(const struct sock *sk, static inline int bpf_sk_storage_clone(const struct sock *sk,
struct sock *newsk) struct sock *newsk)
{ {
return 0; return 0;
} }
static inline struct bpf_sk_storage_diag *
bpf_sk_storage_diag_alloc(const struct nlattr *nla)
{
return NULL;
}
static inline void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
{
}
static inline int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
struct sock *sk, struct sk_buff *skb,
int stg_array_type,
unsigned int *res_diag_size)
{
return 0;
}
#endif #endif
#endif /* _BPF_SK_STORAGE_H */ #endif /* _BPF_SK_STORAGE_H */
...@@ -64,9 +64,11 @@ struct inet_diag_req_raw { ...@@ -64,9 +64,11 @@ struct inet_diag_req_raw {
enum { enum {
INET_DIAG_REQ_NONE, INET_DIAG_REQ_NONE,
INET_DIAG_REQ_BYTECODE, INET_DIAG_REQ_BYTECODE,
INET_DIAG_REQ_SK_BPF_STORAGES,
__INET_DIAG_REQ_MAX,
}; };
#define INET_DIAG_REQ_MAX INET_DIAG_REQ_BYTECODE #define INET_DIAG_REQ_MAX (__INET_DIAG_REQ_MAX - 1)
/* Bytecode is sequence of 4 byte commands followed by variable arguments. /* Bytecode is sequence of 4 byte commands followed by variable arguments.
* All the commands identified by "code" are conditional jumps forward: * All the commands identified by "code" are conditional jumps forward:
...@@ -154,6 +156,7 @@ enum { ...@@ -154,6 +156,7 @@ enum {
INET_DIAG_CLASS_ID, /* request as INET_DIAG_TCLASS */ INET_DIAG_CLASS_ID, /* request as INET_DIAG_TCLASS */
INET_DIAG_MD5SIG, INET_DIAG_MD5SIG,
INET_DIAG_ULP_INFO, INET_DIAG_ULP_INFO,
INET_DIAG_SK_BPF_STORAGES,
__INET_DIAG_MAX, __INET_DIAG_MAX,
}; };
......
...@@ -36,4 +36,30 @@ enum sknetlink_groups { ...@@ -36,4 +36,30 @@ enum sknetlink_groups {
}; };
#define SKNLGRP_MAX (__SKNLGRP_MAX - 1) #define SKNLGRP_MAX (__SKNLGRP_MAX - 1)
enum {
SK_DIAG_BPF_STORAGE_REQ_NONE,
SK_DIAG_BPF_STORAGE_REQ_MAP_FD,
__SK_DIAG_BPF_STORAGE_REQ_MAX,
};
#define SK_DIAG_BPF_STORAGE_REQ_MAX (__SK_DIAG_BPF_STORAGE_REQ_MAX - 1)
enum {
SK_DIAG_BPF_STORAGE_REP_NONE,
SK_DIAG_BPF_STORAGE,
__SK_DIAG_BPF_STORAGE_REP_MAX,
};
#define SK_DIAB_BPF_STORAGE_REP_MAX (__SK_DIAG_BPF_STORAGE_REP_MAX - 1)
enum {
SK_DIAG_BPF_STORAGE_NONE,
SK_DIAG_BPF_STORAGE_PAD,
SK_DIAG_BPF_STORAGE_MAP_ID,
SK_DIAG_BPF_STORAGE_MAP_VALUE,
__SK_DIAG_BPF_STORAGE_MAX,
};
#define SK_DIAG_BPF_STORAGE_MAX (__SK_DIAG_BPF_STORAGE_MAX - 1)
#endif /* _UAPI__SOCK_DIAG_H__ */ #endif /* _UAPI__SOCK_DIAG_H__ */
...@@ -902,6 +902,21 @@ void bpf_map_inc_with_uref(struct bpf_map *map) ...@@ -902,6 +902,21 @@ void bpf_map_inc_with_uref(struct bpf_map *map)
} }
EXPORT_SYMBOL_GPL(bpf_map_inc_with_uref); EXPORT_SYMBOL_GPL(bpf_map_inc_with_uref);
struct bpf_map *bpf_map_get(u32 ufd)
{
struct fd f = fdget(ufd);
struct bpf_map *map;
map = __bpf_map_get(f);
if (IS_ERR(map))
return map;
bpf_map_inc(map);
fdput(f);
return map;
}
struct bpf_map *bpf_map_get_with_uref(u32 ufd) struct bpf_map *bpf_map_get_with_uref(u32 ufd)
{ {
struct fd f = fdget(ufd); struct fd f = fdget(ufd);
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <linux/bpf.h> #include <linux/bpf.h>
#include <net/bpf_sk_storage.h> #include <net/bpf_sk_storage.h>
#include <net/sock.h> #include <net/sock.h>
#include <uapi/linux/sock_diag.h>
#include <uapi/linux/btf.h> #include <uapi/linux/btf.h>
static atomic_t cache_idx; static atomic_t cache_idx;
...@@ -606,6 +607,14 @@ static void bpf_sk_storage_map_free(struct bpf_map *map) ...@@ -606,6 +607,14 @@ static void bpf_sk_storage_map_free(struct bpf_map *map)
kfree(map); kfree(map);
} }
/* U16_MAX is much more than enough for sk local storage
* considering a tcp_sock is ~2k.
*/
#define MAX_VALUE_SIZE \
min_t(u32, \
(KMALLOC_MAX_SIZE - MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem)), \
(U16_MAX - sizeof(struct bpf_sk_storage_elem)))
static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr) static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
{ {
if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK || if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK ||
...@@ -619,12 +628,7 @@ static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr) ...@@ -619,12 +628,7 @@ static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
if (!capable(CAP_SYS_ADMIN)) if (!capable(CAP_SYS_ADMIN))
return -EPERM; return -EPERM;
if (attr->value_size >= KMALLOC_MAX_SIZE - if (attr->value_size > MAX_VALUE_SIZE)
MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem) ||
/* U16_MAX is much more than enough for sk local storage
* considering a tcp_sock is ~2k.
*/
attr->value_size > U16_MAX - sizeof(struct bpf_sk_storage_elem))
return -E2BIG; return -E2BIG;
return 0; return 0;
...@@ -910,3 +914,270 @@ const struct bpf_func_proto bpf_sk_storage_delete_proto = { ...@@ -910,3 +914,270 @@ const struct bpf_func_proto bpf_sk_storage_delete_proto = {
.arg1_type = ARG_CONST_MAP_PTR, .arg1_type = ARG_CONST_MAP_PTR,
.arg2_type = ARG_PTR_TO_SOCKET, .arg2_type = ARG_PTR_TO_SOCKET,
}; };
struct bpf_sk_storage_diag {
u32 nr_maps;
struct bpf_map *maps[];
};
/* The reply will be like:
* INET_DIAG_BPF_SK_STORAGES (nla_nest)
* SK_DIAG_BPF_STORAGE (nla_nest)
* SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
* SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
* SK_DIAG_BPF_STORAGE (nla_nest)
* SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
* SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
* ....
*/
static int nla_value_size(u32 value_size)
{
/* SK_DIAG_BPF_STORAGE (nla_nest)
* SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
* SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
*/
return nla_total_size(0) + nla_total_size(sizeof(u32)) +
nla_total_size_64bit(value_size);
}
void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
{
u32 i;
if (!diag)
return;
for (i = 0; i < diag->nr_maps; i++)
bpf_map_put(diag->maps[i]);
kfree(diag);
}
EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
const struct bpf_map *map)
{
u32 i;
for (i = 0; i < diag->nr_maps; i++) {
if (diag->maps[i] == map)
return true;
}
return false;
}
struct bpf_sk_storage_diag *
bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
{
struct bpf_sk_storage_diag *diag;
struct nlattr *nla;
u32 nr_maps = 0;
int rem, err;
/* bpf_sk_storage_map is currently limited to CAP_SYS_ADMIN as
* the map_alloc_check() side also does.
*/
if (!capable(CAP_SYS_ADMIN))
return ERR_PTR(-EPERM);
nla_for_each_nested(nla, nla_stgs, rem) {
if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
nr_maps++;
}
diag = kzalloc(sizeof(*diag) + sizeof(diag->maps[0]) * nr_maps,
GFP_KERNEL);
if (!diag)
return ERR_PTR(-ENOMEM);
nla_for_each_nested(nla, nla_stgs, rem) {
struct bpf_map *map;
int map_fd;
if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
continue;
map_fd = nla_get_u32(nla);
map = bpf_map_get(map_fd);
if (IS_ERR(map)) {
err = PTR_ERR(map);
goto err_free;
}
if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
bpf_map_put(map);
err = -EINVAL;
goto err_free;
}
if (diag_check_dup(diag, map)) {
bpf_map_put(map);
err = -EEXIST;
goto err_free;
}
diag->maps[diag->nr_maps++] = map;
}
return diag;
err_free:
bpf_sk_storage_diag_free(diag);
return ERR_PTR(err);
}
EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
static int diag_get(struct bpf_sk_storage_data *sdata, struct sk_buff *skb)
{
struct nlattr *nla_stg, *nla_value;
struct bpf_sk_storage_map *smap;
/* It cannot exceed max nlattr's payload */
BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < MAX_VALUE_SIZE);
nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
if (!nla_stg)
return -EMSGSIZE;
smap = rcu_dereference(sdata->smap);
if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
goto errout;
nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
smap->map.value_size,
SK_DIAG_BPF_STORAGE_PAD);
if (!nla_value)
goto errout;
if (map_value_has_spin_lock(&smap->map))
copy_map_value_locked(&smap->map, nla_data(nla_value),
sdata->data, true);
else
copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
nla_nest_end(skb, nla_stg);
return 0;
errout:
nla_nest_cancel(skb, nla_stg);
return -EMSGSIZE;
}
static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
int stg_array_type,
unsigned int *res_diag_size)
{
/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
unsigned int diag_size = nla_total_size(0);
struct bpf_sk_storage *sk_storage;
struct bpf_sk_storage_elem *selem;
struct bpf_sk_storage_map *smap;
struct nlattr *nla_stgs;
unsigned int saved_len;
int err = 0;
rcu_read_lock();
sk_storage = rcu_dereference(sk->sk_bpf_storage);
if (!sk_storage || hlist_empty(&sk_storage->list)) {
rcu_read_unlock();
return 0;
}
nla_stgs = nla_nest_start(skb, stg_array_type);
if (!nla_stgs)
/* Continue to learn diag_size */
err = -EMSGSIZE;
saved_len = skb->len;
hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
smap = rcu_dereference(SDATA(selem)->smap);
diag_size += nla_value_size(smap->map.value_size);
if (nla_stgs && diag_get(SDATA(selem), skb))
/* Continue to learn diag_size */
err = -EMSGSIZE;
}
rcu_read_unlock();
if (nla_stgs) {
if (saved_len == skb->len)
nla_nest_cancel(skb, nla_stgs);
else
nla_nest_end(skb, nla_stgs);
}
if (diag_size == nla_total_size(0)) {
*res_diag_size = 0;
return 0;
}
*res_diag_size = diag_size;
return err;
}
int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
struct sock *sk, struct sk_buff *skb,
int stg_array_type,
unsigned int *res_diag_size)
{
/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
unsigned int diag_size = nla_total_size(0);
struct bpf_sk_storage *sk_storage;
struct bpf_sk_storage_data *sdata;
struct nlattr *nla_stgs;
unsigned int saved_len;
int err = 0;
u32 i;
*res_diag_size = 0;
/* No map has been specified. Dump all. */
if (!diag->nr_maps)
return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
res_diag_size);
rcu_read_lock();
sk_storage = rcu_dereference(sk->sk_bpf_storage);
if (!sk_storage || hlist_empty(&sk_storage->list)) {
rcu_read_unlock();
return 0;
}
nla_stgs = nla_nest_start(skb, stg_array_type);
if (!nla_stgs)
/* Continue to learn diag_size */
err = -EMSGSIZE;
saved_len = skb->len;
for (i = 0; i < diag->nr_maps; i++) {
sdata = __sk_storage_lookup(sk_storage,
(struct bpf_sk_storage_map *)diag->maps[i],
false);
if (!sdata)
continue;
diag_size += nla_value_size(diag->maps[i]->value_size);
if (nla_stgs && diag_get(sdata, skb))
/* Continue to learn diag_size */
err = -EMSGSIZE;
}
rcu_read_unlock();
if (nla_stgs) {
if (saved_len == skb->len)
nla_nest_cancel(skb, nla_stgs);
else
nla_nest_end(skb, nla_stgs);
}
if (diag_size == nla_total_size(0)) {
*res_diag_size = 0;
return 0;
}
*res_diag_size = diag_size;
return err;
}
EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);
...@@ -46,16 +46,15 @@ static void dccp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, ...@@ -46,16 +46,15 @@ static void dccp_diag_get_info(struct sock *sk, struct inet_diag_msg *r,
} }
static void dccp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, static void dccp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, struct nlattr *bc) const struct inet_diag_req_v2 *r)
{ {
inet_diag_dump_icsk(&dccp_hashinfo, skb, cb, r, bc); inet_diag_dump_icsk(&dccp_hashinfo, skb, cb, r);
} }
static int dccp_diag_dump_one(struct sk_buff *in_skb, static int dccp_diag_dump_one(struct netlink_callback *cb,
const struct nlmsghdr *nlh,
const struct inet_diag_req_v2 *req) const struct inet_diag_req_v2 *req)
{ {
return inet_diag_dump_one_icsk(&dccp_hashinfo, in_skb, nlh, req); return inet_diag_dump_one_icsk(&dccp_hashinfo, cb, req);
} }
static const struct inet_diag_handler dccp_diag_handler = { static const struct inet_diag_handler dccp_diag_handler = {
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include <net/inet_hashtables.h> #include <net/inet_hashtables.h>
#include <net/inet_timewait_sock.h> #include <net/inet_timewait_sock.h>
#include <net/inet6_hashtables.h> #include <net/inet6_hashtables.h>
#include <net/bpf_sk_storage.h>
#include <net/netlink.h> #include <net/netlink.h>
#include <linux/inet.h> #include <linux/inet.h>
...@@ -156,26 +157,28 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb, ...@@ -156,26 +157,28 @@ int inet_diag_msg_attrs_fill(struct sock *sk, struct sk_buff *skb,
} }
EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill); EXPORT_SYMBOL_GPL(inet_diag_msg_attrs_fill);
#define MAX_DUMP_ALLOC_SIZE (KMALLOC_MAX_SIZE - SKB_DATA_ALIGN(sizeof(struct skb_shared_info)))
int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
struct sk_buff *skb, const struct inet_diag_req_v2 *req, struct sk_buff *skb, struct netlink_callback *cb,
struct user_namespace *user_ns, const struct inet_diag_req_v2 *req,
u32 portid, u32 seq, u16 nlmsg_flags, u16 nlmsg_flags, bool net_admin)
const struct nlmsghdr *unlh,
bool net_admin)
{ {
const struct tcp_congestion_ops *ca_ops; const struct tcp_congestion_ops *ca_ops;
const struct inet_diag_handler *handler; const struct inet_diag_handler *handler;
struct inet_diag_dump_data *cb_data;
int ext = req->idiag_ext; int ext = req->idiag_ext;
struct inet_diag_msg *r; struct inet_diag_msg *r;
struct nlmsghdr *nlh; struct nlmsghdr *nlh;
struct nlattr *attr; struct nlattr *attr;
void *info = NULL; void *info = NULL;
cb_data = cb->data;
handler = inet_diag_table[req->sdiag_protocol]; handler = inet_diag_table[req->sdiag_protocol];
BUG_ON(!handler); BUG_ON(!handler);
nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r), nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
nlmsg_flags); cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
if (!nlh) if (!nlh)
return -EMSGSIZE; return -EMSGSIZE;
...@@ -187,7 +190,9 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, ...@@ -187,7 +190,9 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
r->idiag_timer = 0; r->idiag_timer = 0;
r->idiag_retrans = 0; r->idiag_retrans = 0;
if (inet_diag_msg_attrs_fill(sk, skb, r, ext, user_ns, net_admin)) if (inet_diag_msg_attrs_fill(sk, skb, r, ext,
sk_user_ns(NETLINK_CB(cb->skb).sk),
net_admin))
goto errout; goto errout;
if (ext & (1 << (INET_DIAG_MEMINFO - 1))) { if (ext & (1 << (INET_DIAG_MEMINFO - 1))) {
...@@ -302,6 +307,48 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, ...@@ -302,6 +307,48 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
goto errout; goto errout;
} }
/* Keep it at the end for potential retry with a larger skb,
* or else do best-effort fitting, which is only done for the
* first_nlmsg.
*/
if (cb_data->bpf_stg_diag) {
bool first_nlmsg = ((unsigned char *)nlh == skb->data);
unsigned int prev_min_dump_alloc;
unsigned int total_nla_size = 0;
unsigned int msg_len;
int err;
msg_len = skb_tail_pointer(skb) - (unsigned char *)nlh;
err = bpf_sk_storage_diag_put(cb_data->bpf_stg_diag, sk, skb,
INET_DIAG_SK_BPF_STORAGES,
&total_nla_size);
if (!err)
goto out;
total_nla_size += msg_len;
prev_min_dump_alloc = cb->min_dump_alloc;
if (total_nla_size > prev_min_dump_alloc)
cb->min_dump_alloc = min_t(u32, total_nla_size,
MAX_DUMP_ALLOC_SIZE);
if (!first_nlmsg)
goto errout;
if (cb->min_dump_alloc > prev_min_dump_alloc)
/* Retry with pskb_expand_head() with
* __GFP_DIRECT_RECLAIM
*/
goto errout;
WARN_ON_ONCE(total_nla_size <= prev_min_dump_alloc);
/* Send what we have for this sk
* and move on to the next sk in the following
* dump()
*/
}
out: out:
nlmsg_end(skb, nlh); nlmsg_end(skb, nlh);
return 0; return 0;
...@@ -312,30 +359,19 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, ...@@ -312,30 +359,19 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
} }
EXPORT_SYMBOL_GPL(inet_sk_diag_fill); EXPORT_SYMBOL_GPL(inet_sk_diag_fill);
static int inet_csk_diag_fill(struct sock *sk,
struct sk_buff *skb,
const struct inet_diag_req_v2 *req,
struct user_namespace *user_ns,
u32 portid, u32 seq, u16 nlmsg_flags,
const struct nlmsghdr *unlh,
bool net_admin)
{
return inet_sk_diag_fill(sk, inet_csk(sk), skb, req, user_ns,
portid, seq, nlmsg_flags, unlh, net_admin);
}
static int inet_twsk_diag_fill(struct sock *sk, static int inet_twsk_diag_fill(struct sock *sk,
struct sk_buff *skb, struct sk_buff *skb,
u32 portid, u32 seq, u16 nlmsg_flags, struct netlink_callback *cb,
const struct nlmsghdr *unlh) u16 nlmsg_flags)
{ {
struct inet_timewait_sock *tw = inet_twsk(sk); struct inet_timewait_sock *tw = inet_twsk(sk);
struct inet_diag_msg *r; struct inet_diag_msg *r;
struct nlmsghdr *nlh; struct nlmsghdr *nlh;
long tmo; long tmo;
nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r), nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
nlmsg_flags); cb->nlh->nlmsg_seq, cb->nlh->nlmsg_type,
sizeof(*r), nlmsg_flags);
if (!nlh) if (!nlh)
return -EMSGSIZE; return -EMSGSIZE;
...@@ -359,16 +395,16 @@ static int inet_twsk_diag_fill(struct sock *sk, ...@@ -359,16 +395,16 @@ static int inet_twsk_diag_fill(struct sock *sk,
} }
static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb, static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
u32 portid, u32 seq, u16 nlmsg_flags, struct netlink_callback *cb,
const struct nlmsghdr *unlh, bool net_admin) u16 nlmsg_flags, bool net_admin)
{ {
struct request_sock *reqsk = inet_reqsk(sk); struct request_sock *reqsk = inet_reqsk(sk);
struct inet_diag_msg *r; struct inet_diag_msg *r;
struct nlmsghdr *nlh; struct nlmsghdr *nlh;
long tmo; long tmo;
nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r), nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq,
nlmsg_flags); cb->nlh->nlmsg_type, sizeof(*r), nlmsg_flags);
if (!nlh) if (!nlh)
return -EMSGSIZE; return -EMSGSIZE;
...@@ -397,21 +433,18 @@ static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb, ...@@ -397,21 +433,18 @@ static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
} }
static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, const struct inet_diag_req_v2 *r,
struct user_namespace *user_ns, u16 nlmsg_flags, bool net_admin)
u32 portid, u32 seq, u16 nlmsg_flags,
const struct nlmsghdr *unlh, bool net_admin)
{ {
if (sk->sk_state == TCP_TIME_WAIT) if (sk->sk_state == TCP_TIME_WAIT)
return inet_twsk_diag_fill(sk, skb, portid, seq, return inet_twsk_diag_fill(sk, skb, cb, nlmsg_flags);
nlmsg_flags, unlh);
if (sk->sk_state == TCP_NEW_SYN_RECV) if (sk->sk_state == TCP_NEW_SYN_RECV)
return inet_req_diag_fill(sk, skb, portid, seq, return inet_req_diag_fill(sk, skb, cb, nlmsg_flags, net_admin);
nlmsg_flags, unlh, net_admin);
return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq, return inet_sk_diag_fill(sk, inet_csk(sk), skb, cb, r, nlmsg_flags,
nlmsg_flags, unlh, net_admin); net_admin);
} }
struct sock *inet_diag_find_one_icsk(struct net *net, struct sock *inet_diag_find_one_icsk(struct net *net,
...@@ -459,10 +492,10 @@ struct sock *inet_diag_find_one_icsk(struct net *net, ...@@ -459,10 +492,10 @@ struct sock *inet_diag_find_one_icsk(struct net *net,
EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk); EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk);
int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
struct sk_buff *in_skb, struct netlink_callback *cb,
const struct nlmsghdr *nlh,
const struct inet_diag_req_v2 *req) const struct inet_diag_req_v2 *req)
{ {
struct sk_buff *in_skb = cb->skb;
bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN); bool net_admin = netlink_net_capable(in_skb, CAP_NET_ADMIN);
struct net *net = sock_net(in_skb->sk); struct net *net = sock_net(in_skb->sk);
struct sk_buff *rep; struct sk_buff *rep;
...@@ -479,10 +512,7 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, ...@@ -479,10 +512,7 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo,
goto out; goto out;
} }
err = sk_diag_fill(sk, rep, req, err = sk_diag_fill(sk, rep, cb, req, 0, net_admin);
sk_user_ns(NETLINK_CB(in_skb).sk),
NETLINK_CB(in_skb).portid,
nlh->nlmsg_seq, 0, nlh, net_admin);
if (err < 0) { if (err < 0) {
WARN_ON(err == -EMSGSIZE); WARN_ON(err == -EMSGSIZE);
nlmsg_free(rep); nlmsg_free(rep);
...@@ -509,14 +539,21 @@ static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb, ...@@ -509,14 +539,21 @@ static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb,
int err; int err;
handler = inet_diag_lock_handler(req->sdiag_protocol); handler = inet_diag_lock_handler(req->sdiag_protocol);
if (IS_ERR(handler)) if (IS_ERR(handler)) {
err = PTR_ERR(handler); err = PTR_ERR(handler);
else if (cmd == SOCK_DIAG_BY_FAMILY) } else if (cmd == SOCK_DIAG_BY_FAMILY) {
err = handler->dump_one(in_skb, nlh, req); struct inet_diag_dump_data empty_dump_data = {};
else if (cmd == SOCK_DESTROY && handler->destroy) struct netlink_callback cb = {
.nlh = nlh,
.skb = in_skb,
.data = &empty_dump_data,
};
err = handler->dump_one(&cb, req);
} else if (cmd == SOCK_DESTROY && handler->destroy) {
err = handler->destroy(in_skb, req); err = handler->destroy(in_skb, req);
else } else {
err = -EOPNOTSUPP; err = -EOPNOTSUPP;
}
inet_diag_unlock_handler(handler); inet_diag_unlock_handler(handler);
return err; return err;
...@@ -847,23 +884,6 @@ static int inet_diag_bc_audit(const struct nlattr *attr, ...@@ -847,23 +884,6 @@ static int inet_diag_bc_audit(const struct nlattr *attr,
return len == 0 ? 0 : -EINVAL; return len == 0 ? 0 : -EINVAL;
} }
static int inet_csk_diag_dump(struct sock *sk,
struct sk_buff *skb,
struct netlink_callback *cb,
const struct inet_diag_req_v2 *r,
const struct nlattr *bc,
bool net_admin)
{
if (!inet_diag_bc_sk(bc, sk))
return 0;
return inet_csk_diag_fill(sk, skb, r,
sk_user_ns(NETLINK_CB(cb->skb).sk),
NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh,
net_admin);
}
static void twsk_build_assert(void) static void twsk_build_assert(void)
{ {
BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) != BUILD_BUG_ON(offsetof(struct inet_timewait_sock, tw_family) !=
...@@ -892,14 +912,17 @@ static void twsk_build_assert(void) ...@@ -892,14 +912,17 @@ static void twsk_build_assert(void)
void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
struct netlink_callback *cb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, struct nlattr *bc) const struct inet_diag_req_v2 *r)
{ {
bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
struct inet_diag_dump_data *cb_data = cb->data;
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
u32 idiag_states = r->idiag_states; u32 idiag_states = r->idiag_states;
int i, num, s_i, s_num; int i, num, s_i, s_num;
struct nlattr *bc;
struct sock *sk; struct sock *sk;
bc = cb_data->inet_diag_nla_bc;
if (idiag_states & TCPF_SYN_RECV) if (idiag_states & TCPF_SYN_RECV)
idiag_states |= TCPF_NEW_SYN_RECV; idiag_states |= TCPF_NEW_SYN_RECV;
s_i = cb->args[1]; s_i = cb->args[1];
...@@ -935,8 +958,12 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, ...@@ -935,8 +958,12 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
r->id.idiag_sport) r->id.idiag_sport)
goto next_listen; goto next_listen;
if (inet_csk_diag_dump(sk, skb, cb, r, if (!inet_diag_bc_sk(bc, sk))
bc, net_admin) < 0) { goto next_listen;
if (inet_sk_diag_fill(sk, inet_csk(sk), skb,
cb, r, NLM_F_MULTI,
net_admin) < 0) {
spin_unlock(&ilb->lock); spin_unlock(&ilb->lock);
goto done; goto done;
} }
...@@ -1014,11 +1041,8 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, ...@@ -1014,11 +1041,8 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
res = 0; res = 0;
for (idx = 0; idx < accum; idx++) { for (idx = 0; idx < accum; idx++) {
if (res >= 0) { if (res >= 0) {
res = sk_diag_fill(sk_arr[idx], skb, r, res = sk_diag_fill(sk_arr[idx], skb, cb, r,
sk_user_ns(NETLINK_CB(cb->skb).sk), NLM_F_MULTI, net_admin);
NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, NLM_F_MULTI,
cb->nlh, net_admin);
if (res < 0) if (res < 0)
num = num_arr[idx]; num = num_arr[idx];
} }
...@@ -1042,31 +1066,101 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, ...@@ -1042,31 +1066,101 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
EXPORT_SYMBOL_GPL(inet_diag_dump_icsk); EXPORT_SYMBOL_GPL(inet_diag_dump_icsk);
static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, static int __inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, const struct inet_diag_req_v2 *r)
struct nlattr *bc)
{ {
const struct inet_diag_handler *handler; const struct inet_diag_handler *handler;
u32 prev_min_dump_alloc;
int err = 0; int err = 0;
again:
prev_min_dump_alloc = cb->min_dump_alloc;
handler = inet_diag_lock_handler(r->sdiag_protocol); handler = inet_diag_lock_handler(r->sdiag_protocol);
if (!IS_ERR(handler)) if (!IS_ERR(handler))
handler->dump(skb, cb, r, bc); handler->dump(skb, cb, r);
else else
err = PTR_ERR(handler); err = PTR_ERR(handler);
inet_diag_unlock_handler(handler); inet_diag_unlock_handler(handler);
/* The skb is not large enough to fit one sk info and
* inet_sk_diag_fill() has requested for a larger skb.
*/
if (!skb->len && cb->min_dump_alloc > prev_min_dump_alloc) {
err = pskb_expand_head(skb, 0, cb->min_dump_alloc, GFP_KERNEL);
if (!err)
goto again;
}
return err ? : skb->len; return err ? : skb->len;
} }
static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
{ {
int hdrlen = sizeof(struct inet_diag_req_v2); return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh));
struct nlattr *bc = NULL; }
if (nlmsg_attrlen(cb->nlh, hdrlen)) static int __inet_diag_dump_start(struct netlink_callback *cb, int hdrlen)
bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE); {
const struct nlmsghdr *nlh = cb->nlh;
struct inet_diag_dump_data *cb_data;
struct sk_buff *skb = cb->skb;
struct nlattr *nla;
int rem, err;
cb_data = kzalloc(sizeof(*cb_data), GFP_KERNEL);
if (!cb_data)
return -ENOMEM;
nla_for_each_attr(nla, nlmsg_attrdata(nlh, hdrlen),
nlmsg_attrlen(nlh, hdrlen), rem) {
int type = nla_type(nla);
if (type < __INET_DIAG_REQ_MAX)
cb_data->req_nlas[type] = nla;
}
nla = cb_data->inet_diag_nla_bc;
if (nla) {
err = inet_diag_bc_audit(nla, skb);
if (err) {
kfree(cb_data);
return err;
}
}
nla = cb_data->inet_diag_nla_bpf_stgs;
if (nla) {
struct bpf_sk_storage_diag *bpf_stg_diag;
bpf_stg_diag = bpf_sk_storage_diag_alloc(nla);
if (IS_ERR(bpf_stg_diag)) {
kfree(cb_data);
return PTR_ERR(bpf_stg_diag);
}
cb_data->bpf_stg_diag = bpf_stg_diag;
}
cb->data = cb_data;
return 0;
}
static int inet_diag_dump_start(struct netlink_callback *cb)
{
return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req_v2));
}
static int inet_diag_dump_start_compat(struct netlink_callback *cb)
{
return __inet_diag_dump_start(cb, sizeof(struct inet_diag_req));
}
return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc); static int inet_diag_dump_done(struct netlink_callback *cb)
{
struct inet_diag_dump_data *cb_data = cb->data;
bpf_sk_storage_diag_free(cb_data->bpf_stg_diag);
kfree(cb->data);
return 0;
} }
static int inet_diag_type2proto(int type) static int inet_diag_type2proto(int type)
...@@ -1085,9 +1179,7 @@ static int inet_diag_dump_compat(struct sk_buff *skb, ...@@ -1085,9 +1179,7 @@ static int inet_diag_dump_compat(struct sk_buff *skb,
struct netlink_callback *cb) struct netlink_callback *cb)
{ {
struct inet_diag_req *rc = nlmsg_data(cb->nlh); struct inet_diag_req *rc = nlmsg_data(cb->nlh);
int hdrlen = sizeof(struct inet_diag_req);
struct inet_diag_req_v2 req; struct inet_diag_req_v2 req;
struct nlattr *bc = NULL;
req.sdiag_family = AF_UNSPEC; /* compatibility */ req.sdiag_family = AF_UNSPEC; /* compatibility */
req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type); req.sdiag_protocol = inet_diag_type2proto(cb->nlh->nlmsg_type);
...@@ -1095,10 +1187,7 @@ static int inet_diag_dump_compat(struct sk_buff *skb, ...@@ -1095,10 +1187,7 @@ static int inet_diag_dump_compat(struct sk_buff *skb,
req.idiag_states = rc->idiag_states; req.idiag_states = rc->idiag_states;
req.id = rc->id; req.id = rc->id;
if (nlmsg_attrlen(cb->nlh, hdrlen)) return __inet_diag_dump(skb, cb, &req);
bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE);
return __inet_diag_dump(skb, cb, &req, bc);
} }
static int inet_diag_get_exact_compat(struct sk_buff *in_skb, static int inet_diag_get_exact_compat(struct sk_buff *in_skb,
...@@ -1126,22 +1215,12 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) ...@@ -1126,22 +1215,12 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)
return -EINVAL; return -EINVAL;
if (nlh->nlmsg_flags & NLM_F_DUMP) { if (nlh->nlmsg_flags & NLM_F_DUMP) {
if (nlmsg_attrlen(nlh, hdrlen)) { struct netlink_dump_control c = {
struct nlattr *attr; .start = inet_diag_dump_start_compat,
int err; .done = inet_diag_dump_done,
.dump = inet_diag_dump_compat,
attr = nlmsg_find_attr(nlh, hdrlen, };
INET_DIAG_REQ_BYTECODE); return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
err = inet_diag_bc_audit(attr, skb);
if (err)
return err;
}
{
struct netlink_dump_control c = {
.dump = inet_diag_dump_compat,
};
return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);
}
} }
return inet_diag_get_exact_compat(skb, nlh); return inet_diag_get_exact_compat(skb, nlh);
...@@ -1157,22 +1236,12 @@ static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h) ...@@ -1157,22 +1236,12 @@ static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h)
if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY && if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY &&
h->nlmsg_flags & NLM_F_DUMP) { h->nlmsg_flags & NLM_F_DUMP) {
if (nlmsg_attrlen(h, hdrlen)) { struct netlink_dump_control c = {
struct nlattr *attr; .start = inet_diag_dump_start,
int err; .done = inet_diag_dump_done,
.dump = inet_diag_dump,
attr = nlmsg_find_attr(h, hdrlen, };
INET_DIAG_REQ_BYTECODE); return netlink_dump_start(net->diag_nlsk, skb, h, &c);
err = inet_diag_bc_audit(attr, skb);
if (err)
return err;
}
{
struct netlink_dump_control c = {
.dump = inet_diag_dump,
};
return netlink_dump_start(net->diag_nlsk, skb, h, &c);
}
} }
return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h)); return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h));
......
...@@ -87,15 +87,16 @@ static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 ...@@ -87,15 +87,16 @@ static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2
return sk ? sk : ERR_PTR(-ENOENT); return sk ? sk : ERR_PTR(-ENOENT);
} }
static int raw_diag_dump_one(struct sk_buff *in_skb, static int raw_diag_dump_one(struct netlink_callback *cb,
const struct nlmsghdr *nlh,
const struct inet_diag_req_v2 *r) const struct inet_diag_req_v2 *r)
{ {
struct net *net = sock_net(in_skb->sk); struct sk_buff *in_skb = cb->skb;
struct sk_buff *rep; struct sk_buff *rep;
struct sock *sk; struct sock *sk;
struct net *net;
int err; int err;
net = sock_net(in_skb->sk);
sk = raw_sock_get(net, r); sk = raw_sock_get(net, r);
if (IS_ERR(sk)) if (IS_ERR(sk))
return PTR_ERR(sk); return PTR_ERR(sk);
...@@ -108,10 +109,7 @@ static int raw_diag_dump_one(struct sk_buff *in_skb, ...@@ -108,10 +109,7 @@ static int raw_diag_dump_one(struct sk_buff *in_skb,
return -ENOMEM; return -ENOMEM;
} }
err = inet_sk_diag_fill(sk, NULL, rep, r, err = inet_sk_diag_fill(sk, NULL, rep, cb, r, 0,
sk_user_ns(NETLINK_CB(in_skb).sk),
NETLINK_CB(in_skb).portid,
nlh->nlmsg_seq, 0, nlh,
netlink_net_capable(in_skb, CAP_NET_ADMIN)); netlink_net_capable(in_skb, CAP_NET_ADMIN));
sock_put(sk); sock_put(sk);
...@@ -136,25 +134,25 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, ...@@ -136,25 +134,25 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb,
if (!inet_diag_bc_sk(bc, sk)) if (!inet_diag_bc_sk(bc, sk))
return 0; return 0;
return inet_sk_diag_fill(sk, NULL, skb, r, return inet_sk_diag_fill(sk, NULL, skb, cb, r, NLM_F_MULTI, net_admin);
sk_user_ns(NETLINK_CB(cb->skb).sk),
NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, NLM_F_MULTI,
cb->nlh, net_admin);
} }
static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, struct nlattr *bc) const struct inet_diag_req_v2 *r)
{ {
bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
struct raw_hashinfo *hashinfo = raw_get_hashinfo(r); struct raw_hashinfo *hashinfo = raw_get_hashinfo(r);
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
struct inet_diag_dump_data *cb_data;
int num, s_num, slot, s_slot; int num, s_num, slot, s_slot;
struct sock *sk = NULL; struct sock *sk = NULL;
struct nlattr *bc;
if (IS_ERR(hashinfo)) if (IS_ERR(hashinfo))
return; return;
cb_data = cb->data;
bc = cb_data->inet_diag_nla_bc;
s_slot = cb->args[0]; s_slot = cb->args[0];
num = s_num = cb->args[1]; num = s_num = cb->args[1];
......
...@@ -179,15 +179,15 @@ static size_t tcp_diag_get_aux_size(struct sock *sk, bool net_admin) ...@@ -179,15 +179,15 @@ static size_t tcp_diag_get_aux_size(struct sock *sk, bool net_admin)
} }
static void tcp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, static void tcp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, struct nlattr *bc) const struct inet_diag_req_v2 *r)
{ {
inet_diag_dump_icsk(&tcp_hashinfo, skb, cb, r, bc); inet_diag_dump_icsk(&tcp_hashinfo, skb, cb, r);
} }
static int tcp_diag_dump_one(struct sk_buff *in_skb, const struct nlmsghdr *nlh, static int tcp_diag_dump_one(struct netlink_callback *cb,
const struct inet_diag_req_v2 *req) const struct inet_diag_req_v2 *req)
{ {
return inet_diag_dump_one_icsk(&tcp_hashinfo, in_skb, nlh, req); return inet_diag_dump_one_icsk(&tcp_hashinfo, cb, req);
} }
#ifdef CONFIG_INET_DIAG_DESTROY #ifdef CONFIG_INET_DIAG_DESTROY
......
...@@ -21,16 +21,15 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, ...@@ -21,16 +21,15 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb,
if (!inet_diag_bc_sk(bc, sk)) if (!inet_diag_bc_sk(bc, sk))
return 0; return 0;
return inet_sk_diag_fill(sk, NULL, skb, req, return inet_sk_diag_fill(sk, NULL, skb, cb, req, NLM_F_MULTI,
sk_user_ns(NETLINK_CB(cb->skb).sk), net_admin);
NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh, net_admin);
} }
static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb, static int udp_dump_one(struct udp_table *tbl,
const struct nlmsghdr *nlh, struct netlink_callback *cb,
const struct inet_diag_req_v2 *req) const struct inet_diag_req_v2 *req)
{ {
struct sk_buff *in_skb = cb->skb;
int err = -EINVAL; int err = -EINVAL;
struct sock *sk = NULL; struct sock *sk = NULL;
struct sk_buff *rep; struct sk_buff *rep;
...@@ -70,11 +69,8 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb, ...@@ -70,11 +69,8 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
if (!rep) if (!rep)
goto out; goto out;
err = inet_sk_diag_fill(sk, NULL, rep, req, err = inet_sk_diag_fill(sk, NULL, rep, cb, req, 0,
sk_user_ns(NETLINK_CB(in_skb).sk), netlink_net_capable(in_skb, CAP_NET_ADMIN));
NETLINK_CB(in_skb).portid,
nlh->nlmsg_seq, 0, nlh,
netlink_net_capable(in_skb, CAP_NET_ADMIN));
if (err < 0) { if (err < 0) {
WARN_ON(err == -EMSGSIZE); WARN_ON(err == -EMSGSIZE);
kfree_skb(rep); kfree_skb(rep);
...@@ -93,12 +89,16 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb, ...@@ -93,12 +89,16 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb,
static void udp_dump(struct udp_table *table, struct sk_buff *skb, static void udp_dump(struct udp_table *table, struct sk_buff *skb,
struct netlink_callback *cb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, struct nlattr *bc) const struct inet_diag_req_v2 *r)
{ {
bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN);
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
struct inet_diag_dump_data *cb_data;
int num, s_num, slot, s_slot; int num, s_num, slot, s_slot;
struct nlattr *bc;
cb_data = cb->data;
bc = cb_data->inet_diag_nla_bc;
s_slot = cb->args[0]; s_slot = cb->args[0];
num = s_num = cb->args[1]; num = s_num = cb->args[1];
...@@ -146,15 +146,15 @@ static void udp_dump(struct udp_table *table, struct sk_buff *skb, ...@@ -146,15 +146,15 @@ static void udp_dump(struct udp_table *table, struct sk_buff *skb,
} }
static void udp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, static void udp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, struct nlattr *bc) const struct inet_diag_req_v2 *r)
{ {
udp_dump(&udp_table, skb, cb, r, bc); udp_dump(&udp_table, skb, cb, r);
} }
static int udp_diag_dump_one(struct sk_buff *in_skb, const struct nlmsghdr *nlh, static int udp_diag_dump_one(struct netlink_callback *cb,
const struct inet_diag_req_v2 *req) const struct inet_diag_req_v2 *req)
{ {
return udp_dump_one(&udp_table, in_skb, nlh, req); return udp_dump_one(&udp_table, cb, req);
} }
static void udp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, static void udp_diag_get_info(struct sock *sk, struct inet_diag_msg *r,
...@@ -249,16 +249,15 @@ static const struct inet_diag_handler udp_diag_handler = { ...@@ -249,16 +249,15 @@ static const struct inet_diag_handler udp_diag_handler = {
}; };
static void udplite_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, static void udplite_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, const struct inet_diag_req_v2 *r)
struct nlattr *bc)
{ {
udp_dump(&udplite_table, skb, cb, r, bc); udp_dump(&udplite_table, skb, cb, r);
} }
static int udplite_diag_dump_one(struct sk_buff *in_skb, const struct nlmsghdr *nlh, static int udplite_diag_dump_one(struct netlink_callback *cb,
const struct inet_diag_req_v2 *req) const struct inet_diag_req_v2 *req)
{ {
return udp_dump_one(&udplite_table, in_skb, nlh, req); return udp_dump_one(&udplite_table, cb, req);
} }
static const struct inet_diag_handler udplite_diag_handler = { static const struct inet_diag_handler udplite_diag_handler = {
......
...@@ -432,11 +432,12 @@ static void sctp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, ...@@ -432,11 +432,12 @@ static void sctp_diag_get_info(struct sock *sk, struct inet_diag_msg *r,
sctp_get_sctp_info(sk, infox->asoc, infox->sctpinfo); sctp_get_sctp_info(sk, infox->asoc, infox->sctpinfo);
} }
static int sctp_diag_dump_one(struct sk_buff *in_skb, static int sctp_diag_dump_one(struct netlink_callback *cb,
const struct nlmsghdr *nlh,
const struct inet_diag_req_v2 *req) const struct inet_diag_req_v2 *req)
{ {
struct sk_buff *in_skb = cb->skb;
struct net *net = sock_net(in_skb->sk); struct net *net = sock_net(in_skb->sk);
const struct nlmsghdr *nlh = cb->nlh;
union sctp_addr laddr, paddr; union sctp_addr laddr, paddr;
struct sctp_comm_param commp = { struct sctp_comm_param commp = {
.skb = in_skb, .skb = in_skb,
...@@ -470,7 +471,7 @@ static int sctp_diag_dump_one(struct sk_buff *in_skb, ...@@ -470,7 +471,7 @@ static int sctp_diag_dump_one(struct sk_buff *in_skb,
} }
static void sctp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, static void sctp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, struct nlattr *bc) const struct inet_diag_req_v2 *r)
{ {
u32 idiag_states = r->idiag_states; u32 idiag_states = r->idiag_states;
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment