Commit 2da59918 authored by Allan Stephens's avatar Allan Stephens Committed by David S. Miller

tipc: Fix race condition that could cause accept() to fail

This patch ensurs that accept() returns successfully even when
the newly created socket is immediately disconnected by its peer.
Previously, accept() would fail if it was unable to pass back
the optional address info for the socket's peer before the
socket became disconnected; TIPC now allows accept() to gather
peer address information after disconnection.  As a bonus, the
revised code accesses the socket's port more efficiently, without
the overhead incurred by a reference table lookup.
Signed-off-by: default avatarAllan Stephens <allan.stephens@windriver.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 8642bd9e
...@@ -63,6 +63,7 @@ ...@@ -63,6 +63,7 @@
struct tipc_sock { struct tipc_sock {
struct sock sk; struct sock sk;
struct tipc_port *p; struct tipc_port *p;
struct tipc_portid peer_name;
}; };
#define tipc_sk(sk) ((struct tipc_sock *)(sk)) #define tipc_sk(sk) ((struct tipc_sock *)(sk))
...@@ -377,27 +378,29 @@ static int bind(struct socket *sock, struct sockaddr *uaddr, int uaddr_len) ...@@ -377,27 +378,29 @@ static int bind(struct socket *sock, struct sockaddr *uaddr, int uaddr_len)
* @sock: socket structure * @sock: socket structure
* @uaddr: area for returned socket address * @uaddr: area for returned socket address
* @uaddr_len: area for returned length of socket address * @uaddr_len: area for returned length of socket address
* @peer: 0 to obtain socket name, 1 to obtain peer socket name * @peer: 0 = own ID, 1 = current peer ID, 2 = current/former peer ID
* *
* Returns 0 on success, errno otherwise * Returns 0 on success, errno otherwise
* *
* NOTE: This routine doesn't need to take the socket lock since it doesn't * NOTE: This routine doesn't need to take the socket lock since it only
* access any non-constant socket information. * accesses socket information that is unchanging (or which changes in
* a completely predictable manner).
*/ */
static int get_name(struct socket *sock, struct sockaddr *uaddr, static int get_name(struct socket *sock, struct sockaddr *uaddr,
int *uaddr_len, int peer) int *uaddr_len, int peer)
{ {
struct sockaddr_tipc *addr = (struct sockaddr_tipc *)uaddr; struct sockaddr_tipc *addr = (struct sockaddr_tipc *)uaddr;
u32 portref = tipc_sk_port(sock->sk)->ref; struct tipc_sock *tsock = tipc_sk(sock->sk);
u32 res;
if (peer) { if (peer) {
res = tipc_peer(portref, &addr->addr.id); if ((sock->state != SS_CONNECTED) &&
if (res) ((peer != 2) || (sock->state != SS_DISCONNECTING)))
return res; return -ENOTCONN;
addr->addr.id.ref = tsock->peer_name.ref;
addr->addr.id.node = tsock->peer_name.node;
} else { } else {
tipc_ownidentity(portref, &addr->addr.id); tipc_ownidentity(tsock->p->ref, &addr->addr.id);
} }
*uaddr_len = sizeof(*addr); *uaddr_len = sizeof(*addr);
...@@ -766,18 +769,17 @@ static int send_stream(struct kiocb *iocb, struct socket *sock, ...@@ -766,18 +769,17 @@ static int send_stream(struct kiocb *iocb, struct socket *sock,
static int auto_connect(struct socket *sock, struct tipc_msg *msg) static int auto_connect(struct socket *sock, struct tipc_msg *msg)
{ {
struct tipc_port *tport = tipc_sk_port(sock->sk); struct tipc_sock *tsock = tipc_sk(sock->sk);
struct tipc_portid peer;
if (msg_errcode(msg)) { if (msg_errcode(msg)) {
sock->state = SS_DISCONNECTING; sock->state = SS_DISCONNECTING;
return -ECONNREFUSED; return -ECONNREFUSED;
} }
peer.ref = msg_origport(msg); tsock->peer_name.ref = msg_origport(msg);
peer.node = msg_orignode(msg); tsock->peer_name.node = msg_orignode(msg);
tipc_connect2port(tport->ref, &peer); tipc_connect2port(tsock->p->ref, &tsock->peer_name);
tipc_set_portimportance(tport->ref, msg_importance(msg)); tipc_set_portimportance(tsock->p->ref, msg_importance(msg));
sock->state = SS_CONNECTED; sock->state = SS_CONNECTED;
return 0; return 0;
} }
...@@ -1529,9 +1531,9 @@ static int accept(struct socket *sock, struct socket *new_sock, int flags) ...@@ -1529,9 +1531,9 @@ static int accept(struct socket *sock, struct socket *new_sock, int flags)
res = tipc_create(sock_net(sock->sk), new_sock, 0); res = tipc_create(sock_net(sock->sk), new_sock, 0);
if (!res) { if (!res) {
struct sock *new_sk = new_sock->sk; struct sock *new_sk = new_sock->sk;
struct tipc_port *new_tport = tipc_sk_port(new_sk); struct tipc_sock *new_tsock = tipc_sk(new_sk);
struct tipc_port *new_tport = new_tsock->p;
u32 new_ref = new_tport->ref; u32 new_ref = new_tport->ref;
struct tipc_portid id;
struct tipc_msg *msg = buf_msg(buf); struct tipc_msg *msg = buf_msg(buf);
lock_sock(new_sk); lock_sock(new_sk);
...@@ -1545,9 +1547,9 @@ static int accept(struct socket *sock, struct socket *new_sock, int flags) ...@@ -1545,9 +1547,9 @@ static int accept(struct socket *sock, struct socket *new_sock, int flags)
/* Connect new socket to it's peer */ /* Connect new socket to it's peer */
id.ref = msg_origport(msg); new_tsock->peer_name.ref = msg_origport(msg);
id.node = msg_orignode(msg); new_tsock->peer_name.node = msg_orignode(msg);
tipc_connect2port(new_ref, &id); tipc_connect2port(new_ref, &new_tsock->peer_name);
new_sock->state = SS_CONNECTED; new_sock->state = SS_CONNECTED;
tipc_set_portimportance(new_ref, msg_importance(msg)); tipc_set_portimportance(new_ref, msg_importance(msg));
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment