Commit 77146b5d authored by David S. Miller's avatar David S. Miller

Merge branch 'l2tp-tunnel-refs'

Guillaume Nault says:

====================
l2tp: fix some l2tp_tunnel_find() issues in l2tp_netlink

Since l2tp_tunnel_find() doesn't take a reference on the tunnel it
returns, its users are almost guaranteed to be racy.

This series defines l2tp_tunnel_get() which can be used as a safe
replacement, and converts some of l2tp_tunnel_find() users in the
l2tp_netlink module.

Other users often combine this issue with other more or less subtle
races. They will be fixed incrementally in followup series.
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 9ee369a4 e702c120
...@@ -113,7 +113,6 @@ struct l2tp_net { ...@@ -113,7 +113,6 @@ struct l2tp_net {
spinlock_t l2tp_session_hlist_lock; spinlock_t l2tp_session_hlist_lock;
}; };
static void l2tp_tunnel_free(struct l2tp_tunnel *tunnel);
static inline struct l2tp_tunnel *l2tp_tunnel(struct sock *sk) static inline struct l2tp_tunnel *l2tp_tunnel(struct sock *sk)
{ {
...@@ -127,39 +126,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net) ...@@ -127,39 +126,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net)
return net_generic(net, l2tp_net_id); return net_generic(net, l2tp_net_id);
} }
/* Tunnel reference counts. Incremented per session that is added to
* the tunnel.
*/
static inline void l2tp_tunnel_inc_refcount_1(struct l2tp_tunnel *tunnel)
{
refcount_inc(&tunnel->ref_count);
}
static inline void l2tp_tunnel_dec_refcount_1(struct l2tp_tunnel *tunnel)
{
if (refcount_dec_and_test(&tunnel->ref_count))
l2tp_tunnel_free(tunnel);
}
#ifdef L2TP_REFCNT_DEBUG
#define l2tp_tunnel_inc_refcount(_t) \
do { \
pr_debug("l2tp_tunnel_inc_refcount: %s:%d %s: cnt=%d\n", \
__func__, __LINE__, (_t)->name, \
refcount_read(&_t->ref_count)); \
l2tp_tunnel_inc_refcount_1(_t); \
} while (0)
#define l2tp_tunnel_dec_refcount(_t) \
do { \
pr_debug("l2tp_tunnel_dec_refcount: %s:%d %s: cnt=%d\n", \
__func__, __LINE__, (_t)->name, \
refcount_read(&_t->ref_count)); \
l2tp_tunnel_dec_refcount_1(_t); \
} while (0)
#else
#define l2tp_tunnel_inc_refcount(t) l2tp_tunnel_inc_refcount_1(t)
#define l2tp_tunnel_dec_refcount(t) l2tp_tunnel_dec_refcount_1(t)
#endif
/* Session hash global list for L2TPv3. /* Session hash global list for L2TPv3.
* The session_id SHOULD be random according to RFC3931, but several * The session_id SHOULD be random according to RFC3931, but several
* L2TP implementations use incrementing session_ids. So we do a real * L2TP implementations use incrementing session_ids. So we do a real
...@@ -229,6 +195,27 @@ l2tp_session_id_hash(struct l2tp_tunnel *tunnel, u32 session_id) ...@@ -229,6 +195,27 @@ l2tp_session_id_hash(struct l2tp_tunnel *tunnel, u32 session_id)
return &tunnel->session_hlist[hash_32(session_id, L2TP_HASH_BITS)]; return &tunnel->session_hlist[hash_32(session_id, L2TP_HASH_BITS)];
} }
/* Lookup a tunnel. A new reference is held on the returned tunnel. */
struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id)
{
const struct l2tp_net *pn = l2tp_pernet(net);
struct l2tp_tunnel *tunnel;
rcu_read_lock_bh();
list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) {
if (tunnel->tunnel_id == tunnel_id) {
l2tp_tunnel_inc_refcount(tunnel);
rcu_read_unlock_bh();
return tunnel;
}
}
rcu_read_unlock_bh();
return NULL;
}
EXPORT_SYMBOL_GPL(l2tp_tunnel_get);
/* Lookup a session. A new reference is held on the returned session. /* Lookup a session. A new reference is held on the returned session.
* Optionally calls session->ref() too if do_ref is true. * Optionally calls session->ref() too if do_ref is true.
*/ */
...@@ -1348,17 +1335,6 @@ static void l2tp_udp_encap_destroy(struct sock *sk) ...@@ -1348,17 +1335,6 @@ static void l2tp_udp_encap_destroy(struct sock *sk)
} }
} }
/* Really kill the tunnel.
* Come here only when all sessions have been cleared from the tunnel.
*/
static void l2tp_tunnel_free(struct l2tp_tunnel *tunnel)
{
BUG_ON(refcount_read(&tunnel->ref_count) != 0);
BUG_ON(tunnel->sock != NULL);
l2tp_info(tunnel, L2TP_MSG_CONTROL, "%s: free...\n", tunnel->name);
kfree_rcu(tunnel, rcu);
}
/* Workqueue tunnel deletion function */ /* Workqueue tunnel deletion function */
static void l2tp_tunnel_del_work(struct work_struct *work) static void l2tp_tunnel_del_work(struct work_struct *work)
{ {
......
...@@ -231,6 +231,8 @@ static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk) ...@@ -231,6 +231,8 @@ static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk)
return tunnel; return tunnel;
} }
struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id);
struct l2tp_session *l2tp_session_get(const struct net *net, struct l2tp_session *l2tp_session_get(const struct net *net,
struct l2tp_tunnel *tunnel, struct l2tp_tunnel *tunnel,
u32 session_id, bool do_ref); u32 session_id, bool do_ref);
...@@ -269,6 +271,17 @@ int l2tp_nl_register_ops(enum l2tp_pwtype pw_type, ...@@ -269,6 +271,17 @@ int l2tp_nl_register_ops(enum l2tp_pwtype pw_type,
void l2tp_nl_unregister_ops(enum l2tp_pwtype pw_type); void l2tp_nl_unregister_ops(enum l2tp_pwtype pw_type);
int l2tp_ioctl(struct sock *sk, int cmd, unsigned long arg); int l2tp_ioctl(struct sock *sk, int cmd, unsigned long arg);
static inline void l2tp_tunnel_inc_refcount(struct l2tp_tunnel *tunnel)
{
refcount_inc(&tunnel->ref_count);
}
static inline void l2tp_tunnel_dec_refcount(struct l2tp_tunnel *tunnel)
{
if (refcount_dec_and_test(&tunnel->ref_count))
kfree_rcu(tunnel, rcu);
}
/* Session reference counts. Incremented when code obtains a reference /* Session reference counts. Incremented when code obtains a reference
* to a session. * to a session.
*/ */
......
...@@ -65,10 +65,12 @@ static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info, ...@@ -65,10 +65,12 @@ static struct l2tp_session *l2tp_nl_session_get(struct genl_info *info,
(info->attrs[L2TP_ATTR_CONN_ID])) { (info->attrs[L2TP_ATTR_CONN_ID])) {
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]); session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
tunnel = l2tp_tunnel_find(net, tunnel_id); tunnel = l2tp_tunnel_get(net, tunnel_id);
if (tunnel) if (tunnel) {
session = l2tp_session_get(net, tunnel, session_id, session = l2tp_session_get(net, tunnel, session_id,
do_ref); do_ref);
l2tp_tunnel_dec_refcount(tunnel);
}
} }
return session; return session;
...@@ -271,8 +273,8 @@ static int l2tp_nl_cmd_tunnel_delete(struct sk_buff *skb, struct genl_info *info ...@@ -271,8 +273,8 @@ static int l2tp_nl_cmd_tunnel_delete(struct sk_buff *skb, struct genl_info *info
} }
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
tunnel = l2tp_tunnel_find(net, tunnel_id); tunnel = l2tp_tunnel_get(net, tunnel_id);
if (tunnel == NULL) { if (!tunnel) {
ret = -ENODEV; ret = -ENODEV;
goto out; goto out;
} }
...@@ -282,6 +284,8 @@ static int l2tp_nl_cmd_tunnel_delete(struct sk_buff *skb, struct genl_info *info ...@@ -282,6 +284,8 @@ static int l2tp_nl_cmd_tunnel_delete(struct sk_buff *skb, struct genl_info *info
(void) l2tp_tunnel_delete(tunnel); (void) l2tp_tunnel_delete(tunnel);
l2tp_tunnel_dec_refcount(tunnel);
out: out:
return ret; return ret;
} }
...@@ -299,8 +303,8 @@ static int l2tp_nl_cmd_tunnel_modify(struct sk_buff *skb, struct genl_info *info ...@@ -299,8 +303,8 @@ static int l2tp_nl_cmd_tunnel_modify(struct sk_buff *skb, struct genl_info *info
} }
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
tunnel = l2tp_tunnel_find(net, tunnel_id); tunnel = l2tp_tunnel_get(net, tunnel_id);
if (tunnel == NULL) { if (!tunnel) {
ret = -ENODEV; ret = -ENODEV;
goto out; goto out;
} }
...@@ -311,6 +315,8 @@ static int l2tp_nl_cmd_tunnel_modify(struct sk_buff *skb, struct genl_info *info ...@@ -311,6 +315,8 @@ static int l2tp_nl_cmd_tunnel_modify(struct sk_buff *skb, struct genl_info *info
ret = l2tp_tunnel_notify(&l2tp_nl_family, info, ret = l2tp_tunnel_notify(&l2tp_nl_family, info,
tunnel, L2TP_CMD_TUNNEL_MODIFY); tunnel, L2TP_CMD_TUNNEL_MODIFY);
l2tp_tunnel_dec_refcount(tunnel);
out: out:
return ret; return ret;
} }
...@@ -438,34 +444,37 @@ static int l2tp_nl_cmd_tunnel_get(struct sk_buff *skb, struct genl_info *info) ...@@ -438,34 +444,37 @@ static int l2tp_nl_cmd_tunnel_get(struct sk_buff *skb, struct genl_info *info)
if (!info->attrs[L2TP_ATTR_CONN_ID]) { if (!info->attrs[L2TP_ATTR_CONN_ID]) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto err;
} }
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
tunnel = l2tp_tunnel_find(net, tunnel_id);
if (tunnel == NULL) {
ret = -ENODEV;
goto out;
}
msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
if (!msg) { if (!msg) {
ret = -ENOMEM; ret = -ENOMEM;
goto out; goto err;
}
tunnel = l2tp_tunnel_get(net, tunnel_id);
if (!tunnel) {
ret = -ENODEV;
goto err_nlmsg;
} }
ret = l2tp_nl_tunnel_send(msg, info->snd_portid, info->snd_seq, ret = l2tp_nl_tunnel_send(msg, info->snd_portid, info->snd_seq,
NLM_F_ACK, tunnel, L2TP_CMD_TUNNEL_GET); NLM_F_ACK, tunnel, L2TP_CMD_TUNNEL_GET);
if (ret < 0) if (ret < 0)
goto err_out; goto err_nlmsg_tunnel;
l2tp_tunnel_dec_refcount(tunnel);
return genlmsg_unicast(net, msg, info->snd_portid); return genlmsg_unicast(net, msg, info->snd_portid);
err_out: err_nlmsg_tunnel:
l2tp_tunnel_dec_refcount(tunnel);
err_nlmsg:
nlmsg_free(msg); nlmsg_free(msg);
err:
out:
return ret; return ret;
} }
...@@ -509,8 +518,9 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf ...@@ -509,8 +518,9 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
ret = -EINVAL; ret = -EINVAL;
goto out; goto out;
} }
tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]); tunnel_id = nla_get_u32(info->attrs[L2TP_ATTR_CONN_ID]);
tunnel = l2tp_tunnel_find(net, tunnel_id); tunnel = l2tp_tunnel_get(net, tunnel_id);
if (!tunnel) { if (!tunnel) {
ret = -ENODEV; ret = -ENODEV;
goto out; goto out;
...@@ -518,24 +528,24 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf ...@@ -518,24 +528,24 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
if (!info->attrs[L2TP_ATTR_SESSION_ID]) { if (!info->attrs[L2TP_ATTR_SESSION_ID]) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out_tunnel;
} }
session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]); session_id = nla_get_u32(info->attrs[L2TP_ATTR_SESSION_ID]);
if (!info->attrs[L2TP_ATTR_PEER_SESSION_ID]) { if (!info->attrs[L2TP_ATTR_PEER_SESSION_ID]) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out_tunnel;
} }
peer_session_id = nla_get_u32(info->attrs[L2TP_ATTR_PEER_SESSION_ID]); peer_session_id = nla_get_u32(info->attrs[L2TP_ATTR_PEER_SESSION_ID]);
if (!info->attrs[L2TP_ATTR_PW_TYPE]) { if (!info->attrs[L2TP_ATTR_PW_TYPE]) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out_tunnel;
} }
cfg.pw_type = nla_get_u16(info->attrs[L2TP_ATTR_PW_TYPE]); cfg.pw_type = nla_get_u16(info->attrs[L2TP_ATTR_PW_TYPE]);
if (cfg.pw_type >= __L2TP_PWTYPE_MAX) { if (cfg.pw_type >= __L2TP_PWTYPE_MAX) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out_tunnel;
} }
if (tunnel->version > 2) { if (tunnel->version > 2) {
...@@ -557,7 +567,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf ...@@ -557,7 +567,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
u16 len = nla_len(info->attrs[L2TP_ATTR_COOKIE]); u16 len = nla_len(info->attrs[L2TP_ATTR_COOKIE]);
if (len > 8) { if (len > 8) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out_tunnel;
} }
cfg.cookie_len = len; cfg.cookie_len = len;
memcpy(&cfg.cookie[0], nla_data(info->attrs[L2TP_ATTR_COOKIE]), len); memcpy(&cfg.cookie[0], nla_data(info->attrs[L2TP_ATTR_COOKIE]), len);
...@@ -566,7 +576,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf ...@@ -566,7 +576,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
u16 len = nla_len(info->attrs[L2TP_ATTR_PEER_COOKIE]); u16 len = nla_len(info->attrs[L2TP_ATTR_PEER_COOKIE]);
if (len > 8) { if (len > 8) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out_tunnel;
} }
cfg.peer_cookie_len = len; cfg.peer_cookie_len = len;
memcpy(&cfg.peer_cookie[0], nla_data(info->attrs[L2TP_ATTR_PEER_COOKIE]), len); memcpy(&cfg.peer_cookie[0], nla_data(info->attrs[L2TP_ATTR_PEER_COOKIE]), len);
...@@ -609,7 +619,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf ...@@ -609,7 +619,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
if ((l2tp_nl_cmd_ops[cfg.pw_type] == NULL) || if ((l2tp_nl_cmd_ops[cfg.pw_type] == NULL) ||
(l2tp_nl_cmd_ops[cfg.pw_type]->session_create == NULL)) { (l2tp_nl_cmd_ops[cfg.pw_type]->session_create == NULL)) {
ret = -EPROTONOSUPPORT; ret = -EPROTONOSUPPORT;
goto out; goto out_tunnel;
} }
/* Check that pseudowire-specific params are present */ /* Check that pseudowire-specific params are present */
...@@ -619,7 +629,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf ...@@ -619,7 +629,7 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
case L2TP_PWTYPE_ETH_VLAN: case L2TP_PWTYPE_ETH_VLAN:
if (!info->attrs[L2TP_ATTR_VLAN_ID]) { if (!info->attrs[L2TP_ATTR_VLAN_ID]) {
ret = -EINVAL; ret = -EINVAL;
goto out; goto out_tunnel;
} }
break; break;
case L2TP_PWTYPE_ETH: case L2TP_PWTYPE_ETH:
...@@ -647,6 +657,8 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf ...@@ -647,6 +657,8 @@ static int l2tp_nl_cmd_session_create(struct sk_buff *skb, struct genl_info *inf
} }
} }
out_tunnel:
l2tp_tunnel_dec_refcount(tunnel);
out: out:
return ret; return ret;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment