Commit 3fafd92e authored by David S. Miller's avatar David S. Miller

Merge branch 'l2tp-session-cleanup' into main

James Chapman says:

====================
l2tp: simplify tunnel and session cleanup

This series simplifies and improves l2tp tunnel and session cleanup.

 * refactor l2tp management code to not use the tunnel socket's
   sk_user_data. This allows the tunnel and its socket to be closed
   and freed without sequencing the two using the socket's sk_destruct
   hook.

 * export ip_flush_pending_frames and use it when closing l2tp_ip
   sockets.

 * move the work of closing all sessions in the tunnel to the work
   queue so that sessions are deleted using the same codepath whether
   they are closed by user API request or their parent tunnel is
   closing.

 * refactor l2tp_ppp pppox socket / session relationship to have the
   session keep the socket alive, not the other way around. Previously
   the pppox socket held a ref on the session, which complicated
   session delete by having to go through the pppox socket destructor.

 * free sessions and pppox sockets by rcu.

 * fix a possible tunnel refcount underflow.

 * avoid using rcu_barrier in net exit handler.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 0a658d08 5dfa598b
......@@ -1534,6 +1534,7 @@ void ip_flush_pending_frames(struct sock *sk)
{
__ip_flush_pending_frames(sk, &sk->sk_write_queue, &inet_sk(sk)->cork.base);
}
EXPORT_SYMBOL_GPL(ip_flush_pending_frames);
struct sk_buff *ip_make_skb(struct sock *sk,
struct flowi4 *fl4,
......
This diff is collapsed.
......@@ -16,7 +16,6 @@
#endif
/* Random numbers used for internal consistency checks of tunnel and session structures */
#define L2TP_TUNNEL_MAGIC 0x42114DDA
#define L2TP_SESSION_MAGIC 0x0C04EB7D
struct sk_buff;
......@@ -67,6 +66,7 @@ struct l2tp_session_coll_list {
struct l2tp_session {
int magic; /* should be L2TP_SESSION_MAGIC */
long dead;
struct rcu_head rcu;
struct l2tp_tunnel *tunnel; /* back pointer to tunnel context */
u32 session_id;
......@@ -103,6 +103,7 @@ struct l2tp_session {
int reorder_skip; /* set if skip to next nr */
enum l2tp_pwtype pwtype;
struct l2tp_stats stats;
struct work_struct del_work;
/* Session receive handler for data packets.
* Each pseudowire implementation should implement this callback in order to
......@@ -155,8 +156,6 @@ struct l2tp_tunnel_cfg {
*/
#define L2TP_TUNNEL_NAME_MAX 20
struct l2tp_tunnel {
int magic; /* Should be L2TP_TUNNEL_MAGIC */
unsigned long dead;
struct rcu_head rcu;
......@@ -176,7 +175,6 @@ struct l2tp_tunnel {
struct net *l2tp_net; /* the net we belong to */
refcount_t ref_count;
void (*old_sk_destruct)(struct sock *sk);
struct sock *sock; /* parent socket */
int fd; /* parent fd, if tunnel socket was created
* by userspace
......@@ -260,7 +258,8 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb);
/* Transmit path helpers for sending packets over the tunnel socket. */
void l2tp_session_set_header_len(struct l2tp_session *session, int version);
void l2tp_session_set_header_len(struct l2tp_session *session, int version,
enum l2tp_encap_type encap);
int l2tp_xmit_skb(struct l2tp_session *session, struct sk_buff *skb);
/* Pseudowire management.
......@@ -273,10 +272,7 @@ void l2tp_nl_unregister_ops(enum l2tp_pwtype pw_type);
/* IOCTL helper for IP encap modules. */
int l2tp_ioctl(struct sock *sk, int cmd, int *karg);
/* Extract the tunnel structure from a socket's sk_user_data pointer,
* validating the tunnel magic feather.
*/
struct l2tp_tunnel *l2tp_sk_to_tunnel(struct sock *sk);
struct l2tp_tunnel *l2tp_sk_to_tunnel(const struct sock *sk);
static inline int l2tp_get_l2specific_len(struct l2tp_session *session)
{
......
......@@ -322,7 +322,7 @@ static int l2tp_eth_create(struct net *net, struct l2tp_tunnel *tunnel,
l2tp_session_dec_refcount(session);
free_netdev(dev);
err_sess:
kfree(session);
l2tp_session_dec_refcount(session);
err:
return rc;
}
......
......@@ -235,14 +235,17 @@ static void l2tp_ip_close(struct sock *sk, long timeout)
static void l2tp_ip_destroy_sock(struct sock *sk)
{
struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk);
struct sk_buff *skb;
struct l2tp_tunnel *tunnel;
while ((skb = __skb_dequeue_tail(&sk->sk_write_queue)) != NULL)
kfree_skb(skb);
lock_sock(sk);
ip_flush_pending_frames(sk);
release_sock(sk);
if (tunnel)
tunnel = l2tp_sk_to_tunnel(sk);
if (tunnel) {
l2tp_tunnel_delete(tunnel);
l2tp_tunnel_dec_refcount(tunnel);
}
}
static int l2tp_ip_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len)
......
......@@ -246,14 +246,17 @@ static void l2tp_ip6_close(struct sock *sk, long timeout)
static void l2tp_ip6_destroy_sock(struct sock *sk)
{
struct l2tp_tunnel *tunnel = l2tp_sk_to_tunnel(sk);
struct l2tp_tunnel *tunnel;
lock_sock(sk);
ip6_flush_pending_frames(sk);
release_sock(sk);
if (tunnel)
tunnel = l2tp_sk_to_tunnel(sk);
if (tunnel) {
l2tp_tunnel_delete(tunnel);
l2tp_tunnel_dec_refcount(tunnel);
}
}
static int l2tp_ip6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len)
......
......@@ -692,8 +692,10 @@ static int l2tp_nl_cmd_session_modify(struct sk_buff *skb, struct genl_info *inf
session->recv_seq = nla_get_u8(info->attrs[L2TP_ATTR_RECV_SEQ]);
if (info->attrs[L2TP_ATTR_SEND_SEQ]) {
struct l2tp_tunnel *tunnel = session->tunnel;
session->send_seq = nla_get_u8(info->attrs[L2TP_ATTR_SEND_SEQ]);
l2tp_session_set_header_len(session, session->tunnel->version);
l2tp_session_set_header_len(session, tunnel->version, tunnel->encap);
}
if (info->attrs[L2TP_ATTR_LNS_MODE])
......
......@@ -119,7 +119,6 @@ struct pppol2tp_session {
struct mutex sk_lock; /* Protects .sk */
struct sock __rcu *sk; /* Pointer to the session PPPoX socket */
struct sock *__sk; /* Copy of .sk, for cleanup */
struct rcu_head rcu; /* For asynchronous release */
};
static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb);
......@@ -157,20 +156,16 @@ static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk)
if (!sk)
return NULL;
sock_hold(sk);
session = (struct l2tp_session *)(sk->sk_user_data);
if (!session) {
sock_put(sk);
goto out;
}
if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) {
session = NULL;
sock_put(sk);
goto out;
rcu_read_lock();
session = rcu_dereference_sk_user_data(sk);
if (session && refcount_inc_not_zero(&session->ref_count)) {
rcu_read_unlock();
WARN_ON_ONCE(session->magic != L2TP_SESSION_MAGIC);
return session;
}
rcu_read_unlock();
out:
return session;
return NULL;
}
/*****************************************************************************
......@@ -318,12 +313,12 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m,
l2tp_xmit_skb(session, skb);
local_bh_enable();
sock_put(sk);
l2tp_session_dec_refcount(session);
return total_len;
error_put_sess:
sock_put(sk);
l2tp_session_dec_refcount(session);
error:
return error;
}
......@@ -377,12 +372,12 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
l2tp_xmit_skb(session, skb);
local_bh_enable();
sock_put(sk);
l2tp_session_dec_refcount(session);
return 1;
abort_put_sess:
sock_put(sk);
l2tp_session_dec_refcount(session);
abort:
/* Free the original skb */
kfree_skb(skb);
......@@ -393,28 +388,31 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
* Session (and tunnel control) socket create/destroy.
*****************************************************************************/
static void pppol2tp_put_sk(struct rcu_head *head)
{
struct pppol2tp_session *ps;
ps = container_of(head, typeof(*ps), rcu);
sock_put(ps->__sk);
}
/* Really kill the session socket. (Called from sock_put() if
* refcnt == 0.)
*/
static void pppol2tp_session_destruct(struct sock *sk)
{
struct l2tp_session *session = sk->sk_user_data;
skb_queue_purge(&sk->sk_receive_queue);
skb_queue_purge(&sk->sk_write_queue);
}
if (session) {
sk->sk_user_data = NULL;
if (WARN_ON(session->magic != L2TP_SESSION_MAGIC))
return;
static void pppol2tp_session_close(struct l2tp_session *session)
{
struct pppol2tp_session *ps;
ps = l2tp_session_priv(session);
mutex_lock(&ps->sk_lock);
ps->__sk = rcu_dereference_protected(ps->sk,
lockdep_is_held(&ps->sk_lock));
RCU_INIT_POINTER(ps->sk, NULL);
mutex_unlock(&ps->sk_lock);
if (ps->__sk) {
/* detach socket */
rcu_assign_sk_user_data(ps->__sk, NULL);
sock_put(ps->__sk);
/* drop ref taken when we referenced socket via sk_user_data */
l2tp_session_dec_refcount(session);
}
}
......@@ -444,30 +442,13 @@ static int pppol2tp_release(struct socket *sock)
session = pppol2tp_sock_to_session(sk);
if (session) {
struct pppol2tp_session *ps;
l2tp_session_delete(session);
ps = l2tp_session_priv(session);
mutex_lock(&ps->sk_lock);
ps->__sk = rcu_dereference_protected(ps->sk,
lockdep_is_held(&ps->sk_lock));
RCU_INIT_POINTER(ps->sk, NULL);
mutex_unlock(&ps->sk_lock);
call_rcu(&ps->rcu, pppol2tp_put_sk);
/* Rely on the sock_put() call at the end of the function for
* dropping the reference held by pppol2tp_sock_to_session().
* The last reference will be dropped by pppol2tp_put_sk().
*/
/* drop ref taken by pppol2tp_sock_to_session */
l2tp_session_dec_refcount(session);
}
release_sock(sk);
/* This will delete the session context via
* pppol2tp_session_destruct() if the socket's refcnt drops to
* zero.
*/
sock_put(sk);
return 0;
......@@ -506,6 +487,7 @@ static int pppol2tp_create(struct net *net, struct socket *sock, int kern)
goto out;
sock_init_data(sock, sk);
sock_set_flag(sk, SOCK_RCU_FREE);
sock->state = SS_UNCONNECTED;
sock->ops = &pppol2tp_ops;
......@@ -542,6 +524,7 @@ static void pppol2tp_session_init(struct l2tp_session *session)
struct pppol2tp_session *ps;
session->recv_skb = pppol2tp_recv;
session->session_close = pppol2tp_session_close;
if (IS_ENABLED(CONFIG_L2TP_DEBUGFS))
session->show = pppol2tp_show;
......@@ -787,6 +770,8 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
goto end;
}
drop_refcnt = true;
pppol2tp_session_init(session);
ps = l2tp_session_priv(session);
l2tp_session_inc_refcount(session);
......@@ -795,10 +780,10 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
error = l2tp_session_register(session, tunnel);
if (error < 0) {
mutex_unlock(&ps->sk_lock);
kfree(session);
l2tp_session_dec_refcount(session);
goto end;
}
drop_refcnt = true;
new_session = true;
}
......@@ -830,12 +815,13 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
out_no_ppp:
/* This is how we get the session context from the socket. */
sk->sk_user_data = session;
sock_hold(sk);
rcu_assign_sk_user_data(sk, session);
rcu_assign_pointer(ps->sk, sk);
mutex_unlock(&ps->sk_lock);
/* Keep the reference we've grabbed on the session: sk doesn't expect
* the session to disappear. pppol2tp_session_destruct() is responsible
* the session to disappear. pppol2tp_session_close() is responsible
* for dropping it.
*/
drop_refcnt = false;
......@@ -891,7 +877,7 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel,
return 0;
err_sess:
kfree(session);
l2tp_session_dec_refcount(session);
err:
return error;
}
......@@ -1002,7 +988,7 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr,
error = len;
sock_put(sk);
l2tp_session_dec_refcount(session);
end:
return error;
}
......@@ -1205,7 +1191,8 @@ static int pppol2tp_session_setsockopt(struct sock *sk,
po->chan.hdrlen = val ? PPPOL2TP_L2TP_HDR_SIZE_SEQ :
PPPOL2TP_L2TP_HDR_SIZE_NOSEQ;
}
l2tp_session_set_header_len(session, session->tunnel->version);
l2tp_session_set_header_len(session, session->tunnel->version,
session->tunnel->encap);
break;
case PPPOL2TP_SO_LNSMODE:
......@@ -1274,7 +1261,7 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname,
err = pppol2tp_session_setsockopt(sk, session, optname, val);
}
sock_put(sk);
l2tp_session_dec_refcount(session);
end:
return err;
}
......@@ -1395,7 +1382,7 @@ static int pppol2tp_getsockopt(struct socket *sock, int level, int optname,
err = 0;
end_put_sess:
sock_put(sk);
l2tp_session_dec_refcount(session);
end:
return err;
}
......@@ -1513,7 +1500,7 @@ static void pppol2tp_seq_tunnel_show(struct seq_file *m, void *v)
seq_printf(m, "\nTUNNEL '%s', %c %d\n",
tunnel->name,
(tunnel == tunnel->sock->sk_user_data) ? 'Y' : 'N',
tunnel->sock ? 'Y' : 'N',
refcount_read(&tunnel->ref_count) - 1);
seq_printf(m, " %08x %ld/%ld/%ld %ld/%ld/%ld\n",
0,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment