Commit 393a2a20 authored by David Howells's avatar David Howells

rxrpc: Extract the peer address from an incoming packet earlier

Extract the peer address from an incoming packet earlier, at the beginning
of rxrpc_input_packet() and thence pass a pointer to it to various
functions that use it as part of the lookup rather than doing it on several
separate paths.
Signed-off-by: default avatarDavid Howells <dhowells@redhat.com>
cc: Marc Dionne <marc.dionne@auristor.com>
cc: linux-afs@lists.infradead.org
parent cd21effb
...@@ -824,6 +824,7 @@ int rxrpc_service_prealloc(struct rxrpc_sock *, gfp_t); ...@@ -824,6 +824,7 @@ int rxrpc_service_prealloc(struct rxrpc_sock *, gfp_t);
void rxrpc_discard_prealloc(struct rxrpc_sock *); void rxrpc_discard_prealloc(struct rxrpc_sock *);
struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *, struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *,
struct rxrpc_sock *, struct rxrpc_sock *,
struct sockaddr_rxrpc *,
struct sk_buff *); struct sk_buff *);
void rxrpc_accept_incoming_calls(struct rxrpc_local *); void rxrpc_accept_incoming_calls(struct rxrpc_local *);
int rxrpc_user_charge_accept(struct rxrpc_sock *, unsigned long); int rxrpc_user_charge_accept(struct rxrpc_sock *, unsigned long);
...@@ -916,6 +917,7 @@ extern unsigned int rxrpc_closed_conn_expiry; ...@@ -916,6 +917,7 @@ extern unsigned int rxrpc_closed_conn_expiry;
struct rxrpc_connection *rxrpc_alloc_connection(struct rxrpc_net *, gfp_t); struct rxrpc_connection *rxrpc_alloc_connection(struct rxrpc_net *, gfp_t);
struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *, struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *,
struct sockaddr_rxrpc *,
struct sk_buff *, struct sk_buff *,
struct rxrpc_peer **); struct rxrpc_peer **);
void __rxrpc_disconnect_call(struct rxrpc_connection *, struct rxrpc_call *); void __rxrpc_disconnect_call(struct rxrpc_connection *, struct rxrpc_call *);
......
...@@ -258,6 +258,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, ...@@ -258,6 +258,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx,
struct rxrpc_peer *peer, struct rxrpc_peer *peer,
struct rxrpc_connection *conn, struct rxrpc_connection *conn,
const struct rxrpc_security *sec, const struct rxrpc_security *sec,
struct sockaddr_rxrpc *peer_srx,
struct sk_buff *skb) struct sk_buff *skb)
{ {
struct rxrpc_backlog *b = rx->backlog; struct rxrpc_backlog *b = rx->backlog;
...@@ -287,8 +288,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, ...@@ -287,8 +288,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx,
peer = NULL; peer = NULL;
if (!peer) { if (!peer) {
peer = b->peer_backlog[peer_tail]; peer = b->peer_backlog[peer_tail];
if (rxrpc_extract_addr_from_skb(&peer->srx, skb) < 0) peer->srx = *peer_srx;
return NULL;
b->peer_backlog[peer_tail] = NULL; b->peer_backlog[peer_tail] = NULL;
smp_store_release(&b->peer_backlog_tail, smp_store_release(&b->peer_backlog_tail,
(peer_tail + 1) & (peer_tail + 1) &
...@@ -346,6 +346,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx, ...@@ -346,6 +346,7 @@ static struct rxrpc_call *rxrpc_alloc_incoming_call(struct rxrpc_sock *rx,
*/ */
struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local, struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local,
struct rxrpc_sock *rx, struct rxrpc_sock *rx,
struct sockaddr_rxrpc *peer_srx,
struct sk_buff *skb) struct sk_buff *skb)
{ {
struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
...@@ -371,7 +372,7 @@ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local, ...@@ -371,7 +372,7 @@ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local,
* we have to recheck the routing. However, we're now holding * we have to recheck the routing. However, we're now holding
* rx->incoming_lock, so the values should remain stable. * rx->incoming_lock, so the values should remain stable.
*/ */
conn = rxrpc_find_connection_rcu(local, skb, &peer); conn = rxrpc_find_connection_rcu(local, peer_srx, skb, &peer);
if (!conn) { if (!conn) {
sec = rxrpc_get_incoming_security(rx, skb); sec = rxrpc_get_incoming_security(rx, skb);
...@@ -379,7 +380,8 @@ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local, ...@@ -379,7 +380,8 @@ struct rxrpc_call *rxrpc_new_incoming_call(struct rxrpc_local *local,
goto no_call; goto no_call;
} }
call = rxrpc_alloc_incoming_call(rx, local, peer, conn, sec, skb); call = rxrpc_alloc_incoming_call(rx, local, peer, conn, sec, peer_srx,
skb);
if (!call) { if (!call) {
skb->mark = RXRPC_SKB_MARK_REJECT_BUSY; skb->mark = RXRPC_SKB_MARK_REJECT_BUSY;
goto no_call; goto no_call;
......
...@@ -73,29 +73,17 @@ struct rxrpc_connection *rxrpc_alloc_connection(struct rxrpc_net *rxnet, ...@@ -73,29 +73,17 @@ struct rxrpc_connection *rxrpc_alloc_connection(struct rxrpc_net *rxnet,
* The caller must be holding the RCU read lock. * The caller must be holding the RCU read lock.
*/ */
struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local,
struct sockaddr_rxrpc *srx,
struct sk_buff *skb, struct sk_buff *skb,
struct rxrpc_peer **_peer) struct rxrpc_peer **_peer)
{ {
struct rxrpc_connection *conn; struct rxrpc_connection *conn;
struct rxrpc_conn_proto k; struct rxrpc_conn_proto k;
struct rxrpc_skb_priv *sp = rxrpc_skb(skb); struct rxrpc_skb_priv *sp = rxrpc_skb(skb);
struct sockaddr_rxrpc srx;
struct rxrpc_peer *peer; struct rxrpc_peer *peer;
_enter(",%x", sp->hdr.cid & RXRPC_CIDMASK); _enter(",%x", sp->hdr.cid & RXRPC_CIDMASK);
if (rxrpc_extract_addr_from_skb(&srx, skb) < 0)
goto not_found;
if (srx.transport.family != local->srx.transport.family &&
(srx.transport.family == AF_INET &&
local->srx.transport.family != AF_INET6)) {
pr_warn_ratelimited("AF_RXRPC: Protocol mismatch %u not %u\n",
srx.transport.family,
local->srx.transport.family);
goto not_found;
}
k.epoch = sp->hdr.epoch; k.epoch = sp->hdr.epoch;
k.cid = sp->hdr.cid & RXRPC_CIDMASK; k.cid = sp->hdr.cid & RXRPC_CIDMASK;
...@@ -104,7 +92,7 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, ...@@ -104,7 +92,7 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local,
* parameter set. We look up the peer first as an intermediate * parameter set. We look up the peer first as an intermediate
* step and then the connection from the peer's tree. * step and then the connection from the peer's tree.
*/ */
peer = rxrpc_lookup_peer_rcu(local, &srx); peer = rxrpc_lookup_peer_rcu(local, srx);
if (!peer) if (!peer)
goto not_found; goto not_found;
*_peer = peer; *_peer = peer;
...@@ -117,8 +105,7 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, ...@@ -117,8 +105,7 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local,
/* Look up client connections by connection ID alone as their /* Look up client connections by connection ID alone as their
* IDs are unique for this machine. * IDs are unique for this machine.
*/ */
conn = idr_find(&rxrpc_client_conn_ids, conn = idr_find(&rxrpc_client_conn_ids, sp->hdr.cid >> RXRPC_CIDSHIFT);
sp->hdr.cid >> RXRPC_CIDSHIFT);
if (!conn || refcount_read(&conn->ref) == 0) { if (!conn || refcount_read(&conn->ref) == 0) {
_debug("no conn"); _debug("no conn");
goto not_found; goto not_found;
...@@ -129,20 +116,20 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local, ...@@ -129,20 +116,20 @@ struct rxrpc_connection *rxrpc_find_connection_rcu(struct rxrpc_local *local,
goto not_found; goto not_found;
peer = conn->peer; peer = conn->peer;
switch (srx.transport.family) { switch (srx->transport.family) {
case AF_INET: case AF_INET:
if (peer->srx.transport.sin.sin_port != if (peer->srx.transport.sin.sin_port !=
srx.transport.sin.sin_port || srx->transport.sin.sin_port ||
peer->srx.transport.sin.sin_addr.s_addr != peer->srx.transport.sin.sin_addr.s_addr !=
srx.transport.sin.sin_addr.s_addr) srx->transport.sin.sin_addr.s_addr)
goto not_found; goto not_found;
break; break;
#ifdef CONFIG_AF_RXRPC_IPV6 #ifdef CONFIG_AF_RXRPC_IPV6
case AF_INET6: case AF_INET6:
if (peer->srx.transport.sin6.sin6_port != if (peer->srx.transport.sin6.sin6_port !=
srx.transport.sin6.sin6_port || srx->transport.sin6.sin6_port ||
memcmp(&peer->srx.transport.sin6.sin6_addr, memcmp(&peer->srx.transport.sin6.sin6_addr,
&srx.transport.sin6.sin6_addr, &srx->transport.sin6.sin6_addr,
sizeof(struct in6_addr)) != 0) sizeof(struct in6_addr)) != 0)
goto not_found; goto not_found;
break; break;
......
...@@ -155,6 +155,7 @@ static bool rxrpc_extract_abort(struct sk_buff *skb) ...@@ -155,6 +155,7 @@ static bool rxrpc_extract_abort(struct sk_buff *skb)
static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb) static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
{ {
struct rxrpc_connection *conn; struct rxrpc_connection *conn;
struct sockaddr_rxrpc peer_srx;
struct rxrpc_channel *chan; struct rxrpc_channel *chan;
struct rxrpc_call *call = NULL; struct rxrpc_call *call = NULL;
struct rxrpc_skb_priv *sp; struct rxrpc_skb_priv *sp;
...@@ -257,6 +258,18 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb) ...@@ -257,6 +258,18 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
if (sp->hdr.serviceId == 0) if (sp->hdr.serviceId == 0)
goto bad_message; goto bad_message;
if (WARN_ON_ONCE(rxrpc_extract_addr_from_skb(&peer_srx, skb) < 0))
return 0; /* Unsupported address type - discard. */
if (peer_srx.transport.family != local->srx.transport.family &&
(peer_srx.transport.family == AF_INET &&
local->srx.transport.family != AF_INET6)) {
pr_warn_ratelimited("AF_RXRPC: Protocol mismatch %u not %u\n",
peer_srx.transport.family,
local->srx.transport.family);
return 0; /* Wrong address type - discard. */
}
rcu_read_lock(); rcu_read_lock();
if (rxrpc_to_server(sp)) { if (rxrpc_to_server(sp)) {
...@@ -276,7 +289,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb) ...@@ -276,7 +289,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
} }
} }
conn = rxrpc_find_connection_rcu(local, skb, &peer); conn = rxrpc_find_connection_rcu(local, &peer_srx, skb, &peer);
if (conn) { if (conn) {
if (sp->hdr.securityIndex != conn->security_ix) if (sp->hdr.securityIndex != conn->security_ix)
goto wrong_security; goto wrong_security;
...@@ -389,7 +402,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb) ...@@ -389,7 +402,7 @@ static int rxrpc_input_packet(struct rxrpc_local *local, struct sk_buff **_skb)
rcu_read_unlock(); rcu_read_unlock();
return 0; return 0;
} }
call = rxrpc_new_incoming_call(local, rx, skb); call = rxrpc_new_incoming_call(local, rx, &peer_srx, skb);
if (!call) { if (!call) {
rcu_read_unlock(); rcu_read_unlock();
goto reject_packet; goto reject_packet;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment