Commit 47dceb8e authored by Willem de Bruijn's avatar Willem de Bruijn Committed by David S. Miller

packet: add classic BPF fanout mode

Add fanout mode PACKET_FANOUT_CBPF that accepts a classic BPF program
to select a socket.

This avoids having to keep adding special case fanout modes. One
example use case is application layer load balancing. The QUIC
protocol, for instance, encodes a connection ID in UDP payload.

Also add socket option SOL_PACKET/PACKET_FANOUT_DATA that updates data
associated with the socket group. Fanout mode PACKET_FANOUT_CBPF is the
only user so far.
Signed-off-by: default avatarWillem de Bruijn <willemb@google.com>
Acked-by: default avatarAlexei Starovoitov <ast@plumgrid.com>
Acked-by: default avatarDaniel Borkmann <daniel@iogearbox.net>
Acked-by: default avatarEric Dumazet <edumazet@google.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent a1c234f9
...@@ -55,6 +55,7 @@ struct sockaddr_ll { ...@@ -55,6 +55,7 @@ struct sockaddr_ll {
#define PACKET_TX_HAS_OFF 19 #define PACKET_TX_HAS_OFF 19
#define PACKET_QDISC_BYPASS 20 #define PACKET_QDISC_BYPASS 20
#define PACKET_ROLLOVER_STATS 21 #define PACKET_ROLLOVER_STATS 21
#define PACKET_FANOUT_DATA 22
#define PACKET_FANOUT_HASH 0 #define PACKET_FANOUT_HASH 0
#define PACKET_FANOUT_LB 1 #define PACKET_FANOUT_LB 1
...@@ -62,6 +63,7 @@ struct sockaddr_ll { ...@@ -62,6 +63,7 @@ struct sockaddr_ll {
#define PACKET_FANOUT_ROLLOVER 3 #define PACKET_FANOUT_ROLLOVER 3
#define PACKET_FANOUT_RND 4 #define PACKET_FANOUT_RND 4
#define PACKET_FANOUT_QM 5 #define PACKET_FANOUT_QM 5
#define PACKET_FANOUT_CBPF 6
#define PACKET_FANOUT_FLAG_ROLLOVER 0x1000 #define PACKET_FANOUT_FLAG_ROLLOVER 0x1000
#define PACKET_FANOUT_FLAG_DEFRAG 0x8000 #define PACKET_FANOUT_FLAG_DEFRAG 0x8000
......
...@@ -92,6 +92,7 @@ ...@@ -92,6 +92,7 @@
#ifdef CONFIG_INET #ifdef CONFIG_INET
#include <net/inet_common.h> #include <net/inet_common.h>
#endif #endif
#include <linux/bpf.h>
#include "internal.h" #include "internal.h"
...@@ -1410,6 +1411,22 @@ static unsigned int fanout_demux_qm(struct packet_fanout *f, ...@@ -1410,6 +1411,22 @@ static unsigned int fanout_demux_qm(struct packet_fanout *f,
return skb_get_queue_mapping(skb) % num; return skb_get_queue_mapping(skb) % num;
} }
static unsigned int fanout_demux_bpf(struct packet_fanout *f,
struct sk_buff *skb,
unsigned int num)
{
struct bpf_prog *prog;
unsigned int ret = 0;
rcu_read_lock();
prog = rcu_dereference(f->bpf_prog);
if (prog)
ret = BPF_PROG_RUN(prog, skb) % num;
rcu_read_unlock();
return ret;
}
static bool fanout_has_flag(struct packet_fanout *f, u16 flag) static bool fanout_has_flag(struct packet_fanout *f, u16 flag)
{ {
return f->flags & (flag >> 8); return f->flags & (flag >> 8);
...@@ -1454,6 +1471,9 @@ static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev, ...@@ -1454,6 +1471,9 @@ static int packet_rcv_fanout(struct sk_buff *skb, struct net_device *dev,
case PACKET_FANOUT_ROLLOVER: case PACKET_FANOUT_ROLLOVER:
idx = fanout_demux_rollover(f, skb, 0, false, num); idx = fanout_demux_rollover(f, skb, 0, false, num);
break; break;
case PACKET_FANOUT_CBPF:
idx = fanout_demux_bpf(f, skb, num);
break;
} }
if (fanout_has_flag(f, PACKET_FANOUT_FLAG_ROLLOVER)) if (fanout_has_flag(f, PACKET_FANOUT_FLAG_ROLLOVER))
...@@ -1502,6 +1522,74 @@ static bool match_fanout_group(struct packet_type *ptype, struct sock *sk) ...@@ -1502,6 +1522,74 @@ static bool match_fanout_group(struct packet_type *ptype, struct sock *sk)
return false; return false;
} }
static void fanout_init_data(struct packet_fanout *f)
{
switch (f->type) {
case PACKET_FANOUT_LB:
atomic_set(&f->rr_cur, 0);
break;
case PACKET_FANOUT_CBPF:
RCU_INIT_POINTER(f->bpf_prog, NULL);
break;
}
}
static void __fanout_set_data_bpf(struct packet_fanout *f, struct bpf_prog *new)
{
struct bpf_prog *old;
spin_lock(&f->lock);
old = rcu_dereference_protected(f->bpf_prog, lockdep_is_held(&f->lock));
rcu_assign_pointer(f->bpf_prog, new);
spin_unlock(&f->lock);
if (old) {
synchronize_net();
bpf_prog_destroy(old);
}
}
static int fanout_set_data_cbpf(struct packet_sock *po, char __user *data,
unsigned int len)
{
struct bpf_prog *new;
struct sock_fprog fprog;
int ret;
if (sock_flag(&po->sk, SOCK_FILTER_LOCKED))
return -EPERM;
if (len != sizeof(fprog))
return -EINVAL;
if (copy_from_user(&fprog, data, len))
return -EFAULT;
ret = bpf_prog_create_from_user(&new, &fprog, NULL);
if (ret)
return ret;
__fanout_set_data_bpf(po->fanout, new);
return 0;
}
static int fanout_set_data(struct packet_sock *po, char __user *data,
unsigned int len)
{
switch (po->fanout->type) {
case PACKET_FANOUT_CBPF:
return fanout_set_data_cbpf(po, data, len);
default:
return -EINVAL;
};
}
static void fanout_release_data(struct packet_fanout *f)
{
switch (f->type) {
case PACKET_FANOUT_CBPF:
__fanout_set_data_bpf(f, NULL);
};
}
static int fanout_add(struct sock *sk, u16 id, u16 type_flags) static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
{ {
struct packet_sock *po = pkt_sk(sk); struct packet_sock *po = pkt_sk(sk);
...@@ -1519,6 +1607,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags) ...@@ -1519,6 +1607,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
case PACKET_FANOUT_CPU: case PACKET_FANOUT_CPU:
case PACKET_FANOUT_RND: case PACKET_FANOUT_RND:
case PACKET_FANOUT_QM: case PACKET_FANOUT_QM:
case PACKET_FANOUT_CBPF:
break; break;
default: default:
return -EINVAL; return -EINVAL;
...@@ -1561,10 +1650,10 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags) ...@@ -1561,10 +1650,10 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
match->id = id; match->id = id;
match->type = type; match->type = type;
match->flags = flags; match->flags = flags;
atomic_set(&match->rr_cur, 0);
INIT_LIST_HEAD(&match->list); INIT_LIST_HEAD(&match->list);
spin_lock_init(&match->lock); spin_lock_init(&match->lock);
atomic_set(&match->sk_ref, 0); atomic_set(&match->sk_ref, 0);
fanout_init_data(match);
match->prot_hook.type = po->prot_hook.type; match->prot_hook.type = po->prot_hook.type;
match->prot_hook.dev = po->prot_hook.dev; match->prot_hook.dev = po->prot_hook.dev;
match->prot_hook.func = packet_rcv_fanout; match->prot_hook.func = packet_rcv_fanout;
...@@ -1610,6 +1699,7 @@ static void fanout_release(struct sock *sk) ...@@ -1610,6 +1699,7 @@ static void fanout_release(struct sock *sk)
if (atomic_dec_and_test(&f->sk_ref)) { if (atomic_dec_and_test(&f->sk_ref)) {
list_del(&f->list); list_del(&f->list);
dev_remove_pack(&f->prot_hook); dev_remove_pack(&f->prot_hook);
fanout_release_data(f);
kfree(f); kfree(f);
} }
mutex_unlock(&fanout_mutex); mutex_unlock(&fanout_mutex);
...@@ -3529,6 +3619,13 @@ packet_setsockopt(struct socket *sock, int level, int optname, char __user *optv ...@@ -3529,6 +3619,13 @@ packet_setsockopt(struct socket *sock, int level, int optname, char __user *optv
return fanout_add(sk, val & 0xffff, val >> 16); return fanout_add(sk, val & 0xffff, val >> 16);
} }
case PACKET_FANOUT_DATA:
{
if (!po->fanout)
return -EINVAL;
return fanout_set_data(po, optval, optlen);
}
case PACKET_TX_HAS_OFF: case PACKET_TX_HAS_OFF:
{ {
unsigned int val; unsigned int val;
......
...@@ -79,7 +79,10 @@ struct packet_fanout { ...@@ -79,7 +79,10 @@ struct packet_fanout {
u16 id; u16 id;
u8 type; u8 type;
u8 flags; u8 flags;
atomic_t rr_cur; union {
atomic_t rr_cur;
struct bpf_prog __rcu *bpf_prog;
};
struct list_head list; struct list_head list;
struct sock *arr[PACKET_FANOUT_MAX]; struct sock *arr[PACKET_FANOUT_MAX];
spinlock_t lock; spinlock_t lock;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment