Commit 3f326a82 authored by Paolo Abeni's avatar Paolo Abeni Committed by David S. Miller

mptcp: change the mpc check helper to return a sk

After the previous patch the __mptcp_nmpc_socket helper is used
only to ensure that the MPTCP socket is a suitable status - that
is, the mptcp capable handshake is not started yet.

Change the return value to the relevant subflow sock, to finally
remove the last references to first subflow socket in the MPTCP stack.

As a bonus, we can get rid of a few local variables in different
functions.

No functional change intended.
Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
Reviewed-by: default avatarMat Martineau <martineau@kernel.org>
Signed-off-by: default avatarMatthieu Baerts <matthieu.baerts@tessares.net>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 3aa36249
...@@ -1007,7 +1007,6 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk, ...@@ -1007,7 +1007,6 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
int addrlen = sizeof(struct sockaddr_in); int addrlen = sizeof(struct sockaddr_in);
struct sockaddr_storage addr; struct sockaddr_storage addr;
struct sock *newsk, *ssk; struct sock *newsk, *ssk;
struct socket *ssock;
int backlog = 1024; int backlog = 1024;
int err; int err;
...@@ -1033,17 +1032,16 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk, ...@@ -1033,17 +1032,16 @@ static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
&mptcp_keys[is_ipv6]); &mptcp_keys[is_ipv6]);
lock_sock(newsk); lock_sock(newsk);
ssock = __mptcp_nmpc_socket(mptcp_sk(newsk)); ssk = __mptcp_nmpc_sk(mptcp_sk(newsk));
release_sock(newsk); release_sock(newsk);
if (IS_ERR(ssock)) if (IS_ERR(ssk))
return PTR_ERR(ssock); return PTR_ERR(ssk);
mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
#if IS_ENABLED(CONFIG_MPTCP_IPV6) #if IS_ENABLED(CONFIG_MPTCP_IPV6)
if (entry->addr.family == AF_INET6) if (entry->addr.family == AF_INET6)
addrlen = sizeof(struct sockaddr_in6); addrlen = sizeof(struct sockaddr_in6);
#endif #endif
ssk = mptcp_sk(newsk)->first;
if (ssk->sk_family == AF_INET) if (ssk->sk_family == AF_INET)
err = inet_bind_sk(ssk, (struct sockaddr *)&addr, addrlen); err = inet_bind_sk(ssk, (struct sockaddr *)&addr, addrlen);
#if IS_ENABLED(CONFIG_MPTCP_IPV6) #if IS_ENABLED(CONFIG_MPTCP_IPV6)
......
...@@ -109,7 +109,7 @@ static int __mptcp_socket_create(struct mptcp_sock *msk) ...@@ -109,7 +109,7 @@ static int __mptcp_socket_create(struct mptcp_sock *msk)
/* If the MPC handshake is not started, returns the first subflow, /* If the MPC handshake is not started, returns the first subflow,
* eventually allocating it. * eventually allocating it.
*/ */
struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk) struct sock *__mptcp_nmpc_sk(struct mptcp_sock *msk)
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
int ret; int ret;
...@@ -117,10 +117,7 @@ struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk) ...@@ -117,10 +117,7 @@ struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk)
if (!((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN))) if (!((1 << sk->sk_state) & (TCPF_CLOSE | TCPF_LISTEN)))
return ERR_PTR(-EINVAL); return ERR_PTR(-EINVAL);
if (!msk->subflow) { if (!msk->first) {
if (msk->first)
return ERR_PTR(-EINVAL);
ret = __mptcp_socket_create(msk); ret = __mptcp_socket_create(msk);
if (ret) if (ret)
return ERR_PTR(ret); return ERR_PTR(ret);
...@@ -128,7 +125,7 @@ struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk) ...@@ -128,7 +125,7 @@ struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk)
mptcp_sockopt_sync(msk, msk->first); mptcp_sockopt_sync(msk, msk->first);
} }
return msk->subflow; return msk->first;
} }
static void mptcp_drop(struct sock *sk, struct sk_buff *skb) static void mptcp_drop(struct sock *sk, struct sk_buff *skb)
...@@ -1643,7 +1640,6 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, ...@@ -1643,7 +1640,6 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg,
{ {
unsigned int saved_flags = msg->msg_flags; unsigned int saved_flags = msg->msg_flags;
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
struct socket *ssock;
struct sock *ssk; struct sock *ssk;
int ret; int ret;
...@@ -1654,9 +1650,9 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg, ...@@ -1654,9 +1650,9 @@ static int mptcp_sendmsg_fastopen(struct sock *sk, struct msghdr *msg,
* fastopen attempt, no need to check for additional subflow status. * fastopen attempt, no need to check for additional subflow status.
*/ */
if (msg->msg_flags & MSG_FASTOPEN) { if (msg->msg_flags & MSG_FASTOPEN) {
ssock = __mptcp_nmpc_socket(msk); ssk = __mptcp_nmpc_sk(msk);
if (IS_ERR(ssock)) if (IS_ERR(ssk))
return PTR_ERR(ssock); return PTR_ERR(ssk);
} }
if (!msk->first) if (!msk->first)
return -EINVAL; return -EINVAL;
...@@ -3577,16 +3573,14 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) ...@@ -3577,16 +3573,14 @@ static int mptcp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
{ {
struct mptcp_subflow_context *subflow; struct mptcp_subflow_context *subflow;
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
struct socket *ssock;
int err = -EINVAL; int err = -EINVAL;
struct sock *ssk; struct sock *ssk;
ssock = __mptcp_nmpc_socket(msk); ssk = __mptcp_nmpc_sk(msk);
if (IS_ERR(ssock)) if (IS_ERR(ssk))
return PTR_ERR(ssock); return PTR_ERR(ssk);
inet_sk_state_store(sk, TCP_SYN_SENT); inet_sk_state_store(sk, TCP_SYN_SENT);
ssk = msk->first;
subflow = mptcp_subflow_ctx(ssk); subflow = mptcp_subflow_ctx(ssk);
#ifdef CONFIG_TCP_MD5SIG #ifdef CONFIG_TCP_MD5SIG
/* no MPTCP if MD5SIG is enabled on this socket or we may run out of /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
...@@ -3682,17 +3676,15 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -3682,17 +3676,15 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
{ {
struct mptcp_sock *msk = mptcp_sk(sock->sk); struct mptcp_sock *msk = mptcp_sk(sock->sk);
struct sock *ssk, *sk = sock->sk; struct sock *ssk, *sk = sock->sk;
struct socket *ssock;
int err = -EINVAL; int err = -EINVAL;
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_nmpc_socket(msk); ssk = __mptcp_nmpc_sk(msk);
if (IS_ERR(ssock)) { if (IS_ERR(ssk)) {
err = PTR_ERR(ssock); err = PTR_ERR(ssk);
goto unlock; goto unlock;
} }
ssk = msk->first;
if (sk->sk_family == AF_INET) if (sk->sk_family == AF_INET)
err = inet_bind_sk(ssk, uaddr, addr_len); err = inet_bind_sk(ssk, uaddr, addr_len);
#if IS_ENABLED(CONFIG_MPTCP_IPV6) #if IS_ENABLED(CONFIG_MPTCP_IPV6)
...@@ -3711,7 +3703,6 @@ static int mptcp_listen(struct socket *sock, int backlog) ...@@ -3711,7 +3703,6 @@ static int mptcp_listen(struct socket *sock, int backlog)
{ {
struct mptcp_sock *msk = mptcp_sk(sock->sk); struct mptcp_sock *msk = mptcp_sk(sock->sk);
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct socket *ssock;
struct sock *ssk; struct sock *ssk;
int err; int err;
...@@ -3723,13 +3714,12 @@ static int mptcp_listen(struct socket *sock, int backlog) ...@@ -3723,13 +3714,12 @@ static int mptcp_listen(struct socket *sock, int backlog)
if (sock->state != SS_UNCONNECTED || sock->type != SOCK_STREAM) if (sock->state != SS_UNCONNECTED || sock->type != SOCK_STREAM)
goto unlock; goto unlock;
ssock = __mptcp_nmpc_socket(msk); ssk = __mptcp_nmpc_sk(msk);
if (IS_ERR(ssock)) { if (IS_ERR(ssk)) {
err = PTR_ERR(ssock); err = PTR_ERR(ssk);
goto unlock; goto unlock;
} }
ssk = msk->first;
inet_sk_state_store(sk, TCP_LISTEN); inet_sk_state_store(sk, TCP_LISTEN);
sock_set_flag(sk, SOCK_RCU_FREE); sock_set_flag(sk, SOCK_RCU_FREE);
......
...@@ -640,7 +640,7 @@ void __mptcp_subflow_send_ack(struct sock *ssk); ...@@ -640,7 +640,7 @@ void __mptcp_subflow_send_ack(struct sock *ssk);
void mptcp_subflow_reset(struct sock *ssk); void mptcp_subflow_reset(struct sock *ssk);
void mptcp_subflow_queue_clean(struct sock *sk, struct sock *ssk); void mptcp_subflow_queue_clean(struct sock *sk, struct sock *ssk);
void mptcp_sock_graft(struct sock *sk, struct socket *parent); void mptcp_sock_graft(struct sock *sk, struct socket *parent);
struct socket *__mptcp_nmpc_socket(struct mptcp_sock *msk); struct sock *__mptcp_nmpc_sk(struct mptcp_sock *msk);
bool __mptcp_close(struct sock *sk, long timeout); bool __mptcp_close(struct sock *sk, long timeout);
void mptcp_cancel_work(struct sock *sk); void mptcp_cancel_work(struct sock *sk);
void __mptcp_unaccepted_force_close(struct sock *sk); void __mptcp_unaccepted_force_close(struct sock *sk);
......
...@@ -292,7 +292,6 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname, ...@@ -292,7 +292,6 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname,
sockptr_t optval, unsigned int optlen) sockptr_t optval, unsigned int optlen)
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
struct socket *ssock;
struct sock *ssk; struct sock *ssk;
int ret; int ret;
...@@ -302,13 +301,12 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname, ...@@ -302,13 +301,12 @@ static int mptcp_setsockopt_sol_socket(struct mptcp_sock *msk, int optname,
case SO_BINDTODEVICE: case SO_BINDTODEVICE:
case SO_BINDTOIFINDEX: case SO_BINDTOIFINDEX:
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_nmpc_socket(msk); ssk = __mptcp_nmpc_sk(msk);
if (IS_ERR(ssock)) { if (IS_ERR(ssk)) {
release_sock(sk); release_sock(sk);
return PTR_ERR(ssock); return PTR_ERR(ssk);
} }
ssk = msk->first;
ret = sk_setsockopt(ssk, SOL_SOCKET, optname, optval, optlen); ret = sk_setsockopt(ssk, SOL_SOCKET, optname, optval, optlen);
if (ret == 0) { if (ret == 0) {
if (optname == SO_REUSEPORT) if (optname == SO_REUSEPORT)
...@@ -392,7 +390,6 @@ static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname, ...@@ -392,7 +390,6 @@ static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname,
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
int ret = -EOPNOTSUPP; int ret = -EOPNOTSUPP;
struct socket *ssock;
struct sock *ssk; struct sock *ssk;
switch (optname) { switch (optname) {
...@@ -400,13 +397,12 @@ static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname, ...@@ -400,13 +397,12 @@ static int mptcp_setsockopt_v6(struct mptcp_sock *msk, int optname,
case IPV6_TRANSPARENT: case IPV6_TRANSPARENT:
case IPV6_FREEBIND: case IPV6_FREEBIND:
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_nmpc_socket(msk); ssk = __mptcp_nmpc_sk(msk);
if (IS_ERR(ssock)) { if (IS_ERR(ssk)) {
release_sock(sk); release_sock(sk);
return PTR_ERR(ssock); return PTR_ERR(ssk);
} }
ssk = msk->first;
ret = tcp_setsockopt(ssk, SOL_IPV6, optname, optval, optlen); ret = tcp_setsockopt(ssk, SOL_IPV6, optname, optval, optlen);
if (ret != 0) { if (ret != 0) {
release_sock(sk); release_sock(sk);
...@@ -689,7 +685,7 @@ static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int o ...@@ -689,7 +685,7 @@ static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int o
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
struct inet_sock *issk; struct inet_sock *issk;
struct socket *ssock; struct sock *ssk;
int err; int err;
err = ip_setsockopt(sk, SOL_IP, optname, optval, optlen); err = ip_setsockopt(sk, SOL_IP, optname, optval, optlen);
...@@ -698,13 +694,13 @@ static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int o ...@@ -698,13 +694,13 @@ static int mptcp_setsockopt_sol_ip_set_transparent(struct mptcp_sock *msk, int o
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_nmpc_socket(msk); ssk = __mptcp_nmpc_sk(msk);
if (IS_ERR(ssock)) { if (IS_ERR(ssk)) {
release_sock(sk); release_sock(sk);
return PTR_ERR(ssock); return PTR_ERR(ssk);
} }
issk = inet_sk(msk->first); issk = inet_sk(ssk);
switch (optname) { switch (optname) {
case IP_FREEBIND: case IP_FREEBIND:
...@@ -767,18 +763,18 @@ static int mptcp_setsockopt_first_sf_only(struct mptcp_sock *msk, int level, int ...@@ -767,18 +763,18 @@ static int mptcp_setsockopt_first_sf_only(struct mptcp_sock *msk, int level, int
sockptr_t optval, unsigned int optlen) sockptr_t optval, unsigned int optlen)
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
struct socket *sock; struct sock *ssk;
int ret; int ret;
/* Limit to first subflow, before the connection establishment */ /* Limit to first subflow, before the connection establishment */
lock_sock(sk); lock_sock(sk);
sock = __mptcp_nmpc_socket(msk); ssk = __mptcp_nmpc_sk(msk);
if (IS_ERR(sock)) { if (IS_ERR(ssk)) {
ret = PTR_ERR(sock); ret = PTR_ERR(ssk);
goto unlock; goto unlock;
} }
ret = tcp_setsockopt(sock->sk, level, optname, optval, optlen); ret = tcp_setsockopt(ssk, level, optname, optval, optlen);
unlock: unlock:
release_sock(sk); release_sock(sk);
...@@ -868,7 +864,6 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int ...@@ -868,7 +864,6 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int
char __user *optval, int __user *optlen) char __user *optval, int __user *optlen)
{ {
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
struct socket *ssock;
struct sock *ssk; struct sock *ssk;
int ret; int ret;
...@@ -879,9 +874,9 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int ...@@ -879,9 +874,9 @@ static int mptcp_getsockopt_first_sf_only(struct mptcp_sock *msk, int level, int
goto out; goto out;
} }
ssock = __mptcp_nmpc_socket(msk); ssk = __mptcp_nmpc_sk(msk);
if (IS_ERR(ssock)) { if (IS_ERR(ssk)) {
ret = PTR_ERR(ssock); ret = PTR_ERR(ssk);
goto out; goto out;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment