Commit b35f504a authored by David S. Miller's avatar David S. Miller

Merge branch 'listener_refactor'

Eric Dumazet says:

====================
inet: tcp listener refactoring, part 10

We are getting close to the point where request sockets will be hashed
into generic hash table. Some followups are needed for netfilter and
will be handled in next patch series.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents f00bbd21 13854e5a
...@@ -275,6 +275,11 @@ static inline void inet_csk_reqsk_queue_add(struct sock *sk, ...@@ -275,6 +275,11 @@ static inline void inet_csk_reqsk_queue_add(struct sock *sk,
struct sock *child) struct sock *child)
{ {
reqsk_queue_add(&inet_csk(sk)->icsk_accept_queue, req, sk, child); reqsk_queue_add(&inet_csk(sk)->icsk_accept_queue, req, sk, child);
/* before letting lookups find us, make sure all req fields
* are committed to memory.
*/
smp_wmb();
atomic_set(&req->rsk_refcnt, 1);
} }
void inet_csk_reqsk_queue_hash_add(struct sock *sk, struct request_sock *req, void inet_csk_reqsk_queue_hash_add(struct sock *sk, struct request_sock *req,
......
...@@ -255,6 +255,11 @@ static inline struct request_sock *inet_reqsk_alloc(struct request_sock_ops *ops ...@@ -255,6 +255,11 @@ static inline struct request_sock *inet_reqsk_alloc(struct request_sock_ops *ops
ireq->opt = NULL; ireq->opt = NULL;
atomic64_set(&ireq->ir_cookie, 0); atomic64_set(&ireq->ir_cookie, 0);
ireq->ireq_state = TCP_NEW_SYN_RECV; ireq->ireq_state = TCP_NEW_SYN_RECV;
/* Following is temporary. It is coupled with debugging
* helpers in reqsk_put() & reqsk_free()
*/
atomic_set(&ireq->ireq_refcnt, 0);
} }
return req; return req;
......
...@@ -82,19 +82,20 @@ static inline struct request_sock *inet_reqsk(struct sock *sk) ...@@ -82,19 +82,20 @@ static inline struct request_sock *inet_reqsk(struct sock *sk)
return (struct request_sock *)sk; return (struct request_sock *)sk;
} }
static inline void __reqsk_free(struct request_sock *req)
{
kmem_cache_free(req->rsk_ops->slab, req);
}
static inline void reqsk_free(struct request_sock *req) static inline void reqsk_free(struct request_sock *req)
{ {
/* temporary debugging */
WARN_ON_ONCE(atomic_read(&req->rsk_refcnt) != 0);
req->rsk_ops->destructor(req); req->rsk_ops->destructor(req);
__reqsk_free(req); kmem_cache_free(req->rsk_ops->slab, req);
} }
static inline void reqsk_put(struct request_sock *req) static inline void reqsk_put(struct request_sock *req)
{ {
/* temporary debugging, until req sock are put into ehash table */
WARN_ON_ONCE(atomic_read(&req->rsk_refcnt) != 1);
if (atomic_dec_and_test(&req->rsk_refcnt)) if (atomic_dec_and_test(&req->rsk_refcnt))
reqsk_free(req); reqsk_free(req);
} }
......
...@@ -67,6 +67,7 @@ ...@@ -67,6 +67,7 @@
#include <linux/atomic.h> #include <linux/atomic.h>
#include <net/dst.h> #include <net/dst.h>
#include <net/checksum.h> #include <net/checksum.h>
#include <net/tcp_states.h>
#include <linux/net_tstamp.h> #include <linux/net_tstamp.h>
struct cgroup; struct cgroup;
...@@ -2218,6 +2219,14 @@ static inline struct sock *skb_steal_sock(struct sk_buff *skb) ...@@ -2218,6 +2219,14 @@ static inline struct sock *skb_steal_sock(struct sk_buff *skb)
return NULL; return NULL;
} }
/* This helper checks if a socket is a full socket,
* ie _not_ a timewait or request socket.
*/
static inline bool sk_fullsock(const struct sock *sk)
{
return (1 << sk->sk_state) & ~(TCPF_TIME_WAIT | TCPF_NEW_SYN_RECV);
}
void sock_enable_timestamp(struct sock *sk, int flag); void sock_enable_timestamp(struct sock *sk, int flag);
int sock_get_timestamp(struct sock *, struct timeval __user *); int sock_get_timestamp(struct sock *, struct timeval __user *);
int sock_get_timestampns(struct sock *, struct timespec __user *); int sock_get_timestampns(struct sock *, struct timespec __user *);
......
...@@ -103,7 +103,7 @@ void reqsk_queue_destroy(struct request_sock_queue *queue) ...@@ -103,7 +103,7 @@ void reqsk_queue_destroy(struct request_sock_queue *queue)
while ((req = lopt->syn_table[i]) != NULL) { while ((req = lopt->syn_table[i]) != NULL) {
lopt->syn_table[i] = req->dl_next; lopt->syn_table[i] = req->dl_next;
lopt->qlen--; lopt->qlen--;
reqsk_free(req); reqsk_put(req);
} }
} }
} }
...@@ -180,7 +180,7 @@ void reqsk_fastopen_remove(struct sock *sk, struct request_sock *req, ...@@ -180,7 +180,7 @@ void reqsk_fastopen_remove(struct sock *sk, struct request_sock *req,
*/ */
spin_unlock_bh(&fastopenq->lock); spin_unlock_bh(&fastopenq->lock);
sock_put(lsk); sock_put(lsk);
reqsk_free(req); reqsk_put(req);
return; return;
} }
/* Wait for 60secs before removing a req that has triggered RST. /* Wait for 60secs before removing a req that has triggered RST.
......
...@@ -1661,21 +1661,6 @@ void sock_efree(struct sk_buff *skb) ...@@ -1661,21 +1661,6 @@ void sock_efree(struct sk_buff *skb)
} }
EXPORT_SYMBOL(sock_efree); EXPORT_SYMBOL(sock_efree);
#ifdef CONFIG_INET
void sock_edemux(struct sk_buff *skb)
{
struct sock *sk = skb->sk;
if (sk->sk_state == TCP_TIME_WAIT)
inet_twsk_put(inet_twsk(sk));
else if (sk->sk_state == TCP_NEW_SYN_RECV)
reqsk_put(inet_reqsk(sk));
else
sock_put(sk);
}
EXPORT_SYMBOL(sock_edemux);
#endif
kuid_t sock_i_uid(struct sock *sk) kuid_t sock_i_uid(struct sock *sk)
{ {
kuid_t uid; kuid_t uid;
......
...@@ -340,7 +340,7 @@ struct sock *inet_csk_accept(struct sock *sk, int flags, int *err) ...@@ -340,7 +340,7 @@ struct sock *inet_csk_accept(struct sock *sk, int flags, int *err)
out: out:
release_sock(sk); release_sock(sk);
if (req) if (req)
__reqsk_free(req); reqsk_put(req);
return newsk; return newsk;
out_err: out_err:
newsk = NULL; newsk = NULL;
...@@ -635,7 +635,7 @@ void inet_csk_reqsk_queue_prune(struct sock *parent, ...@@ -635,7 +635,7 @@ void inet_csk_reqsk_queue_prune(struct sock *parent,
/* Drop this request */ /* Drop this request */
inet_csk_reqsk_queue_unlink(parent, req, reqp); inet_csk_reqsk_queue_unlink(parent, req, reqp);
reqsk_queue_removed(queue, req); reqsk_queue_removed(queue, req);
reqsk_free(req); reqsk_put(req);
continue; continue;
} }
reqp = &req->dl_next; reqp = &req->dl_next;
...@@ -837,7 +837,7 @@ void inet_csk_listen_stop(struct sock *sk) ...@@ -837,7 +837,7 @@ void inet_csk_listen_stop(struct sock *sk)
sock_put(child); sock_put(child);
sk_acceptq_removed(sk); sk_acceptq_removed(sk);
__reqsk_free(req); reqsk_put(req);
} }
if (queue->fastopenq != NULL) { if (queue->fastopenq != NULL) {
/* Free all the reqs queued in rskq_rst_head. */ /* Free all the reqs queued in rskq_rst_head. */
...@@ -847,7 +847,7 @@ void inet_csk_listen_stop(struct sock *sk) ...@@ -847,7 +847,7 @@ void inet_csk_listen_stop(struct sock *sk)
spin_unlock_bh(&queue->fastopenq->lock); spin_unlock_bh(&queue->fastopenq->lock);
while ((req = acc_req) != NULL) { while ((req = acc_req) != NULL) {
acc_req = req->dl_next; acc_req = req->dl_next;
__reqsk_free(req); reqsk_put(req);
} }
} }
WARN_ON(sk->sk_ack_backlog); WARN_ON(sk->sk_ack_backlog);
......
...@@ -113,14 +113,13 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, ...@@ -113,14 +113,13 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,
return -EMSGSIZE; return -EMSGSIZE;
r = nlmsg_data(nlh); r = nlmsg_data(nlh);
BUG_ON((1 << sk->sk_state) & (TCPF_TIME_WAIT | TCPF_NEW_SYN_RECV)); BUG_ON(!sk_fullsock(sk));
inet_diag_msg_common_fill(r, sk); inet_diag_msg_common_fill(r, sk);
r->idiag_state = sk->sk_state; r->idiag_state = sk->sk_state;
r->idiag_timer = 0; r->idiag_timer = 0;
r->idiag_retrans = 0; r->idiag_retrans = 0;
if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown)) if (nla_put_u8(skb, INET_DIAG_SHUTDOWN, sk->sk_shutdown))
goto errout; goto errout;
...@@ -229,7 +228,6 @@ static int inet_csk_diag_fill(struct sock *sk, ...@@ -229,7 +228,6 @@ static int inet_csk_diag_fill(struct sock *sk,
static int inet_twsk_diag_fill(struct sock *sk, static int inet_twsk_diag_fill(struct sock *sk,
struct sk_buff *skb, struct sk_buff *skb,
const struct inet_diag_req_v2 *req,
u32 portid, u32 seq, u16 nlmsg_flags, u32 portid, u32 seq, u16 nlmsg_flags,
const struct nlmsghdr *unlh) const struct nlmsghdr *unlh)
{ {
...@@ -265,6 +263,39 @@ static int inet_twsk_diag_fill(struct sock *sk, ...@@ -265,6 +263,39 @@ static int inet_twsk_diag_fill(struct sock *sk,
return 0; return 0;
} }
static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb,
u32 portid, u32 seq, u16 nlmsg_flags,
const struct nlmsghdr *unlh)
{
struct inet_diag_msg *r;
struct nlmsghdr *nlh;
long tmo;
nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
nlmsg_flags);
if (!nlh)
return -EMSGSIZE;
r = nlmsg_data(nlh);
inet_diag_msg_common_fill(r, sk);
r->idiag_state = TCP_SYN_RECV;
r->idiag_timer = 1;
r->idiag_retrans = inet_reqsk(sk)->num_retrans;
BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
offsetof(struct sock, sk_cookie));
tmo = inet_reqsk(sk)->expires - jiffies;
r->idiag_expires = (tmo >= 0) ? jiffies_to_msecs(tmo) : 0;
r->idiag_rqueue = 0;
r->idiag_wqueue = 0;
r->idiag_uid = 0;
r->idiag_inode = 0;
nlmsg_end(skb, nlh);
return 0;
}
static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
const struct inet_diag_req_v2 *r, const struct inet_diag_req_v2 *r,
struct user_namespace *user_ns, struct user_namespace *user_ns,
...@@ -272,9 +303,13 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, ...@@ -272,9 +303,13 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb,
const struct nlmsghdr *unlh) const struct nlmsghdr *unlh)
{ {
if (sk->sk_state == TCP_TIME_WAIT) if (sk->sk_state == TCP_TIME_WAIT)
return inet_twsk_diag_fill(sk, skb, r, portid, seq, return inet_twsk_diag_fill(sk, skb, portid, seq,
nlmsg_flags, unlh); nlmsg_flags, unlh);
if (sk->sk_state == TCP_NEW_SYN_RECV)
return inet_req_diag_fill(sk, skb, portid, seq,
nlmsg_flags, unlh);
return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq, return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq,
nlmsg_flags, unlh); nlmsg_flags, unlh);
} }
...@@ -502,7 +537,7 @@ int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk) ...@@ -502,7 +537,7 @@ int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk)
entry_fill_addrs(&entry, sk); entry_fill_addrs(&entry, sk);
entry.sport = inet->inet_num; entry.sport = inet->inet_num;
entry.dport = ntohs(inet->inet_dport); entry.dport = ntohs(inet->inet_dport);
entry.userlocks = (sk->sk_state != TCP_TIME_WAIT) ? sk->sk_userlocks : 0; entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0;
return inet_diag_bc_run(bc, &entry); return inet_diag_bc_run(bc, &entry);
} }
...@@ -661,61 +696,6 @@ static void twsk_build_assert(void) ...@@ -661,61 +696,6 @@ static void twsk_build_assert(void)
#endif #endif
} }
static int inet_twsk_diag_dump(struct sock *sk,
struct sk_buff *skb,
struct netlink_callback *cb,
const struct inet_diag_req_v2 *r,
const struct nlattr *bc)
{
twsk_build_assert();
if (!inet_diag_bc_sk(bc, sk))
return 0;
return inet_twsk_diag_fill(sk, skb, r,
NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh);
}
static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
struct request_sock *req,
struct user_namespace *user_ns,
u32 portid, u32 seq,
const struct nlmsghdr *unlh)
{
const struct inet_request_sock *ireq = inet_rsk(req);
struct inet_diag_msg *r;
struct nlmsghdr *nlh;
long tmo;
nlh = nlmsg_put(skb, portid, seq, unlh->nlmsg_type, sizeof(*r),
NLM_F_MULTI);
if (!nlh)
return -EMSGSIZE;
r = nlmsg_data(nlh);
inet_diag_msg_common_fill(r, (struct sock *)ireq);
r->idiag_state = TCP_SYN_RECV;
r->idiag_timer = 1;
r->idiag_retrans = req->num_retrans;
BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) !=
offsetof(struct sock, sk_cookie));
tmo = req->expires - jiffies;
if (tmo < 0)
tmo = 0;
r->idiag_expires = jiffies_to_msecs(tmo);
r->idiag_rqueue = 0;
r->idiag_wqueue = 0;
r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk));
r->idiag_inode = 0;
nlmsg_end(skb, nlh);
return 0;
}
static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk, static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
struct netlink_callback *cb, struct netlink_callback *cb,
const struct inet_diag_req_v2 *r, const struct inet_diag_req_v2 *r,
...@@ -769,10 +749,10 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk, ...@@ -769,10 +749,10 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
continue; continue;
} }
err = inet_diag_fill_req(skb, sk, req, err = inet_req_diag_fill((struct sock *)req, skb,
sk_user_ns(NETLINK_CB(cb->skb).sk),
NETLINK_CB(cb->skb).portid, NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, cb->nlh); cb->nlh->nlmsg_seq,
NLM_F_MULTI, cb->nlh);
if (err < 0) { if (err < 0) {
cb->args[3] = j + 1; cb->args[3] = j + 1;
cb->args[4] = reqnum; cb->args[4] = reqnum;
...@@ -903,10 +883,16 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, ...@@ -903,10 +883,16 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,
if (r->id.idiag_dport != sk->sk_dport && if (r->id.idiag_dport != sk->sk_dport &&
r->id.idiag_dport) r->id.idiag_dport)
goto next_normal; goto next_normal;
if (sk->sk_state == TCP_TIME_WAIT) twsk_build_assert();
res = inet_twsk_diag_dump(sk, skb, cb, r, bc);
else if (!inet_diag_bc_sk(bc, sk))
res = inet_csk_diag_dump(sk, skb, cb, r, bc); goto next_normal;
res = sk_diag_fill(sk, skb, r,
sk_user_ns(NETLINK_CB(cb->skb).sk),
NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, NLM_F_MULTI,
cb->nlh);
if (res < 0) { if (res < 0) {
spin_unlock_bh(lock); spin_unlock_bh(lock);
goto done; goto done;
......
...@@ -269,6 +269,12 @@ void sock_gen_put(struct sock *sk) ...@@ -269,6 +269,12 @@ void sock_gen_put(struct sock *sk)
} }
EXPORT_SYMBOL_GPL(sock_gen_put); EXPORT_SYMBOL_GPL(sock_gen_put);
void sock_edemux(struct sk_buff *skb)
{
sock_gen_put(skb->sk);
}
EXPORT_SYMBOL(sock_edemux);
struct sock *__inet_lookup_established(struct net *net, struct sock *__inet_lookup_established(struct net *net,
struct inet_hashinfo *hashinfo, struct inet_hashinfo *hashinfo,
const __be32 saddr, const __be16 sport, const __be32 saddr, const __be16 sport,
......
...@@ -219,9 +219,9 @@ int __cookie_v4_check(const struct iphdr *iph, const struct tcphdr *th, ...@@ -219,9 +219,9 @@ int __cookie_v4_check(const struct iphdr *iph, const struct tcphdr *th,
} }
EXPORT_SYMBOL_GPL(__cookie_v4_check); EXPORT_SYMBOL_GPL(__cookie_v4_check);
static inline struct sock *get_cookie_sock(struct sock *sk, struct sk_buff *skb, static struct sock *get_cookie_sock(struct sock *sk, struct sk_buff *skb,
struct request_sock *req, struct request_sock *req,
struct dst_entry *dst) struct dst_entry *dst)
{ {
struct inet_connection_sock *icsk = inet_csk(sk); struct inet_connection_sock *icsk = inet_csk(sk);
struct sock *child; struct sock *child;
...@@ -357,7 +357,7 @@ struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb) ...@@ -357,7 +357,7 @@ struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb)
ireq->opt = tcp_v4_save_options(skb); ireq->opt = tcp_v4_save_options(skb);
if (security_inet_conn_request(sk, skb, req)) { if (security_inet_conn_request(sk, skb, req)) {
reqsk_free(req); reqsk_put(req);
goto out; goto out;
} }
...@@ -378,7 +378,7 @@ struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb) ...@@ -378,7 +378,7 @@ struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb)
security_req_classify_flow(req, flowi4_to_flowi(&fl4)); security_req_classify_flow(req, flowi4_to_flowi(&fl4));
rt = ip_route_output_key(sock_net(sk), &fl4); rt = ip_route_output_key(sock_net(sk), &fl4);
if (IS_ERR(rt)) { if (IS_ERR(rt)) {
reqsk_free(req); reqsk_put(req);
goto out; goto out;
} }
......
...@@ -253,7 +253,7 @@ static bool tcp_fastopen_queue_check(struct sock *sk) ...@@ -253,7 +253,7 @@ static bool tcp_fastopen_queue_check(struct sock *sk)
fastopenq->rskq_rst_head = req1->dl_next; fastopenq->rskq_rst_head = req1->dl_next;
fastopenq->qlen--; fastopenq->qlen--;
spin_unlock(&fastopenq->lock); spin_unlock(&fastopenq->lock);
reqsk_free(req1); reqsk_put(req1);
} }
return true; return true;
} }
......
...@@ -1518,7 +1518,7 @@ void tcp_v4_early_demux(struct sk_buff *skb) ...@@ -1518,7 +1518,7 @@ void tcp_v4_early_demux(struct sk_buff *skb)
if (sk) { if (sk) {
skb->sk = sk; skb->sk = sk;
skb->destructor = sock_edemux; skb->destructor = sock_edemux;
if (sk->sk_state != TCP_TIME_WAIT) { if (sk_fullsock(sk)) {
struct dst_entry *dst = sk->sk_rx_dst; struct dst_entry *dst = sk->sk_rx_dst;
if (dst) if (dst)
......
...@@ -1583,7 +1583,7 @@ static void tcp_v6_early_demux(struct sk_buff *skb) ...@@ -1583,7 +1583,7 @@ static void tcp_v6_early_demux(struct sk_buff *skb)
if (sk) { if (sk) {
skb->sk = sk; skb->sk = sk;
skb->destructor = sock_edemux; skb->destructor = sock_edemux;
if (sk->sk_state != TCP_TIME_WAIT) { if (sk_fullsock(sk)) {
struct dst_entry *dst = sk->sk_rx_dst; struct dst_entry *dst = sk->sk_rx_dst;
if (dst) if (dst)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment