Commit b73c8bfd authored by David S. Miller's avatar David S. Miller

Merge branch 'skb_to_full_sk'

Eric Dumazet says:

====================
net: add skb_to_full_sk() helper

Many contexts need to reach listener socket from skb attached
to a request socket. This patch series add skb_to_full_sk() to
clearly express this need and use it where appropriate.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents fb9a10d9 3aed8225
...@@ -210,6 +210,18 @@ struct inet_sock { ...@@ -210,6 +210,18 @@ struct inet_sock {
#define IP_CMSG_ORIGDSTADDR BIT(6) #define IP_CMSG_ORIGDSTADDR BIT(6)
#define IP_CMSG_CHECKSUM BIT(7) #define IP_CMSG_CHECKSUM BIT(7)
/* SYNACK messages might be attached to request sockets.
* Some places want to reach the listener in this case.
*/
static inline struct sock *skb_to_full_sk(const struct sk_buff *skb)
{
struct sock *sk = skb->sk;
if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
sk = inet_reqsk(sk)->rsk_listener;
return sk;
}
static inline struct inet_sock *inet_sk(const struct sock *sk) static inline struct inet_sock *inet_sk(const struct sock *sk)
{ {
return (struct inet_sock *)sk; return (struct inet_sock *)sk;
......
...@@ -31,6 +31,7 @@ void nft_meta_get_eval(const struct nft_expr *expr, ...@@ -31,6 +31,7 @@ void nft_meta_get_eval(const struct nft_expr *expr,
const struct nft_meta *priv = nft_expr_priv(expr); const struct nft_meta *priv = nft_expr_priv(expr);
const struct sk_buff *skb = pkt->skb; const struct sk_buff *skb = pkt->skb;
const struct net_device *in = pkt->in, *out = pkt->out; const struct net_device *in = pkt->in, *out = pkt->out;
struct sock *sk;
u32 *dest = &regs->data[priv->dreg]; u32 *dest = &regs->data[priv->dreg];
switch (priv->key) { switch (priv->key) {
...@@ -86,33 +87,35 @@ void nft_meta_get_eval(const struct nft_expr *expr, ...@@ -86,33 +87,35 @@ void nft_meta_get_eval(const struct nft_expr *expr,
*(u16 *)dest = out->type; *(u16 *)dest = out->type;
break; break;
case NFT_META_SKUID: case NFT_META_SKUID:
if (skb->sk == NULL || !sk_fullsock(skb->sk)) sk = skb_to_full_sk(skb);
if (!sk || !sk_fullsock(sk))
goto err; goto err;
read_lock_bh(&skb->sk->sk_callback_lock); read_lock_bh(&sk->sk_callback_lock);
if (skb->sk->sk_socket == NULL || if (sk->sk_socket == NULL ||
skb->sk->sk_socket->file == NULL) { sk->sk_socket->file == NULL) {
read_unlock_bh(&skb->sk->sk_callback_lock); read_unlock_bh(&sk->sk_callback_lock);
goto err; goto err;
} }
*dest = from_kuid_munged(&init_user_ns, *dest = from_kuid_munged(&init_user_ns,
skb->sk->sk_socket->file->f_cred->fsuid); sk->sk_socket->file->f_cred->fsuid);
read_unlock_bh(&skb->sk->sk_callback_lock); read_unlock_bh(&sk->sk_callback_lock);
break; break;
case NFT_META_SKGID: case NFT_META_SKGID:
if (skb->sk == NULL || !sk_fullsock(skb->sk)) sk = skb_to_full_sk(skb);
if (!sk || !sk_fullsock(sk))
goto err; goto err;
read_lock_bh(&skb->sk->sk_callback_lock); read_lock_bh(&sk->sk_callback_lock);
if (skb->sk->sk_socket == NULL || if (sk->sk_socket == NULL ||
skb->sk->sk_socket->file == NULL) { sk->sk_socket->file == NULL) {
read_unlock_bh(&skb->sk->sk_callback_lock); read_unlock_bh(&sk->sk_callback_lock);
goto err; goto err;
} }
*dest = from_kgid_munged(&init_user_ns, *dest = from_kgid_munged(&init_user_ns,
skb->sk->sk_socket->file->f_cred->fsgid); sk->sk_socket->file->f_cred->fsgid);
read_unlock_bh(&skb->sk->sk_callback_lock); read_unlock_bh(&sk->sk_callback_lock);
break; break;
#ifdef CONFIG_IP_ROUTE_CLASSID #ifdef CONFIG_IP_ROUTE_CLASSID
case NFT_META_RTCLASSID: { case NFT_META_RTCLASSID: {
...@@ -168,9 +171,10 @@ void nft_meta_get_eval(const struct nft_expr *expr, ...@@ -168,9 +171,10 @@ void nft_meta_get_eval(const struct nft_expr *expr,
break; break;
#ifdef CONFIG_CGROUP_NET_CLASSID #ifdef CONFIG_CGROUP_NET_CLASSID
case NFT_META_CGROUP: case NFT_META_CGROUP:
if (skb->sk == NULL || !sk_fullsock(skb->sk)) sk = skb_to_full_sk(skb);
if (!sk || !sk_fullsock(sk))
goto err; goto err;
*dest = skb->sk->sk_classid; *dest = sk->sk_classid;
break; break;
#endif #endif
default: default:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <linux/skbuff.h> #include <linux/skbuff.h>
#include <linux/file.h> #include <linux/file.h>
#include <net/sock.h> #include <net/sock.h>
#include <net/inet_sock.h>
#include <linux/netfilter/x_tables.h> #include <linux/netfilter/x_tables.h>
#include <linux/netfilter/xt_owner.h> #include <linux/netfilter/xt_owner.h>
...@@ -33,8 +34,9 @@ owner_mt(const struct sk_buff *skb, struct xt_action_param *par) ...@@ -33,8 +34,9 @@ owner_mt(const struct sk_buff *skb, struct xt_action_param *par)
{ {
const struct xt_owner_match_info *info = par->matchinfo; const struct xt_owner_match_info *info = par->matchinfo;
const struct file *filp; const struct file *filp;
struct sock *sk = skb_to_full_sk(skb);
if (skb->sk == NULL || skb->sk->sk_socket == NULL) if (sk == NULL || sk->sk_socket == NULL)
return (info->match ^ info->invert) == 0; return (info->match ^ info->invert) == 0;
else if (info->match & info->invert & XT_OWNER_SOCKET) else if (info->match & info->invert & XT_OWNER_SOCKET)
/* /*
...@@ -43,7 +45,7 @@ owner_mt(const struct sk_buff *skb, struct xt_action_param *par) ...@@ -43,7 +45,7 @@ owner_mt(const struct sk_buff *skb, struct xt_action_param *par)
*/ */
return false; return false;
filp = skb->sk->sk_socket->file; filp = sk->sk_socket->file;
if (filp == NULL) if (filp == NULL)
return ((info->match ^ info->invert) & return ((info->match ^ info->invert) &
(XT_OWNER_UID | XT_OWNER_GID)) == 0; (XT_OWNER_UID | XT_OWNER_GID)) == 0;
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <linux/if_vlan.h> #include <linux/if_vlan.h>
#include <linux/slab.h> #include <linux/slab.h>
#include <linux/module.h> #include <linux/module.h>
#include <net/inet_sock.h>
#include <net/pkt_cls.h> #include <net/pkt_cls.h>
#include <net/ip.h> #include <net/ip.h>
...@@ -197,8 +198,11 @@ static u32 flow_get_rtclassid(const struct sk_buff *skb) ...@@ -197,8 +198,11 @@ static u32 flow_get_rtclassid(const struct sk_buff *skb)
static u32 flow_get_skuid(const struct sk_buff *skb) static u32 flow_get_skuid(const struct sk_buff *skb)
{ {
if (skb->sk && skb->sk->sk_socket && skb->sk->sk_socket->file) { struct sock *sk = skb_to_full_sk(skb);
kuid_t skuid = skb->sk->sk_socket->file->f_cred->fsuid;
if (sk && sk->sk_socket && sk->sk_socket->file) {
kuid_t skuid = sk->sk_socket->file->f_cred->fsuid;
return from_kuid(&init_user_ns, skuid); return from_kuid(&init_user_ns, skuid);
} }
return 0; return 0;
...@@ -206,8 +210,11 @@ static u32 flow_get_skuid(const struct sk_buff *skb) ...@@ -206,8 +210,11 @@ static u32 flow_get_skuid(const struct sk_buff *skb)
static u32 flow_get_skgid(const struct sk_buff *skb) static u32 flow_get_skgid(const struct sk_buff *skb)
{ {
if (skb->sk && skb->sk->sk_socket && skb->sk->sk_socket->file) { struct sock *sk = skb_to_full_sk(skb);
kgid_t skgid = skb->sk->sk_socket->file->f_cred->fsgid;
if (sk && sk->sk_socket && sk->sk_socket->file) {
kgid_t skgid = sk->sk_socket->file->f_cred->fsgid;
return from_kgid(&init_user_ns, skgid); return from_kgid(&init_user_ns, skgid);
} }
return 0; return 0;
......
...@@ -343,119 +343,145 @@ META_COLLECTOR(int_sk_refcnt) ...@@ -343,119 +343,145 @@ META_COLLECTOR(int_sk_refcnt)
META_COLLECTOR(int_sk_rcvbuf) META_COLLECTOR(int_sk_rcvbuf)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_rcvbuf; dst->value = sk->sk_rcvbuf;
} }
META_COLLECTOR(int_sk_shutdown) META_COLLECTOR(int_sk_shutdown)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_shutdown; dst->value = sk->sk_shutdown;
} }
META_COLLECTOR(int_sk_proto) META_COLLECTOR(int_sk_proto)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_protocol; dst->value = sk->sk_protocol;
} }
META_COLLECTOR(int_sk_type) META_COLLECTOR(int_sk_type)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_type; dst->value = sk->sk_type;
} }
META_COLLECTOR(int_sk_rmem_alloc) META_COLLECTOR(int_sk_rmem_alloc)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = sk_rmem_alloc_get(skb->sk); dst->value = sk_rmem_alloc_get(sk);
} }
META_COLLECTOR(int_sk_wmem_alloc) META_COLLECTOR(int_sk_wmem_alloc)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = sk_wmem_alloc_get(skb->sk); dst->value = sk_wmem_alloc_get(sk);
} }
META_COLLECTOR(int_sk_omem_alloc) META_COLLECTOR(int_sk_omem_alloc)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = atomic_read(&skb->sk->sk_omem_alloc); dst->value = atomic_read(&sk->sk_omem_alloc);
} }
META_COLLECTOR(int_sk_rcv_qlen) META_COLLECTOR(int_sk_rcv_qlen)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_receive_queue.qlen; dst->value = sk->sk_receive_queue.qlen;
} }
META_COLLECTOR(int_sk_snd_qlen) META_COLLECTOR(int_sk_snd_qlen)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_write_queue.qlen; dst->value = sk->sk_write_queue.qlen;
} }
META_COLLECTOR(int_sk_wmem_queued) META_COLLECTOR(int_sk_wmem_queued)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_wmem_queued; dst->value = sk->sk_wmem_queued;
} }
META_COLLECTOR(int_sk_fwd_alloc) META_COLLECTOR(int_sk_fwd_alloc)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_forward_alloc; dst->value = sk->sk_forward_alloc;
} }
META_COLLECTOR(int_sk_sndbuf) META_COLLECTOR(int_sk_sndbuf)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_sndbuf; dst->value = sk->sk_sndbuf;
} }
META_COLLECTOR(int_sk_alloc) META_COLLECTOR(int_sk_alloc)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = (__force int) skb->sk->sk_allocation; dst->value = (__force int) sk->sk_allocation;
} }
META_COLLECTOR(int_sk_hash) META_COLLECTOR(int_sk_hash)
...@@ -469,92 +495,112 @@ META_COLLECTOR(int_sk_hash) ...@@ -469,92 +495,112 @@ META_COLLECTOR(int_sk_hash)
META_COLLECTOR(int_sk_lingertime) META_COLLECTOR(int_sk_lingertime)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_lingertime / HZ; dst->value = sk->sk_lingertime / HZ;
} }
META_COLLECTOR(int_sk_err_qlen) META_COLLECTOR(int_sk_err_qlen)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_error_queue.qlen; dst->value = sk->sk_error_queue.qlen;
} }
META_COLLECTOR(int_sk_ack_bl) META_COLLECTOR(int_sk_ack_bl)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_ack_backlog; dst->value = sk->sk_ack_backlog;
} }
META_COLLECTOR(int_sk_max_ack_bl) META_COLLECTOR(int_sk_max_ack_bl)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_max_ack_backlog; dst->value = sk->sk_max_ack_backlog;
} }
META_COLLECTOR(int_sk_prio) META_COLLECTOR(int_sk_prio)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_priority; dst->value = sk->sk_priority;
} }
META_COLLECTOR(int_sk_rcvlowat) META_COLLECTOR(int_sk_rcvlowat)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_rcvlowat; dst->value = sk->sk_rcvlowat;
} }
META_COLLECTOR(int_sk_rcvtimeo) META_COLLECTOR(int_sk_rcvtimeo)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_rcvtimeo / HZ; dst->value = sk->sk_rcvtimeo / HZ;
} }
META_COLLECTOR(int_sk_sndtimeo) META_COLLECTOR(int_sk_sndtimeo)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_sndtimeo / HZ; dst->value = sk->sk_sndtimeo / HZ;
} }
META_COLLECTOR(int_sk_sendmsg_off) META_COLLECTOR(int_sk_sendmsg_off)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_frag.offset; dst->value = sk->sk_frag.offset;
} }
META_COLLECTOR(int_sk_write_pend) META_COLLECTOR(int_sk_write_pend)
{ {
if (skip_nonlocal(skb)) { const struct sock *sk = skb_to_full_sk(skb);
if (!sk) {
*err = -1; *err = -1;
return; return;
} }
dst->value = skb->sk->sk_write_pending; dst->value = sk->sk_write_pending;
} }
/************************************************************************** /**************************************************************************
......
...@@ -4931,23 +4931,11 @@ static unsigned int selinux_ipv4_output(void *priv, ...@@ -4931,23 +4931,11 @@ static unsigned int selinux_ipv4_output(void *priv,
return selinux_ip_output(skb, PF_INET); return selinux_ip_output(skb, PF_INET);
} }
/* SYNACK messages might be attached to request sockets.
* To get back to sk_security, we need to look at the listener.
*/
static struct sock *selinux_skb_sk(const struct sk_buff *skb)
{
struct sock *sk = skb->sk;
if (sk && sk->sk_state == TCP_NEW_SYN_RECV)
sk = inet_reqsk(sk)->rsk_listener;
return sk;
}
static unsigned int selinux_ip_postroute_compat(struct sk_buff *skb, static unsigned int selinux_ip_postroute_compat(struct sk_buff *skb,
int ifindex, int ifindex,
u16 family) u16 family)
{ {
struct sock *sk = selinux_skb_sk(skb); struct sock *sk = skb_to_full_sk(skb);
struct sk_security_struct *sksec; struct sk_security_struct *sksec;
struct common_audit_data ad; struct common_audit_data ad;
struct lsm_network_audit net = {0,}; struct lsm_network_audit net = {0,};
...@@ -5002,7 +4990,7 @@ static unsigned int selinux_ip_postroute(struct sk_buff *skb, ...@@ -5002,7 +4990,7 @@ static unsigned int selinux_ip_postroute(struct sk_buff *skb,
if (!secmark_active && !peerlbl_active) if (!secmark_active && !peerlbl_active)
return NF_ACCEPT; return NF_ACCEPT;
sk = selinux_skb_sk(skb); sk = skb_to_full_sk(skb);
#ifdef CONFIG_XFRM #ifdef CONFIG_XFRM
/* If skb->dst->xfrm is non-NULL then the packet is undergoing an IPsec /* If skb->dst->xfrm is non-NULL then the packet is undergoing an IPsec
......
...@@ -245,7 +245,7 @@ int selinux_netlbl_skbuff_setsid(struct sk_buff *skb, ...@@ -245,7 +245,7 @@ int selinux_netlbl_skbuff_setsid(struct sk_buff *skb,
/* if this is a locally generated packet check to see if it is already /* if this is a locally generated packet check to see if it is already
* being labeled by it's parent socket, if it is just exit */ * being labeled by it's parent socket, if it is just exit */
sk = skb->sk; sk = skb_to_full_sk(skb);
if (sk != NULL) { if (sk != NULL) {
struct sk_security_struct *sksec = sk->sk_security; struct sk_security_struct *sksec = sk->sk_security;
if (sksec->nlbl_state != NLBL_REQSKB) if (sksec->nlbl_state != NLBL_REQSKB)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <linux/netfilter_ipv4.h> #include <linux/netfilter_ipv4.h>
#include <linux/netfilter_ipv6.h> #include <linux/netfilter_ipv6.h>
#include <linux/netdevice.h> #include <linux/netdevice.h>
#include <net/inet_sock.h>
#include "smack.h" #include "smack.h"
#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
...@@ -25,11 +26,12 @@ static unsigned int smack_ipv6_output(void *priv, ...@@ -25,11 +26,12 @@ static unsigned int smack_ipv6_output(void *priv,
struct sk_buff *skb, struct sk_buff *skb,
const struct nf_hook_state *state) const struct nf_hook_state *state)
{ {
struct sock *sk = skb_to_full_sk(skb);
struct socket_smack *ssp; struct socket_smack *ssp;
struct smack_known *skp; struct smack_known *skp;
if (skb && skb->sk && skb->sk->sk_security) { if (sk && sk->sk_security) {
ssp = skb->sk->sk_security; ssp = sk->sk_security;
skp = ssp->smk_out; skp = ssp->smk_out;
skb->secmark = skp->smk_secid; skb->secmark = skp->smk_secid;
} }
...@@ -42,11 +44,12 @@ static unsigned int smack_ipv4_output(void *priv, ...@@ -42,11 +44,12 @@ static unsigned int smack_ipv4_output(void *priv,
struct sk_buff *skb, struct sk_buff *skb,
const struct nf_hook_state *state) const struct nf_hook_state *state)
{ {
struct sock *sk = skb_to_full_sk(skb);
struct socket_smack *ssp; struct socket_smack *ssp;
struct smack_known *skp; struct smack_known *skp;
if (skb && skb->sk && skb->sk->sk_security) { if (sk && sk->sk_security) {
ssp = skb->sk->sk_security; ssp = sk->sk_security;
skp = ssp->smk_out; skp = ssp->smk_out;
skb->secmark = skp->smk_secid; skb->secmark = skp->smk_secid;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment