Commit 7727640c authored by Tim Smith's avatar Tim Smith Committed by David Howells

af_rxrpc: Keep rxrpc_call pointers in a hashtable

Keep track of rxrpc_call structures in a hashtable so they can be
found directly from the network parameters which define the call.

This allows incoming packets to be routed directly to a call without walking
through hierarchy of peer -> transport -> connection -> call and all the
spinlocks that that entailed.
Signed-off-by: default avatarTim Smith <tim@electronghost.co.uk>
Signed-off-by: default avatarDavid Howells <dhowells@redhat.com>
parent e8388eb1
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include <linux/slab.h> #include <linux/slab.h>
#include <linux/module.h> #include <linux/module.h>
#include <linux/circ_buf.h> #include <linux/circ_buf.h>
#include <linux/hashtable.h>
#include <linux/spinlock_types.h>
#include <net/sock.h> #include <net/sock.h>
#include <net/af_rxrpc.h> #include <net/af_rxrpc.h>
#include "ar-internal.h" #include "ar-internal.h"
...@@ -55,6 +57,145 @@ static void rxrpc_dead_call_expired(unsigned long _call); ...@@ -55,6 +57,145 @@ static void rxrpc_dead_call_expired(unsigned long _call);
static void rxrpc_ack_time_expired(unsigned long _call); static void rxrpc_ack_time_expired(unsigned long _call);
static void rxrpc_resend_time_expired(unsigned long _call); static void rxrpc_resend_time_expired(unsigned long _call);
static DEFINE_SPINLOCK(rxrpc_call_hash_lock);
static DEFINE_HASHTABLE(rxrpc_call_hash, 10);
/*
* Hash function for rxrpc_call_hash
*/
static unsigned long rxrpc_call_hashfunc(
u8 clientflag,
__be32 cid,
__be32 call_id,
__be32 epoch,
__be16 service_id,
sa_family_t proto,
void *localptr,
unsigned int addr_size,
const u8 *peer_addr)
{
const u16 *p;
unsigned int i;
unsigned long key;
u32 hcid = ntohl(cid);
_enter("");
key = (unsigned long)localptr;
/* We just want to add up the __be32 values, so forcing the
* cast should be okay.
*/
key += (__force u32)epoch;
key += (__force u16)service_id;
key += (__force u32)call_id;
key += (hcid & RXRPC_CIDMASK) >> RXRPC_CIDSHIFT;
key += hcid & RXRPC_CHANNELMASK;
key += clientflag;
key += proto;
/* Step through the peer address in 16-bit portions for speed */
for (i = 0, p = (const u16 *)peer_addr; i < addr_size >> 1; i++, p++)
key += *p;
_leave(" key = 0x%lx", key);
return key;
}
/*
* Add a call to the hashtable
*/
static void rxrpc_call_hash_add(struct rxrpc_call *call)
{
unsigned long key;
unsigned int addr_size = 0;
_enter("");
switch (call->proto) {
case AF_INET:
addr_size = sizeof(call->peer_ip.ipv4_addr);
break;
case AF_INET6:
addr_size = sizeof(call->peer_ip.ipv6_addr);
break;
default:
break;
}
key = rxrpc_call_hashfunc(call->in_clientflag, call->cid,
call->call_id, call->epoch,
call->service_id, call->proto,
call->conn->trans->local, addr_size,
call->peer_ip.ipv6_addr);
/* Store the full key in the call */
call->hash_key = key;
spin_lock(&rxrpc_call_hash_lock);
hash_add_rcu(rxrpc_call_hash, &call->hash_node, key);
spin_unlock(&rxrpc_call_hash_lock);
_leave("");
}
/*
* Remove a call from the hashtable
*/
static void rxrpc_call_hash_del(struct rxrpc_call *call)
{
_enter("");
spin_lock(&rxrpc_call_hash_lock);
hash_del_rcu(&call->hash_node);
spin_unlock(&rxrpc_call_hash_lock);
_leave("");
}
/*
* Find a call in the hashtable and return it, or NULL if it
* isn't there.
*/
struct rxrpc_call *rxrpc_find_call_hash(
u8 clientflag,
__be32 cid,
__be32 call_id,
__be32 epoch,
__be16 service_id,
void *localptr,
sa_family_t proto,
const u8 *peer_addr)
{
unsigned long key;
unsigned int addr_size = 0;
struct rxrpc_call *call = NULL;
struct rxrpc_call *ret = NULL;
_enter("");
switch (proto) {
case AF_INET:
addr_size = sizeof(call->peer_ip.ipv4_addr);
break;
case AF_INET6:
addr_size = sizeof(call->peer_ip.ipv6_addr);
break;
default:
break;
}
key = rxrpc_call_hashfunc(clientflag, cid, call_id, epoch,
service_id, proto, localptr, addr_size,
peer_addr);
hash_for_each_possible_rcu(rxrpc_call_hash, call, hash_node, key) {
if (call->hash_key == key &&
call->call_id == call_id &&
call->cid == cid &&
call->in_clientflag == clientflag &&
call->service_id == service_id &&
call->proto == proto &&
call->local == localptr &&
memcmp(call->peer_ip.ipv6_addr, peer_addr,
addr_size) == 0 &&
call->epoch == epoch) {
ret = call;
break;
}
}
_leave(" = %p", ret);
return ret;
}
/* /*
* allocate a new call * allocate a new call
*/ */
...@@ -136,6 +277,26 @@ static struct rxrpc_call *rxrpc_alloc_client_call( ...@@ -136,6 +277,26 @@ static struct rxrpc_call *rxrpc_alloc_client_call(
return ERR_PTR(ret); return ERR_PTR(ret);
} }
/* Record copies of information for hashtable lookup */
call->proto = rx->proto;
call->local = trans->local;
switch (call->proto) {
case AF_INET:
call->peer_ip.ipv4_addr =
trans->peer->srx.transport.sin.sin_addr.s_addr;
break;
case AF_INET6:
memcpy(call->peer_ip.ipv6_addr,
trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8,
sizeof(call->peer_ip.ipv6_addr));
break;
}
call->epoch = call->conn->epoch;
call->service_id = call->conn->service_id;
call->in_clientflag = call->conn->in_clientflag;
/* Add the new call to the hashtable */
rxrpc_call_hash_add(call);
spin_lock(&call->conn->trans->peer->lock); spin_lock(&call->conn->trans->peer->lock);
list_add(&call->error_link, &call->conn->trans->peer->error_targets); list_add(&call->error_link, &call->conn->trans->peer->error_targets);
spin_unlock(&call->conn->trans->peer->lock); spin_unlock(&call->conn->trans->peer->lock);
...@@ -328,9 +489,12 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx, ...@@ -328,9 +489,12 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx,
parent = *p; parent = *p;
call = rb_entry(parent, struct rxrpc_call, conn_node); call = rb_entry(parent, struct rxrpc_call, conn_node);
if (call_id < call->call_id) /* The tree is sorted in order of the __be32 value without
* turning it into host order.
*/
if ((__force u32)call_id < (__force u32)call->call_id)
p = &(*p)->rb_left; p = &(*p)->rb_left;
else if (call_id > call->call_id) else if ((__force u32)call_id > (__force u32)call->call_id)
p = &(*p)->rb_right; p = &(*p)->rb_right;
else else
goto old_call; goto old_call;
...@@ -355,6 +519,28 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx, ...@@ -355,6 +519,28 @@ struct rxrpc_call *rxrpc_incoming_call(struct rxrpc_sock *rx,
list_add_tail(&call->link, &rxrpc_calls); list_add_tail(&call->link, &rxrpc_calls);
write_unlock_bh(&rxrpc_call_lock); write_unlock_bh(&rxrpc_call_lock);
/* Record copies of information for hashtable lookup */
call->proto = rx->proto;
call->local = conn->trans->local;
switch (call->proto) {
case AF_INET:
call->peer_ip.ipv4_addr =
conn->trans->peer->srx.transport.sin.sin_addr.s_addr;
break;
case AF_INET6:
memcpy(call->peer_ip.ipv6_addr,
conn->trans->peer->srx.transport.sin6.sin6_addr.in6_u.u6_addr8,
sizeof(call->peer_ip.ipv6_addr));
break;
default:
break;
}
call->epoch = conn->epoch;
call->service_id = conn->service_id;
call->in_clientflag = conn->in_clientflag;
/* Add the new call to the hashtable */
rxrpc_call_hash_add(call);
_net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id); _net("CALL incoming %d on CONN %d", call->debug_id, call->conn->debug_id);
call->lifetimer.expires = jiffies + rxrpc_max_call_lifetime; call->lifetimer.expires = jiffies + rxrpc_max_call_lifetime;
...@@ -673,6 +859,9 @@ static void rxrpc_cleanup_call(struct rxrpc_call *call) ...@@ -673,6 +859,9 @@ static void rxrpc_cleanup_call(struct rxrpc_call *call)
rxrpc_put_connection(call->conn); rxrpc_put_connection(call->conn);
} }
/* Remove the call from the hash */
rxrpc_call_hash_del(call);
if (call->acks_window) { if (call->acks_window) {
_debug("kill Tx window %d", _debug("kill Tx window %d",
CIRC_CNT(call->acks_head, call->acks_tail, CIRC_CNT(call->acks_head, call->acks_tail,
......
...@@ -523,36 +523,38 @@ static void rxrpc_process_jumbo_packet(struct rxrpc_call *call, ...@@ -523,36 +523,38 @@ static void rxrpc_process_jumbo_packet(struct rxrpc_call *call,
* post an incoming packet to the appropriate call/socket to deal with * post an incoming packet to the appropriate call/socket to deal with
* - must get rid of the sk_buff, either by freeing it or by queuing it * - must get rid of the sk_buff, either by freeing it or by queuing it
*/ */
static void rxrpc_post_packet_to_call(struct rxrpc_connection *conn, static void rxrpc_post_packet_to_call(struct rxrpc_call *call,
struct sk_buff *skb) struct sk_buff *skb)
{ {
struct rxrpc_skb_priv *sp; struct rxrpc_skb_priv *sp;
struct rxrpc_call *call;
struct rb_node *p;
__be32 call_id;
_enter("%p,%p", conn, skb);
read_lock_bh(&conn->lock); _enter("%p,%p", call, skb);
sp = rxrpc_skb(skb); sp = rxrpc_skb(skb);
/* look at extant calls by channel number first */
call = conn->channels[ntohl(sp->hdr.cid) & RXRPC_CHANNELMASK];
if (!call || call->call_id != sp->hdr.callNumber)
goto call_not_extant;
_debug("extant call [%d]", call->state); _debug("extant call [%d]", call->state);
ASSERTCMP(call->conn, ==, conn);
read_lock(&call->state_lock); read_lock(&call->state_lock);
switch (call->state) { switch (call->state) {
case RXRPC_CALL_LOCALLY_ABORTED: case RXRPC_CALL_LOCALLY_ABORTED:
if (!test_and_set_bit(RXRPC_CALL_ABORT, &call->events)) if (!test_and_set_bit(RXRPC_CALL_ABORT, &call->events)) {
rxrpc_queue_call(call); rxrpc_queue_call(call);
goto free_unlock;
}
case RXRPC_CALL_REMOTELY_ABORTED: case RXRPC_CALL_REMOTELY_ABORTED:
case RXRPC_CALL_NETWORK_ERROR: case RXRPC_CALL_NETWORK_ERROR:
case RXRPC_CALL_DEAD: case RXRPC_CALL_DEAD:
goto dead_call;
case RXRPC_CALL_COMPLETE:
case RXRPC_CALL_CLIENT_FINAL_ACK:
/* complete server call */
if (call->conn->in_clientflag)
goto dead_call;
/* resend last packet of a completed call */
_debug("final ack again");
rxrpc_get_call(call);
set_bit(RXRPC_CALL_ACK_FINAL, &call->events);
rxrpc_queue_call(call);
goto free_unlock; goto free_unlock;
default: default:
break; break;
...@@ -560,7 +562,6 @@ static void rxrpc_post_packet_to_call(struct rxrpc_connection *conn, ...@@ -560,7 +562,6 @@ static void rxrpc_post_packet_to_call(struct rxrpc_connection *conn,
read_unlock(&call->state_lock); read_unlock(&call->state_lock);
rxrpc_get_call(call); rxrpc_get_call(call);
read_unlock_bh(&conn->lock);
if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA && if (sp->hdr.type == RXRPC_PACKET_TYPE_DATA &&
sp->hdr.flags & RXRPC_JUMBO_PACKET) sp->hdr.flags & RXRPC_JUMBO_PACKET)
...@@ -571,80 +572,16 @@ static void rxrpc_post_packet_to_call(struct rxrpc_connection *conn, ...@@ -571,80 +572,16 @@ static void rxrpc_post_packet_to_call(struct rxrpc_connection *conn,
rxrpc_put_call(call); rxrpc_put_call(call);
goto done; goto done;
call_not_extant:
/* search the completed calls in case what we're dealing with is
* there */
_debug("call not extant");
call_id = sp->hdr.callNumber;
p = conn->calls.rb_node;
while (p) {
call = rb_entry(p, struct rxrpc_call, conn_node);
if (call_id < call->call_id)
p = p->rb_left;
else if (call_id > call->call_id)
p = p->rb_right;
else
goto found_completed_call;
}
dead_call: dead_call:
/* it's a either a really old call that we no longer remember or its a
* new incoming call */
read_unlock_bh(&conn->lock);
if (sp->hdr.flags & RXRPC_CLIENT_INITIATED &&
sp->hdr.seq == cpu_to_be32(1)) {
_debug("incoming call");
skb_queue_tail(&conn->trans->local->accept_queue, skb);
rxrpc_queue_work(&conn->trans->local->acceptor);
goto done;
}
_debug("dead call");
if (sp->hdr.type != RXRPC_PACKET_TYPE_ABORT) { if (sp->hdr.type != RXRPC_PACKET_TYPE_ABORT) {
skb->priority = RX_CALL_DEAD; skb->priority = RX_CALL_DEAD;
rxrpc_reject_packet(conn->trans->local, skb); rxrpc_reject_packet(call->conn->trans->local, skb);
} goto unlock;
goto done;
/* resend last packet of a completed call
* - client calls may have been aborted or ACK'd
* - server calls may have been aborted
*/
found_completed_call:
_debug("completed call");
if (atomic_read(&call->usage) == 0)
goto dead_call;
/* synchronise any state changes */
read_lock(&call->state_lock);
ASSERTIFCMP(call->state != RXRPC_CALL_CLIENT_FINAL_ACK,
call->state, >=, RXRPC_CALL_COMPLETE);
if (call->state == RXRPC_CALL_LOCALLY_ABORTED ||
call->state == RXRPC_CALL_REMOTELY_ABORTED ||
call->state == RXRPC_CALL_DEAD) {
read_unlock(&call->state_lock);
goto dead_call;
}
if (call->conn->in_clientflag) {
read_unlock(&call->state_lock);
goto dead_call; /* complete server call */
} }
_debug("final ack again");
rxrpc_get_call(call);
set_bit(RXRPC_CALL_ACK_FINAL, &call->events);
rxrpc_queue_call(call);
free_unlock: free_unlock:
read_unlock(&call->state_lock);
read_unlock_bh(&conn->lock);
rxrpc_free_skb(skb); rxrpc_free_skb(skb);
unlock:
read_unlock(&call->state_lock);
done: done:
_leave(""); _leave("");
} }
...@@ -663,17 +600,42 @@ static void rxrpc_post_packet_to_conn(struct rxrpc_connection *conn, ...@@ -663,17 +600,42 @@ static void rxrpc_post_packet_to_conn(struct rxrpc_connection *conn,
rxrpc_queue_conn(conn); rxrpc_queue_conn(conn);
} }
static struct rxrpc_connection *rxrpc_conn_from_local(struct rxrpc_local *local,
struct sk_buff *skb,
struct rxrpc_skb_priv *sp)
{
struct rxrpc_peer *peer;
struct rxrpc_transport *trans;
struct rxrpc_connection *conn;
peer = rxrpc_find_peer(local, ip_hdr(skb)->saddr,
udp_hdr(skb)->source);
if (IS_ERR(peer))
goto cant_find_conn;
trans = rxrpc_find_transport(local, peer);
rxrpc_put_peer(peer);
if (!trans)
goto cant_find_conn;
conn = rxrpc_find_connection(trans, &sp->hdr);
rxrpc_put_transport(trans);
if (!conn)
goto cant_find_conn;
return conn;
cant_find_conn:
return NULL;
}
/* /*
* handle data received on the local endpoint * handle data received on the local endpoint
* - may be called in interrupt context * - may be called in interrupt context
*/ */
void rxrpc_data_ready(struct sock *sk, int count) void rxrpc_data_ready(struct sock *sk, int count)
{ {
struct rxrpc_connection *conn;
struct rxrpc_transport *trans;
struct rxrpc_skb_priv *sp; struct rxrpc_skb_priv *sp;
struct rxrpc_local *local; struct rxrpc_local *local;
struct rxrpc_peer *peer;
struct sk_buff *skb; struct sk_buff *skb;
int ret; int ret;
...@@ -748,27 +710,34 @@ void rxrpc_data_ready(struct sock *sk, int count) ...@@ -748,27 +710,34 @@ void rxrpc_data_ready(struct sock *sk, int count)
(sp->hdr.callNumber == 0 || sp->hdr.seq == 0)) (sp->hdr.callNumber == 0 || sp->hdr.seq == 0))
goto bad_message; goto bad_message;
peer = rxrpc_find_peer(local, ip_hdr(skb)->saddr, udp_hdr(skb)->source); if (sp->hdr.callNumber == 0) {
if (IS_ERR(peer)) /* This is a connection-level packet. These should be
goto cant_route_call; * fairly rare, so the extra overhead of looking them up the
* old-fashioned way doesn't really hurt */
trans = rxrpc_find_transport(local, peer); struct rxrpc_connection *conn;
rxrpc_put_peer(peer);
if (!trans)
goto cant_route_call;
conn = rxrpc_find_connection(trans, &sp->hdr); conn = rxrpc_conn_from_local(local, skb, sp);
rxrpc_put_transport(trans);
if (!conn) if (!conn)
goto cant_route_call; goto cant_route_call;
_debug("CONN %p {%d}", conn, conn->debug_id); _debug("CONN %p {%d}", conn, conn->debug_id);
if (sp->hdr.callNumber == 0)
rxrpc_post_packet_to_conn(conn, skb); rxrpc_post_packet_to_conn(conn, skb);
else
rxrpc_post_packet_to_call(conn, skb);
rxrpc_put_connection(conn); rxrpc_put_connection(conn);
} else {
struct rxrpc_call *call;
u8 in_clientflag = 0;
if (sp->hdr.flags & RXRPC_CLIENT_INITIATED)
in_clientflag = RXRPC_CLIENT_INITIATED;
call = rxrpc_find_call_hash(in_clientflag, sp->hdr.cid,
sp->hdr.callNumber, sp->hdr.epoch,
sp->hdr.serviceId, local, AF_INET,
(u8 *)&ip_hdr(skb)->saddr);
if (call)
rxrpc_post_packet_to_call(call, skb);
else
goto cant_route_call;
}
rxrpc_put_local(local); rxrpc_put_local(local);
return; return;
......
...@@ -396,9 +396,20 @@ struct rxrpc_call { ...@@ -396,9 +396,20 @@ struct rxrpc_call {
#define RXRPC_ACKR_WINDOW_ASZ DIV_ROUND_UP(RXRPC_MAXACKS, BITS_PER_LONG) #define RXRPC_ACKR_WINDOW_ASZ DIV_ROUND_UP(RXRPC_MAXACKS, BITS_PER_LONG)
unsigned long ackr_window[RXRPC_ACKR_WINDOW_ASZ + 1]; unsigned long ackr_window[RXRPC_ACKR_WINDOW_ASZ + 1];
struct hlist_node hash_node;
unsigned long hash_key; /* Full hash key */
u8 in_clientflag; /* Copy of conn->in_clientflag for hashing */
struct rxrpc_local *local; /* Local endpoint. Used for hashing. */
sa_family_t proto; /* Frame protocol */
/* the following should all be in net order */ /* the following should all be in net order */
__be32 cid; /* connection ID + channel index */ __be32 cid; /* connection ID + channel index */
__be32 call_id; /* call ID on connection */ __be32 call_id; /* call ID on connection */
__be32 epoch; /* epoch of this connection */
__be16 service_id; /* service ID */
union { /* Peer IP address for hashing */
__be32 ipv4_addr;
__u8 ipv6_addr[16]; /* Anticipates eventual IPv6 support */
} peer_ip;
}; };
/* /*
...@@ -453,6 +464,8 @@ extern struct kmem_cache *rxrpc_call_jar; ...@@ -453,6 +464,8 @@ extern struct kmem_cache *rxrpc_call_jar;
extern struct list_head rxrpc_calls; extern struct list_head rxrpc_calls;
extern rwlock_t rxrpc_call_lock; extern rwlock_t rxrpc_call_lock;
struct rxrpc_call *rxrpc_find_call_hash(u8, __be32, __be32, __be32,
__be16, void *, sa_family_t, const u8 *);
struct rxrpc_call *rxrpc_get_client_call(struct rxrpc_sock *, struct rxrpc_call *rxrpc_get_client_call(struct rxrpc_sock *,
struct rxrpc_transport *, struct rxrpc_transport *,
struct rxrpc_conn_bundle *, struct rxrpc_conn_bundle *,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment