Commit 13f1555c authored by David S. Miller's avatar David S. Miller

Merge branch 'MPTCP-improve-fallback-to-TCP'

Davide Caratti says:

====================
MPTCP: improve fallback to TCP

there are situations where MPTCP sockets should fall-back to regular TCP:
this series reworks the fallback code to pursue the following goals:

1) cleanup the non fallback code, removing most of 'if (<fallback>)' in
   the data path
2) improve performance for non-fallback sockets, avoiding locks in poll()

further work will also leverage on this changes to achieve:

a) more consistent behavior of gestockopt()/setsockopt() on passive sockets
   after fallback
b) support for "infinite maps" as per RFC8684, section 3.7

the series is made of the following items:

- patch 1 lets sendmsg() / recvmsg() / poll() use the main socket also
  after fallback
- patch 2 fixes 'simultaneous connect' scenario after fallback. The
  problem was present also before the rework, but the fix is much easier
  to implement after patch 1
- patch 3, 4, 5 are clean-ups for code that is no more needed after the
  fallback rework
- patch 6 fixes a race condition between close() and poll(). The problem
  was theoretically present before the rework, but it became almost
  systematic after patch 1
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents e1170333 8a05661b
...@@ -624,6 +624,9 @@ bool mptcp_established_options(struct sock *sk, struct sk_buff *skb, ...@@ -624,6 +624,9 @@ bool mptcp_established_options(struct sock *sk, struct sk_buff *skb,
opts->suboptions = 0; opts->suboptions = 0;
if (unlikely(mptcp_check_fallback(sk)))
return false;
if (mptcp_established_options_mp(sk, skb, &opt_size, remaining, opts)) if (mptcp_established_options_mp(sk, skb, &opt_size, remaining, opts))
ret = true; ret = true;
else if (mptcp_established_options_dss(sk, skb, &opt_size, remaining, else if (mptcp_established_options_dss(sk, skb, &opt_size, remaining,
...@@ -714,7 +717,8 @@ static bool check_fully_established(struct mptcp_sock *msk, struct sock *sk, ...@@ -714,7 +717,8 @@ static bool check_fully_established(struct mptcp_sock *msk, struct sock *sk,
*/ */
if (!mp_opt->mp_capable) { if (!mp_opt->mp_capable) {
subflow->mp_capable = 0; subflow->mp_capable = 0;
tcp_sk(sk)->is_mptcp = 0; pr_fallback(msk);
__mptcp_do_fallback(msk);
return false; return false;
} }
...@@ -814,6 +818,9 @@ void mptcp_incoming_options(struct sock *sk, struct sk_buff *skb, ...@@ -814,6 +818,9 @@ void mptcp_incoming_options(struct sock *sk, struct sk_buff *skb,
struct mptcp_options_received mp_opt; struct mptcp_options_received mp_opt;
struct mptcp_ext *mpext; struct mptcp_ext *mpext;
if (__mptcp_check_fallback(msk))
return;
mptcp_get_options(skb, &mp_opt); mptcp_get_options(skb, &mp_opt);
if (!check_fully_established(msk, sk, subflow, skb, &mp_opt)) if (!check_fully_established(msk, sk, subflow, skb, &mp_opt))
return; return;
......
...@@ -52,18 +52,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk) ...@@ -52,18 +52,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
return msk->subflow; return msk->subflow;
} }
static bool __mptcp_needs_tcp_fallback(const struct mptcp_sock *msk) static bool mptcp_is_tcpsk(struct sock *sk)
{
return msk->first && !sk_is_mptcp(msk->first);
}
static struct socket *mptcp_is_tcpsk(struct sock *sk)
{ {
struct socket *sock = sk->sk_socket; struct socket *sock = sk->sk_socket;
if (sock->sk != sk)
return NULL;
if (unlikely(sk->sk_prot == &tcp_prot)) { if (unlikely(sk->sk_prot == &tcp_prot)) {
/* we are being invoked after mptcp_accept() has /* we are being invoked after mptcp_accept() has
* accepted a non-mp-capable flow: sk is a tcp_sk, * accepted a non-mp-capable flow: sk is a tcp_sk,
...@@ -73,59 +65,37 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk) ...@@ -73,59 +65,37 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk)
* bypass mptcp. * bypass mptcp.
*/ */
sock->ops = &inet_stream_ops; sock->ops = &inet_stream_ops;
return sock; return true;
#if IS_ENABLED(CONFIG_MPTCP_IPV6) #if IS_ENABLED(CONFIG_MPTCP_IPV6)
} else if (unlikely(sk->sk_prot == &tcpv6_prot)) { } else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
sock->ops = &inet6_stream_ops; sock->ops = &inet6_stream_ops;
return sock; return true;
#endif #endif
} }
return NULL; return false;
} }
static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk) static struct sock *__mptcp_tcp_fallback(struct mptcp_sock *msk)
{ {
struct socket *sock;
sock_owned_by_me((const struct sock *)msk); sock_owned_by_me((const struct sock *)msk);
sock = mptcp_is_tcpsk((struct sock *)msk); if (likely(!__mptcp_check_fallback(msk)))
if (unlikely(sock))
return sock;
if (likely(!__mptcp_needs_tcp_fallback(msk)))
return NULL; return NULL;
return msk->subflow; return msk->first;
}
static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk)
{
return !msk->first;
} }
static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state) static int __mptcp_socket_create(struct mptcp_sock *msk)
{ {
struct mptcp_subflow_context *subflow; struct mptcp_subflow_context *subflow;
struct sock *sk = (struct sock *)msk; struct sock *sk = (struct sock *)msk;
struct socket *ssock; struct socket *ssock;
int err; int err;
ssock = __mptcp_tcp_fallback(msk);
if (unlikely(ssock))
return ssock;
ssock = __mptcp_nmpc_socket(msk);
if (ssock)
goto set_state;
if (!__mptcp_can_create_subflow(msk))
return ERR_PTR(-EINVAL);
err = mptcp_subflow_create_socket(sk, &ssock); err = mptcp_subflow_create_socket(sk, &ssock);
if (err) if (err)
return ERR_PTR(err); return err;
msk->first = ssock->sk; msk->first = ssock->sk;
msk->subflow = ssock; msk->subflow = ssock;
...@@ -133,10 +103,12 @@ static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state) ...@@ -133,10 +103,12 @@ static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
list_add(&subflow->node, &msk->conn_list); list_add(&subflow->node, &msk->conn_list);
subflow->request_mptcp = 1; subflow->request_mptcp = 1;
set_state: /* accept() will wait on first subflow sk_wq, and we always wakes up
if (state != MPTCP_SAME_STATE) * via msk->sk_socket
inet_sk_state_store(sk, state); */
return ssock; RCU_INIT_POINTER(msk->first->sk_wq, &sk->sk_socket->wq);
return 0;
} }
static void __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk, static void __mptcp_move_skb(struct mptcp_sock *msk, struct sock *ssk,
...@@ -229,6 +201,15 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk, ...@@ -229,6 +201,15 @@ static bool __mptcp_move_skbs_from_subflow(struct mptcp_sock *msk,
if (!skb) if (!skb)
break; break;
if (__mptcp_check_fallback(msk)) {
/* if we are running under the workqueue, TCP could have
* collapsed skbs between dummy map creation and now
* be sure to adjust the size
*/
map_remaining = skb->len;
subflow->map_data_len = skb->len;
}
offset = seq - TCP_SKB_CB(skb)->seq; offset = seq - TCP_SKB_CB(skb)->seq;
fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN; fin = TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN;
if (fin) { if (fin) {
...@@ -466,8 +447,15 @@ static void mptcp_clean_una(struct sock *sk) ...@@ -466,8 +447,15 @@ static void mptcp_clean_una(struct sock *sk)
{ {
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
struct mptcp_data_frag *dtmp, *dfrag; struct mptcp_data_frag *dtmp, *dfrag;
u64 snd_una = atomic64_read(&msk->snd_una);
bool cleaned = false; bool cleaned = false;
u64 snd_una;
/* on fallback we just need to ignore snd_una, as this is really
* plain TCP
*/
if (__mptcp_check_fallback(msk))
atomic64_set(&msk->snd_una, msk->write_seq);
snd_una = atomic64_read(&msk->snd_una);
list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) { list_for_each_entry_safe(dfrag, dtmp, &msk->rtx_queue, list) {
if (after64(dfrag->data_seq + dfrag->data_len, snd_una)) if (after64(dfrag->data_seq + dfrag->data_len, snd_una))
...@@ -740,7 +728,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -740,7 +728,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
int mss_now = 0, size_goal = 0, ret = 0; int mss_now = 0, size_goal = 0, ret = 0;
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
struct page_frag *pfrag; struct page_frag *pfrag;
struct socket *ssock;
size_t copied = 0; size_t copied = 0;
struct sock *ssk; struct sock *ssk;
bool tx_ok; bool tx_ok;
...@@ -759,15 +746,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -759,15 +746,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
goto out; goto out;
} }
fallback:
ssock = __mptcp_tcp_fallback(msk);
if (unlikely(ssock)) {
release_sock(sk);
pr_debug("fallback passthrough");
ret = sock_sendmsg(ssock, msg);
return ret >= 0 ? ret + copied : (copied ? copied : ret);
}
pfrag = sk_page_frag(sk); pfrag = sk_page_frag(sk);
restart: restart:
mptcp_clean_una(sk); mptcp_clean_una(sk);
...@@ -819,17 +797,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) ...@@ -819,17 +797,6 @@ static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
} }
break; break;
} }
if (ret == 0 && unlikely(__mptcp_needs_tcp_fallback(msk))) {
/* Can happen for passive sockets:
* 3WHS negotiated MPTCP, but first packet after is
* plain TCP (e.g. due to middlebox filtering unknown
* options).
*
* Fall back to TCP.
*/
release_sock(ssk);
goto fallback;
}
copied += ret; copied += ret;
...@@ -972,7 +939,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -972,7 +939,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
int nonblock, int flags, int *addr_len) int nonblock, int flags, int *addr_len)
{ {
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
struct socket *ssock;
int copied = 0; int copied = 0;
int target; int target;
long timeo; long timeo;
...@@ -981,16 +947,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -981,16 +947,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
return -EOPNOTSUPP; return -EOPNOTSUPP;
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_tcp_fallback(msk);
if (unlikely(ssock)) {
fallback:
release_sock(sk);
pr_debug("fallback-read subflow=%p",
mptcp_subflow_ctx(ssock->sk));
copied = sock_recvmsg(ssock, msg, flags);
return copied;
}
timeo = sock_rcvtimeo(sk, nonblock); timeo = sock_rcvtimeo(sk, nonblock);
len = min_t(size_t, len, INT_MAX); len = min_t(size_t, len, INT_MAX);
...@@ -1056,9 +1012,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, ...@@ -1056,9 +1012,6 @@ static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
pr_debug("block timeout %ld", timeo); pr_debug("block timeout %ld", timeo);
mptcp_wait_data(sk, &timeo); mptcp_wait_data(sk, &timeo);
ssock = __mptcp_tcp_fallback(msk);
if (unlikely(ssock))
goto fallback;
} }
if (skb_queue_empty(&sk->sk_receive_queue)) { if (skb_queue_empty(&sk->sk_receive_queue)) {
...@@ -1283,6 +1236,10 @@ static int mptcp_init_sock(struct sock *sk) ...@@ -1283,6 +1236,10 @@ static int mptcp_init_sock(struct sock *sk)
if (ret) if (ret)
return ret; return ret;
ret = __mptcp_socket_create(mptcp_sk(sk));
if (ret)
return ret;
sk_sockets_allocated_inc(sk); sk_sockets_allocated_inc(sk);
sk->sk_sndbuf = sock_net(sk)->ipv4.sysctl_tcp_wmem[2]; sk->sk_sndbuf = sock_net(sk)->ipv4.sysctl_tcp_wmem[2];
...@@ -1335,8 +1292,6 @@ static void mptcp_subflow_shutdown(struct sock *ssk, int how, ...@@ -1335,8 +1292,6 @@ static void mptcp_subflow_shutdown(struct sock *ssk, int how,
break; break;
} }
/* Wake up anyone sleeping in poll. */
ssk->sk_state_change(ssk);
release_sock(ssk); release_sock(ssk);
} }
...@@ -1487,7 +1442,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, ...@@ -1487,7 +1442,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
return NULL; return NULL;
pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk)); pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
if (sk_is_mptcp(newsk)) { if (sk_is_mptcp(newsk)) {
struct mptcp_subflow_context *subflow; struct mptcp_subflow_context *subflow;
struct sock *new_mptcp_sock; struct sock *new_mptcp_sock;
...@@ -1544,7 +1498,7 @@ static int mptcp_setsockopt(struct sock *sk, int level, int optname, ...@@ -1544,7 +1498,7 @@ static int mptcp_setsockopt(struct sock *sk, int level, int optname,
char __user *optval, unsigned int optlen) char __user *optval, unsigned int optlen)
{ {
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
struct socket *ssock; struct sock *ssk;
pr_debug("msk=%p", msk); pr_debug("msk=%p", msk);
...@@ -1555,11 +1509,10 @@ static int mptcp_setsockopt(struct sock *sk, int level, int optname, ...@@ -1555,11 +1509,10 @@ static int mptcp_setsockopt(struct sock *sk, int level, int optname,
* to the one remaining subflow. * to the one remaining subflow.
*/ */
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_tcp_fallback(msk); ssk = __mptcp_tcp_fallback(msk);
release_sock(sk); release_sock(sk);
if (ssock) if (ssk)
return tcp_setsockopt(ssock->sk, level, optname, optval, return tcp_setsockopt(ssk, level, optname, optval, optlen);
optlen);
return -EOPNOTSUPP; return -EOPNOTSUPP;
} }
...@@ -1568,7 +1521,7 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname, ...@@ -1568,7 +1521,7 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname,
char __user *optval, int __user *option) char __user *optval, int __user *option)
{ {
struct mptcp_sock *msk = mptcp_sk(sk); struct mptcp_sock *msk = mptcp_sk(sk);
struct socket *ssock; struct sock *ssk;
pr_debug("msk=%p", msk); pr_debug("msk=%p", msk);
...@@ -1579,11 +1532,10 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname, ...@@ -1579,11 +1532,10 @@ static int mptcp_getsockopt(struct sock *sk, int level, int optname,
* to the one remaining subflow. * to the one remaining subflow.
*/ */
lock_sock(sk); lock_sock(sk);
ssock = __mptcp_tcp_fallback(msk); ssk = __mptcp_tcp_fallback(msk);
release_sock(sk); release_sock(sk);
if (ssock) if (ssk)
return tcp_getsockopt(ssock->sk, level, optname, optval, return tcp_getsockopt(ssk, level, optname, optval, option);
option);
return -EOPNOTSUPP; return -EOPNOTSUPP;
} }
...@@ -1660,12 +1612,6 @@ void mptcp_finish_connect(struct sock *ssk) ...@@ -1660,12 +1612,6 @@ void mptcp_finish_connect(struct sock *ssk)
sk = subflow->conn; sk = subflow->conn;
msk = mptcp_sk(sk); msk = mptcp_sk(sk);
if (!subflow->mp_capable) {
MPTCP_INC_STATS(sock_net(sk),
MPTCP_MIB_MPCAPABLEACTIVEFALLBACK);
return;
}
pr_debug("msk=%p, token=%u", sk, subflow->token); pr_debug("msk=%p, token=%u", sk, subflow->token);
mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq); mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq);
...@@ -1781,9 +1727,9 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) ...@@ -1781,9 +1727,9 @@ static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
int err; int err;
lock_sock(sock->sk); lock_sock(sock->sk);
ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE); ssock = __mptcp_nmpc_socket(msk);
if (IS_ERR(ssock)) { if (!ssock) {
err = PTR_ERR(ssock); err = -EINVAL;
goto unlock; goto unlock;
} }
...@@ -1813,13 +1759,14 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1813,13 +1759,14 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
goto do_connect; goto do_connect;
} }
mptcp_token_destroy(msk); ssock = __mptcp_nmpc_socket(msk);
ssock = __mptcp_socket_create(msk, TCP_SYN_SENT); if (!ssock) {
if (IS_ERR(ssock)) { err = -EINVAL;
err = PTR_ERR(ssock);
goto unlock; goto unlock;
} }
mptcp_token_destroy(msk);
inet_sk_state_store(sock->sk, TCP_SYN_SENT);
subflow = mptcp_subflow_ctx(ssock->sk); subflow = mptcp_subflow_ctx(ssock->sk);
#ifdef CONFIG_TCP_MD5SIG #ifdef CONFIG_TCP_MD5SIG
/* no MPTCP if MD5SIG is enabled on this socket or we may run out of /* no MPTCP if MD5SIG is enabled on this socket or we may run out of
...@@ -1848,42 +1795,6 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, ...@@ -1848,42 +1795,6 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
return err; return err;
} }
static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
int peer)
{
if (sock->sk->sk_prot == &tcp_prot) {
/* we are being invoked from __sys_accept4, after
* mptcp_accept() has just accepted a non-mp-capable
* flow: sk is a tcp_sk, not an mptcp one.
*
* Hand the socket over to tcp so all further socket ops
* bypass mptcp.
*/
sock->ops = &inet_stream_ops;
}
return inet_getname(sock, uaddr, peer);
}
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
int peer)
{
if (sock->sk->sk_prot == &tcpv6_prot) {
/* we are being invoked from __sys_accept4 after
* mptcp_accept() has accepted a non-mp-capable
* subflow: sk is a tcp_sk, not mptcp.
*
* Hand the socket over to tcp so all further
* socket ops bypass mptcp.
*/
sock->ops = &inet6_stream_ops;
}
return inet6_getname(sock, uaddr, peer);
}
#endif
static int mptcp_listen(struct socket *sock, int backlog) static int mptcp_listen(struct socket *sock, int backlog)
{ {
struct mptcp_sock *msk = mptcp_sk(sock->sk); struct mptcp_sock *msk = mptcp_sk(sock->sk);
...@@ -1893,13 +1804,14 @@ static int mptcp_listen(struct socket *sock, int backlog) ...@@ -1893,13 +1804,14 @@ static int mptcp_listen(struct socket *sock, int backlog)
pr_debug("msk=%p", msk); pr_debug("msk=%p", msk);
lock_sock(sock->sk); lock_sock(sock->sk);
mptcp_token_destroy(msk); ssock = __mptcp_nmpc_socket(msk);
ssock = __mptcp_socket_create(msk, TCP_LISTEN); if (!ssock) {
if (IS_ERR(ssock)) { err = -EINVAL;
err = PTR_ERR(ssock);
goto unlock; goto unlock;
} }
mptcp_token_destroy(msk);
inet_sk_state_store(sock->sk, TCP_LISTEN);
sock_set_flag(sock->sk, SOCK_RCU_FREE); sock_set_flag(sock->sk, SOCK_RCU_FREE);
err = ssock->ops->listen(ssock, backlog); err = ssock->ops->listen(ssock, backlog);
...@@ -1912,15 +1824,6 @@ static int mptcp_listen(struct socket *sock, int backlog) ...@@ -1912,15 +1824,6 @@ static int mptcp_listen(struct socket *sock, int backlog)
return err; return err;
} }
static bool is_tcp_proto(const struct proto *p)
{
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
return p == &tcp_prot || p == &tcpv6_prot;
#else
return p == &tcp_prot;
#endif
}
static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
int flags, bool kern) int flags, bool kern)
{ {
...@@ -1938,11 +1841,12 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, ...@@ -1938,11 +1841,12 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
if (!ssock) if (!ssock)
goto unlock_fail; goto unlock_fail;
clear_bit(MPTCP_DATA_READY, &msk->flags);
sock_hold(ssock->sk); sock_hold(ssock->sk);
release_sock(sock->sk); release_sock(sock->sk);
err = ssock->ops->accept(sock, newsock, flags, kern); err = ssock->ops->accept(sock, newsock, flags, kern);
if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) { if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {
struct mptcp_sock *msk = mptcp_sk(newsock->sk); struct mptcp_sock *msk = mptcp_sk(newsock->sk);
struct mptcp_subflow_context *subflow; struct mptcp_subflow_context *subflow;
...@@ -1958,6 +1862,8 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, ...@@ -1958,6 +1862,8 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
} }
} }
if (inet_csk_listen_poll(ssock->sk))
set_bit(MPTCP_DATA_READY, &msk->flags);
sock_put(ssock->sk); sock_put(ssock->sk);
return err; return err;
...@@ -1966,39 +1872,36 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, ...@@ -1966,39 +1872,36 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
return -EINVAL; return -EINVAL;
} }
static __poll_t mptcp_check_readable(struct mptcp_sock *msk)
{
return test_bit(MPTCP_DATA_READY, &msk->flags) ? EPOLLIN | EPOLLRDNORM :
0;
}
static __poll_t mptcp_poll(struct file *file, struct socket *sock, static __poll_t mptcp_poll(struct file *file, struct socket *sock,
struct poll_table_struct *wait) struct poll_table_struct *wait)
{ {
struct sock *sk = sock->sk; struct sock *sk = sock->sk;
struct mptcp_sock *msk; struct mptcp_sock *msk;
struct socket *ssock;
__poll_t mask = 0; __poll_t mask = 0;
int state;
msk = mptcp_sk(sk); msk = mptcp_sk(sk);
lock_sock(sk);
ssock = __mptcp_tcp_fallback(msk);
if (!ssock)
ssock = __mptcp_nmpc_socket(msk);
if (ssock) {
mask = ssock->ops->poll(file, ssock, wait);
release_sock(sk);
return mask;
}
release_sock(sk);
sock_poll_wait(file, sock, wait); sock_poll_wait(file, sock, wait);
lock_sock(sk);
if (test_bit(MPTCP_DATA_READY, &msk->flags)) state = inet_sk_state_load(sk);
mask = EPOLLIN | EPOLLRDNORM; if (state == TCP_LISTEN)
if (sk_stream_is_writeable(sk) && return mptcp_check_readable(msk);
test_bit(MPTCP_SEND_SPACE, &msk->flags))
mask |= EPOLLOUT | EPOLLWRNORM; if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) {
mask |= mptcp_check_readable(msk);
if (sk_stream_is_writeable(sk) &&
test_bit(MPTCP_SEND_SPACE, &msk->flags))
mask |= EPOLLOUT | EPOLLWRNORM;
}
if (sk->sk_shutdown & RCV_SHUTDOWN) if (sk->sk_shutdown & RCV_SHUTDOWN)
mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP;
release_sock(sk);
return mask; return mask;
} }
...@@ -2006,18 +1909,11 @@ static int mptcp_shutdown(struct socket *sock, int how) ...@@ -2006,18 +1909,11 @@ static int mptcp_shutdown(struct socket *sock, int how)
{ {
struct mptcp_sock *msk = mptcp_sk(sock->sk); struct mptcp_sock *msk = mptcp_sk(sock->sk);
struct mptcp_subflow_context *subflow; struct mptcp_subflow_context *subflow;
struct socket *ssock;
int ret = 0; int ret = 0;
pr_debug("sk=%p, how=%d", msk, how); pr_debug("sk=%p, how=%d", msk, how);
lock_sock(sock->sk); lock_sock(sock->sk);
ssock = __mptcp_tcp_fallback(msk);
if (ssock) {
release_sock(sock->sk);
return inet_shutdown(ssock, how);
}
if (how == SHUT_WR || how == SHUT_RDWR) if (how == SHUT_WR || how == SHUT_RDWR)
inet_sk_state_store(sock->sk, TCP_FIN_WAIT1); inet_sk_state_store(sock->sk, TCP_FIN_WAIT1);
...@@ -2043,6 +1939,9 @@ static int mptcp_shutdown(struct socket *sock, int how) ...@@ -2043,6 +1939,9 @@ static int mptcp_shutdown(struct socket *sock, int how)
mptcp_subflow_shutdown(tcp_sk, how, 1, msk->write_seq); mptcp_subflow_shutdown(tcp_sk, how, 1, msk->write_seq);
} }
/* Wake up anyone sleeping in poll. */
sock->sk->sk_state_change(sock->sk);
out_unlock: out_unlock:
release_sock(sock->sk); release_sock(sock->sk);
...@@ -2057,7 +1956,7 @@ static const struct proto_ops mptcp_stream_ops = { ...@@ -2057,7 +1956,7 @@ static const struct proto_ops mptcp_stream_ops = {
.connect = mptcp_stream_connect, .connect = mptcp_stream_connect,
.socketpair = sock_no_socketpair, .socketpair = sock_no_socketpair,
.accept = mptcp_stream_accept, .accept = mptcp_stream_accept,
.getname = mptcp_v4_getname, .getname = inet_getname,
.poll = mptcp_poll, .poll = mptcp_poll,
.ioctl = inet_ioctl, .ioctl = inet_ioctl,
.gettstamp = sock_gettstamp, .gettstamp = sock_gettstamp,
...@@ -2111,7 +2010,7 @@ static const struct proto_ops mptcp_v6_stream_ops = { ...@@ -2111,7 +2010,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
.connect = mptcp_stream_connect, .connect = mptcp_stream_connect,
.socketpair = sock_no_socketpair, .socketpair = sock_no_socketpair,
.accept = mptcp_stream_accept, .accept = mptcp_stream_accept,
.getname = mptcp_v6_getname, .getname = inet6_getname,
.poll = mptcp_poll, .poll = mptcp_poll,
.ioctl = inet6_ioctl, .ioctl = inet6_ioctl,
.gettstamp = sock_gettstamp, .gettstamp = sock_gettstamp,
......
...@@ -89,6 +89,7 @@ ...@@ -89,6 +89,7 @@
#define MPTCP_SEND_SPACE 1 #define MPTCP_SEND_SPACE 1
#define MPTCP_WORK_RTX 2 #define MPTCP_WORK_RTX 2
#define MPTCP_WORK_EOF 3 #define MPTCP_WORK_EOF 3
#define MPTCP_FALLBACK_DONE 4
struct mptcp_options_received { struct mptcp_options_received {
u64 sndr_key; u64 sndr_key;
...@@ -457,4 +458,46 @@ static inline bool before64(__u64 seq1, __u64 seq2) ...@@ -457,4 +458,46 @@ static inline bool before64(__u64 seq1, __u64 seq2)
void mptcp_diag_subflow_init(struct tcp_ulp_ops *ops); void mptcp_diag_subflow_init(struct tcp_ulp_ops *ops);
static inline bool __mptcp_check_fallback(struct mptcp_sock *msk)
{
return test_bit(MPTCP_FALLBACK_DONE, &msk->flags);
}
static inline bool mptcp_check_fallback(struct sock *sk)
{
struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
struct mptcp_sock *msk = mptcp_sk(subflow->conn);
return __mptcp_check_fallback(msk);
}
static inline void __mptcp_do_fallback(struct mptcp_sock *msk)
{
if (test_bit(MPTCP_FALLBACK_DONE, &msk->flags)) {
pr_debug("TCP fallback already done (msk=%p)", msk);
return;
}
set_bit(MPTCP_FALLBACK_DONE, &msk->flags);
}
static inline void mptcp_do_fallback(struct sock *sk)
{
struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
struct mptcp_sock *msk = mptcp_sk(subflow->conn);
__mptcp_do_fallback(msk);
}
#define pr_fallback(a) pr_debug("%s:fallback to TCP (msk=%p)", __func__, a)
static inline bool subflow_simultaneous_connect(struct sock *sk)
{
struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
struct sock *parent = subflow->conn;
return sk->sk_state == TCP_ESTABLISHED &&
!mptcp_sk(parent)->pm.server_side &&
!subflow->conn_finished;
}
#endif /* __MPTCP_PROTOCOL_H */ #endif /* __MPTCP_PROTOCOL_H */
...@@ -216,7 +216,6 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) ...@@ -216,7 +216,6 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb)
struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
struct mptcp_options_received mp_opt; struct mptcp_options_received mp_opt;
struct sock *parent = subflow->conn; struct sock *parent = subflow->conn;
struct tcp_sock *tp = tcp_sk(sk);
subflow->icsk_af_ops->sk_rx_dst_set(sk, skb); subflow->icsk_af_ops->sk_rx_dst_set(sk, skb);
...@@ -230,6 +229,8 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) ...@@ -230,6 +229,8 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb)
return; return;
subflow->conn_finished = 1; subflow->conn_finished = 1;
subflow->ssn_offset = TCP_SKB_CB(skb)->seq;
pr_debug("subflow=%p synack seq=%x", subflow, subflow->ssn_offset);
mptcp_get_options(skb, &mp_opt); mptcp_get_options(skb, &mp_opt);
if (subflow->request_mptcp && mp_opt.mp_capable) { if (subflow->request_mptcp && mp_opt.mp_capable) {
...@@ -245,21 +246,20 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) ...@@ -245,21 +246,20 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb)
pr_debug("subflow=%p, thmac=%llu, remote_nonce=%u", subflow, pr_debug("subflow=%p, thmac=%llu, remote_nonce=%u", subflow,
subflow->thmac, subflow->remote_nonce); subflow->thmac, subflow->remote_nonce);
} else { } else {
tp->is_mptcp = 0; if (subflow->request_mptcp)
MPTCP_INC_STATS(sock_net(sk),
MPTCP_MIB_MPCAPABLEACTIVEFALLBACK);
mptcp_do_fallback(sk);
pr_fallback(mptcp_sk(subflow->conn));
} }
if (!tp->is_mptcp) if (mptcp_check_fallback(sk))
return; return;
if (subflow->mp_capable) { if (subflow->mp_capable) {
pr_debug("subflow=%p, remote_key=%llu", mptcp_subflow_ctx(sk), pr_debug("subflow=%p, remote_key=%llu", mptcp_subflow_ctx(sk),
subflow->remote_key); subflow->remote_key);
mptcp_finish_connect(sk); mptcp_finish_connect(sk);
if (skb) {
pr_debug("synack seq=%u", TCP_SKB_CB(skb)->seq);
subflow->ssn_offset = TCP_SKB_CB(skb)->seq;
}
} else if (subflow->mp_join) { } else if (subflow->mp_join) {
u8 hmac[SHA256_DIGEST_SIZE]; u8 hmac[SHA256_DIGEST_SIZE];
...@@ -279,9 +279,6 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb) ...@@ -279,9 +279,6 @@ static void subflow_finish_connect(struct sock *sk, const struct sk_buff *skb)
memcpy(subflow->hmac, hmac, MPTCPOPT_HMAC_LEN); memcpy(subflow->hmac, hmac, MPTCPOPT_HMAC_LEN);
if (skb)
subflow->ssn_offset = TCP_SKB_CB(skb)->seq;
if (!mptcp_finish_join(sk)) if (!mptcp_finish_join(sk))
goto do_reset; goto do_reset;
...@@ -557,7 +554,8 @@ enum mapping_status { ...@@ -557,7 +554,8 @@ enum mapping_status {
MAPPING_OK, MAPPING_OK,
MAPPING_INVALID, MAPPING_INVALID,
MAPPING_EMPTY, MAPPING_EMPTY,
MAPPING_DATA_FIN MAPPING_DATA_FIN,
MAPPING_DUMMY
}; };
static u64 expand_seq(u64 old_seq, u16 old_data_len, u64 seq) static u64 expand_seq(u64 old_seq, u16 old_data_len, u64 seq)
...@@ -621,6 +619,9 @@ static enum mapping_status get_mapping_status(struct sock *ssk) ...@@ -621,6 +619,9 @@ static enum mapping_status get_mapping_status(struct sock *ssk)
if (!skb) if (!skb)
return MAPPING_EMPTY; return MAPPING_EMPTY;
if (mptcp_check_fallback(ssk))
return MAPPING_DUMMY;
mpext = mptcp_get_ext(skb); mpext = mptcp_get_ext(skb);
if (!mpext || !mpext->use_map) { if (!mpext || !mpext->use_map) {
if (!subflow->map_valid && !skb->len) { if (!subflow->map_valid && !skb->len) {
...@@ -762,6 +763,16 @@ static bool subflow_check_data_avail(struct sock *ssk) ...@@ -762,6 +763,16 @@ static bool subflow_check_data_avail(struct sock *ssk)
ssk->sk_err = EBADMSG; ssk->sk_err = EBADMSG;
goto fatal; goto fatal;
} }
if (status == MAPPING_DUMMY) {
__mptcp_do_fallback(msk);
skb = skb_peek(&ssk->sk_receive_queue);
subflow->map_valid = 1;
subflow->map_seq = READ_ONCE(msk->ack_seq);
subflow->map_data_len = skb->len;
subflow->map_subflow_seq = tcp_sk(ssk)->copied_seq -
subflow->ssn_offset;
return true;
}
if (status != MAPPING_OK) if (status != MAPPING_OK)
return false; return false;
...@@ -885,14 +896,18 @@ static void subflow_data_ready(struct sock *sk) ...@@ -885,14 +896,18 @@ static void subflow_data_ready(struct sock *sk)
{ {
struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk); struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(sk);
struct sock *parent = subflow->conn; struct sock *parent = subflow->conn;
struct mptcp_sock *msk;
if (!subflow->mp_capable && !subflow->mp_join) { msk = mptcp_sk(parent);
subflow->tcp_data_ready(sk); if (inet_sk_state_load(sk) == TCP_LISTEN) {
set_bit(MPTCP_DATA_READY, &msk->flags);
parent->sk_data_ready(parent); parent->sk_data_ready(parent);
return; return;
} }
WARN_ON_ONCE(!__mptcp_check_fallback(msk) && !subflow->mp_capable &&
!subflow->mp_join);
if (mptcp_subflow_data_available(sk)) if (mptcp_subflow_data_available(sk))
mptcp_data_ready(parent, sk); mptcp_data_ready(parent, sk);
} }
...@@ -1113,11 +1128,21 @@ static void subflow_state_change(struct sock *sk) ...@@ -1113,11 +1128,21 @@ static void subflow_state_change(struct sock *sk)
__subflow_state_change(sk); __subflow_state_change(sk);
if (subflow_simultaneous_connect(sk)) {
mptcp_do_fallback(sk);
pr_fallback(mptcp_sk(parent));
subflow->conn_finished = 1;
if (inet_sk_state_load(parent) == TCP_SYN_SENT) {
inet_sk_state_store(parent, TCP_ESTABLISHED);
parent->sk_state_change(parent);
}
}
/* as recvmsg() does not acquire the subflow socket for ssk selection /* as recvmsg() does not acquire the subflow socket for ssk selection
* a fin packet carrying a DSS can be unnoticed if we don't trigger * a fin packet carrying a DSS can be unnoticed if we don't trigger
* the data available machinery here. * the data available machinery here.
*/ */
if (subflow->mp_capable && mptcp_subflow_data_available(sk)) if (mptcp_subflow_data_available(sk))
mptcp_data_ready(parent, sk); mptcp_data_ready(parent, sk);
if (!(parent->sk_shutdown & RCV_SHUTDOWN) && if (!(parent->sk_shutdown & RCV_SHUTDOWN) &&
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment