Commit adc6a3ab authored by NeilBrown's avatar NeilBrown Committed by David S. Miller

rhashtable: move dereference inside rht_ptr()

Rather than dereferencing a pointer to a bucket and then passing the
result to rht_ptr(), we now pass in the pointer and do the dereference
in rht_ptr().

This requires that we pass in the tbl and hash as well to support RCU
checks, and means that the various rht_for_each functions can expect a
pointer that can be dereferenced without further care.

There are two places where we dereference a bucket pointer
where there is no testable protection - in each case we know
that we much have exclusive access without having taken a lock.
The previous code used rht_dereference() to pretend that holding
the mutex provided protects, but holding the mutex never provides
protection for accessing buckets.

So instead introduce rht_ptr_exclusive() that can be used when
there is known to be exclusive access without holding any locks.
Signed-off-by: default avatarNeilBrown <neilb@suse.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent c5783311
...@@ -344,12 +344,28 @@ static inline void rht_unlock(struct bucket_table *tbl, ...@@ -344,12 +344,28 @@ static inline void rht_unlock(struct bucket_table *tbl,
} }
/* /*
* If 'p' is a bucket head and might be locked: * Where 'bkt' is a bucket and might be locked:
* rht_ptr() returns the address without the lock bit. * rht_ptr() dereferences that pointer and clears the lock bit.
* rht_ptr_locked() returns the address WITH the lock bit. * rht_ptr_exclusive() dereferences in a context where exclusive
* access is guaranteed, such as when destroying the table.
*/ */
static inline struct rhash_head __rcu *rht_ptr(const struct rhash_lock_head *p) static inline struct rhash_head *rht_ptr(
struct rhash_lock_head __rcu * const *bkt,
struct bucket_table *tbl,
unsigned int hash)
{ {
const struct rhash_lock_head *p =
rht_dereference_bucket_rcu(*bkt, tbl, hash);
return (void *)(((unsigned long)p) & ~BIT(1));
}
static inline struct rhash_head *rht_ptr_exclusive(
struct rhash_lock_head __rcu * const *bkt)
{
const struct rhash_lock_head *p =
rcu_dereference_protected(*bkt, 1);
return (void *)(((unsigned long)p) & ~BIT(1)); return (void *)(((unsigned long)p) & ~BIT(1));
} }
...@@ -380,8 +396,8 @@ static inline void rht_assign_unlock(struct bucket_table *tbl, ...@@ -380,8 +396,8 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
* @hash: the hash value / bucket index * @hash: the hash value / bucket index
*/ */
#define rht_for_each_from(pos, head, tbl, hash) \ #define rht_for_each_from(pos, head, tbl, hash) \
for (pos = rht_dereference_bucket(head, tbl, hash); \ for (pos = head; \
!rht_is_a_nulls(pos); \ !rht_is_a_nulls(pos); \
pos = rht_dereference_bucket((pos)->next, tbl, hash)) pos = rht_dereference_bucket((pos)->next, tbl, hash))
/** /**
...@@ -391,7 +407,8 @@ static inline void rht_assign_unlock(struct bucket_table *tbl, ...@@ -391,7 +407,8 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
* @hash: the hash value / bucket index * @hash: the hash value / bucket index
*/ */
#define rht_for_each(pos, tbl, hash) \ #define rht_for_each(pos, tbl, hash) \
rht_for_each_from(pos, rht_ptr(*rht_bucket(tbl, hash)), tbl, hash) rht_for_each_from(pos, rht_ptr(rht_bucket(tbl, hash), tbl, hash), \
tbl, hash)
/** /**
* rht_for_each_entry_from - iterate over hash chain from given head * rht_for_each_entry_from - iterate over hash chain from given head
...@@ -403,7 +420,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl, ...@@ -403,7 +420,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
* @member: name of the &struct rhash_head within the hashable struct. * @member: name of the &struct rhash_head within the hashable struct.
*/ */
#define rht_for_each_entry_from(tpos, pos, head, tbl, hash, member) \ #define rht_for_each_entry_from(tpos, pos, head, tbl, hash, member) \
for (pos = rht_dereference_bucket(head, tbl, hash); \ for (pos = head; \
(!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \ (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \
pos = rht_dereference_bucket((pos)->next, tbl, hash)) pos = rht_dereference_bucket((pos)->next, tbl, hash))
...@@ -416,8 +433,9 @@ static inline void rht_assign_unlock(struct bucket_table *tbl, ...@@ -416,8 +433,9 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
* @member: name of the &struct rhash_head within the hashable struct. * @member: name of the &struct rhash_head within the hashable struct.
*/ */
#define rht_for_each_entry(tpos, pos, tbl, hash, member) \ #define rht_for_each_entry(tpos, pos, tbl, hash, member) \
rht_for_each_entry_from(tpos, pos, rht_ptr(*rht_bucket(tbl, hash)), \ rht_for_each_entry_from(tpos, pos, \
tbl, hash, member) rht_ptr(rht_bucket(tbl, hash), tbl, hash), \
tbl, hash, member)
/** /**
* rht_for_each_entry_safe - safely iterate over hash chain of given type * rht_for_each_entry_safe - safely iterate over hash chain of given type
...@@ -432,8 +450,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl, ...@@ -432,8 +450,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
* remove the loop cursor from the list. * remove the loop cursor from the list.
*/ */
#define rht_for_each_entry_safe(tpos, pos, next, tbl, hash, member) \ #define rht_for_each_entry_safe(tpos, pos, next, tbl, hash, member) \
for (pos = rht_dereference_bucket(rht_ptr(*rht_bucket(tbl, hash)), \ for (pos = rht_ptr(rht_bucket(tbl, hash), tbl, hash), \
tbl, hash), \
next = !rht_is_a_nulls(pos) ? \ next = !rht_is_a_nulls(pos) ? \
rht_dereference_bucket(pos->next, tbl, hash) : NULL; \ rht_dereference_bucket(pos->next, tbl, hash) : NULL; \
(!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \ (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \
...@@ -454,7 +471,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl, ...@@ -454,7 +471,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
*/ */
#define rht_for_each_rcu_from(pos, head, tbl, hash) \ #define rht_for_each_rcu_from(pos, head, tbl, hash) \
for (({barrier(); }), \ for (({barrier(); }), \
pos = rht_dereference_bucket_rcu(head, tbl, hash); \ pos = head; \
!rht_is_a_nulls(pos); \ !rht_is_a_nulls(pos); \
pos = rcu_dereference_raw(pos->next)) pos = rcu_dereference_raw(pos->next))
...@@ -469,10 +486,9 @@ static inline void rht_assign_unlock(struct bucket_table *tbl, ...@@ -469,10 +486,9 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
* traversal is guarded by rcu_read_lock(). * traversal is guarded by rcu_read_lock().
*/ */
#define rht_for_each_rcu(pos, tbl, hash) \ #define rht_for_each_rcu(pos, tbl, hash) \
for (({barrier(); }), \ for (({barrier(); }), \
pos = rht_ptr(rht_dereference_bucket_rcu( \ pos = rht_ptr(rht_bucket(tbl, hash), tbl, hash); \
*rht_bucket(tbl, hash), tbl, hash)); \ !rht_is_a_nulls(pos); \
!rht_is_a_nulls(pos); \
pos = rcu_dereference_raw(pos->next)) pos = rcu_dereference_raw(pos->next))
/** /**
...@@ -490,7 +506,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl, ...@@ -490,7 +506,7 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
*/ */
#define rht_for_each_entry_rcu_from(tpos, pos, head, tbl, hash, member) \ #define rht_for_each_entry_rcu_from(tpos, pos, head, tbl, hash, member) \
for (({barrier(); }), \ for (({barrier(); }), \
pos = rht_dereference_bucket_rcu(head, tbl, hash); \ pos = head; \
(!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \ (!rht_is_a_nulls(pos)) && rht_entry(tpos, pos, member); \
pos = rht_dereference_bucket_rcu(pos->next, tbl, hash)) pos = rht_dereference_bucket_rcu(pos->next, tbl, hash))
...@@ -508,8 +524,9 @@ static inline void rht_assign_unlock(struct bucket_table *tbl, ...@@ -508,8 +524,9 @@ static inline void rht_assign_unlock(struct bucket_table *tbl,
*/ */
#define rht_for_each_entry_rcu(tpos, pos, tbl, hash, member) \ #define rht_for_each_entry_rcu(tpos, pos, tbl, hash, member) \
rht_for_each_entry_rcu_from(tpos, pos, \ rht_for_each_entry_rcu_from(tpos, pos, \
rht_ptr(*rht_bucket(tbl, hash)), \ rht_ptr(rht_bucket(tbl, hash), \
tbl, hash, member) tbl, hash), \
tbl, hash, member)
/** /**
* rhl_for_each_rcu - iterate over rcu hash table list * rhl_for_each_rcu - iterate over rcu hash table list
...@@ -556,7 +573,6 @@ static inline struct rhash_head *__rhashtable_lookup( ...@@ -556,7 +573,6 @@ static inline struct rhash_head *__rhashtable_lookup(
}; };
struct rhash_lock_head __rcu * const *bkt; struct rhash_lock_head __rcu * const *bkt;
struct bucket_table *tbl; struct bucket_table *tbl;
struct rhash_head __rcu *head;
struct rhash_head *he; struct rhash_head *he;
unsigned int hash; unsigned int hash;
...@@ -565,8 +581,7 @@ static inline struct rhash_head *__rhashtable_lookup( ...@@ -565,8 +581,7 @@ static inline struct rhash_head *__rhashtable_lookup(
hash = rht_key_hashfn(ht, tbl, key, params); hash = rht_key_hashfn(ht, tbl, key, params);
bkt = rht_bucket(tbl, hash); bkt = rht_bucket(tbl, hash);
do { do {
head = rht_ptr(rht_dereference_bucket_rcu(*bkt, tbl, hash)); rht_for_each_rcu_from(he, rht_ptr(bkt, tbl, hash), tbl, hash) {
rht_for_each_rcu_from(he, head, tbl, hash) {
if (params.obj_cmpfn ? if (params.obj_cmpfn ?
params.obj_cmpfn(&arg, rht_obj(ht, he)) : params.obj_cmpfn(&arg, rht_obj(ht, he)) :
rhashtable_compare(&arg, rht_obj(ht, he))) rhashtable_compare(&arg, rht_obj(ht, he)))
...@@ -699,7 +714,7 @@ static inline void *__rhashtable_insert_fast( ...@@ -699,7 +714,7 @@ static inline void *__rhashtable_insert_fast(
return rhashtable_insert_slow(ht, key, obj); return rhashtable_insert_slow(ht, key, obj);
} }
rht_for_each_from(head, rht_ptr(*bkt), tbl, hash) { rht_for_each_from(head, rht_ptr(bkt, tbl, hash), tbl, hash) {
struct rhlist_head *plist; struct rhlist_head *plist;
struct rhlist_head *list; struct rhlist_head *list;
...@@ -744,7 +759,7 @@ static inline void *__rhashtable_insert_fast( ...@@ -744,7 +759,7 @@ static inline void *__rhashtable_insert_fast(
goto slow_path; goto slow_path;
/* Inserting at head of list makes unlocking free. */ /* Inserting at head of list makes unlocking free. */
head = rht_ptr(rht_dereference_bucket(*bkt, tbl, hash)); head = rht_ptr(bkt, tbl, hash);
RCU_INIT_POINTER(obj->next, head); RCU_INIT_POINTER(obj->next, head);
if (rhlist) { if (rhlist) {
...@@ -971,7 +986,7 @@ static inline int __rhashtable_remove_fast_one( ...@@ -971,7 +986,7 @@ static inline int __rhashtable_remove_fast_one(
pprev = NULL; pprev = NULL;
rht_lock(tbl, bkt); rht_lock(tbl, bkt);
rht_for_each_from(he, rht_ptr(*bkt), tbl, hash) { rht_for_each_from(he, rht_ptr(bkt, tbl, hash), tbl, hash) {
struct rhlist_head *list; struct rhlist_head *list;
list = container_of(he, struct rhlist_head, rhead); list = container_of(he, struct rhlist_head, rhead);
...@@ -1130,7 +1145,7 @@ static inline int __rhashtable_replace_fast( ...@@ -1130,7 +1145,7 @@ static inline int __rhashtable_replace_fast(
pprev = NULL; pprev = NULL;
rht_lock(tbl, bkt); rht_lock(tbl, bkt);
rht_for_each_from(he, rht_ptr(*bkt), tbl, hash) { rht_for_each_from(he, rht_ptr(bkt, tbl, hash), tbl, hash) {
if (he != obj_old) { if (he != obj_old) {
pprev = &he->next; pprev = &he->next;
continue; continue;
......
...@@ -231,7 +231,8 @@ static int rhashtable_rehash_one(struct rhashtable *ht, ...@@ -231,7 +231,8 @@ static int rhashtable_rehash_one(struct rhashtable *ht,
err = -ENOENT; err = -ENOENT;
rht_for_each_from(entry, rht_ptr(*bkt), old_tbl, old_hash) { rht_for_each_from(entry, rht_ptr(bkt, old_tbl, old_hash),
old_tbl, old_hash) {
err = 0; err = 0;
next = rht_dereference_bucket(entry->next, old_tbl, old_hash); next = rht_dereference_bucket(entry->next, old_tbl, old_hash);
...@@ -248,8 +249,7 @@ static int rhashtable_rehash_one(struct rhashtable *ht, ...@@ -248,8 +249,7 @@ static int rhashtable_rehash_one(struct rhashtable *ht,
rht_lock_nested(new_tbl, &new_tbl->buckets[new_hash], SINGLE_DEPTH_NESTING); rht_lock_nested(new_tbl, &new_tbl->buckets[new_hash], SINGLE_DEPTH_NESTING);
head = rht_ptr(rht_dereference_bucket(new_tbl->buckets[new_hash], head = rht_ptr(new_tbl->buckets + new_hash, new_tbl, new_hash);
new_tbl, new_hash));
RCU_INIT_POINTER(entry->next, head); RCU_INIT_POINTER(entry->next, head);
...@@ -491,7 +491,7 @@ static void *rhashtable_lookup_one(struct rhashtable *ht, ...@@ -491,7 +491,7 @@ static void *rhashtable_lookup_one(struct rhashtable *ht,
int elasticity; int elasticity;
elasticity = RHT_ELASTICITY; elasticity = RHT_ELASTICITY;
rht_for_each_from(head, rht_ptr(*bkt), tbl, hash) { rht_for_each_from(head, rht_ptr(bkt, tbl, hash), tbl, hash) {
struct rhlist_head *list; struct rhlist_head *list;
struct rhlist_head *plist; struct rhlist_head *plist;
...@@ -557,7 +557,7 @@ static struct bucket_table *rhashtable_insert_one(struct rhashtable *ht, ...@@ -557,7 +557,7 @@ static struct bucket_table *rhashtable_insert_one(struct rhashtable *ht,
if (unlikely(rht_grow_above_100(ht, tbl))) if (unlikely(rht_grow_above_100(ht, tbl)))
return ERR_PTR(-EAGAIN); return ERR_PTR(-EAGAIN);
head = rht_ptr(rht_dereference_bucket(*bkt, tbl, hash)); head = rht_ptr(bkt, tbl, hash);
RCU_INIT_POINTER(obj->next, head); RCU_INIT_POINTER(obj->next, head);
if (ht->rhlist) { if (ht->rhlist) {
...@@ -1139,7 +1139,7 @@ void rhashtable_free_and_destroy(struct rhashtable *ht, ...@@ -1139,7 +1139,7 @@ void rhashtable_free_and_destroy(struct rhashtable *ht,
struct rhash_head *pos, *next; struct rhash_head *pos, *next;
cond_resched(); cond_resched();
for (pos = rht_ptr(rht_dereference(*rht_bucket(tbl, i), ht)), for (pos = rht_ptr_exclusive(rht_bucket(tbl, i)),
next = !rht_is_a_nulls(pos) ? next = !rht_is_a_nulls(pos) ?
rht_dereference(pos->next, ht) : NULL; rht_dereference(pos->next, ht) : NULL;
!rht_is_a_nulls(pos); !rht_is_a_nulls(pos);
......
...@@ -500,7 +500,7 @@ static unsigned int __init print_ht(struct rhltable *rhlt) ...@@ -500,7 +500,7 @@ static unsigned int __init print_ht(struct rhltable *rhlt)
struct rhash_head *pos, *next; struct rhash_head *pos, *next;
struct test_obj_rhl *p; struct test_obj_rhl *p;
pos = rht_ptr(rht_dereference(tbl->buckets[i], ht)); pos = rht_ptr_exclusive(tbl->buckets + i);
next = !rht_is_a_nulls(pos) ? rht_dereference(pos->next, ht) : NULL; next = !rht_is_a_nulls(pos) ? rht_dereference(pos->next, ht) : NULL;
if (!rht_is_a_nulls(pos)) { if (!rht_is_a_nulls(pos)) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment