Commit c571feca authored by Jason Gunthorpe's avatar Jason Gunthorpe

RDMA/odp: use mmu_notifier_get/put for 'struct ib_ucontext_per_mm'

This is a significant simplification, no extra list is kept per FD, and
the interval tree is now shared between all the ucontexts, reducing
overhead if there are multiple ucontexts active.

Link: https://lore.kernel.org/r/20190806231548.25242-7-jgg@ziepe.caSigned-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent daa138a5
...@@ -82,7 +82,7 @@ static void ib_umem_notifier_release(struct mmu_notifier *mn, ...@@ -82,7 +82,7 @@ static void ib_umem_notifier_release(struct mmu_notifier *mn,
struct rb_node *node; struct rb_node *node;
down_read(&per_mm->umem_rwsem); down_read(&per_mm->umem_rwsem);
if (!per_mm->active) if (!per_mm->mn.users)
goto out; goto out;
for (node = rb_first_cached(&per_mm->umem_tree); node; for (node = rb_first_cached(&per_mm->umem_tree); node;
...@@ -125,10 +125,10 @@ static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn, ...@@ -125,10 +125,10 @@ static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
else if (!down_read_trylock(&per_mm->umem_rwsem)) else if (!down_read_trylock(&per_mm->umem_rwsem))
return -EAGAIN; return -EAGAIN;
if (!per_mm->active) { if (!per_mm->mn.users) {
up_read(&per_mm->umem_rwsem); up_read(&per_mm->umem_rwsem);
/* /*
* At this point active is permanently set and visible to this * At this point users is permanently zero and visible to this
* CPU without a lock, that fact is relied on to skip the unlock * CPU without a lock, that fact is relied on to skip the unlock
* in range_end. * in range_end.
*/ */
...@@ -158,7 +158,7 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn, ...@@ -158,7 +158,7 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
struct ib_ucontext_per_mm *per_mm = struct ib_ucontext_per_mm *per_mm =
container_of(mn, struct ib_ucontext_per_mm, mn); container_of(mn, struct ib_ucontext_per_mm, mn);
if (unlikely(!per_mm->active)) if (unlikely(!per_mm->mn.users))
return; return;
rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start, rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
...@@ -167,122 +167,47 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn, ...@@ -167,122 +167,47 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
up_read(&per_mm->umem_rwsem); up_read(&per_mm->umem_rwsem);
} }
static const struct mmu_notifier_ops ib_umem_notifiers = { static struct mmu_notifier *ib_umem_alloc_notifier(struct mm_struct *mm)
.release = ib_umem_notifier_release,
.invalidate_range_start = ib_umem_notifier_invalidate_range_start,
.invalidate_range_end = ib_umem_notifier_invalidate_range_end,
};
static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
down_write(&per_mm->umem_rwsem);
interval_tree_remove(&umem_odp->interval_tree, &per_mm->umem_tree);
complete_all(&umem_odp->notifier_completion);
up_write(&per_mm->umem_rwsem);
}
static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
struct mm_struct *mm)
{ {
struct ib_ucontext_per_mm *per_mm; struct ib_ucontext_per_mm *per_mm;
int ret;
per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL); per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
if (!per_mm) if (!per_mm)
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
per_mm->context = ctx;
per_mm->mm = mm;
per_mm->umem_tree = RB_ROOT_CACHED; per_mm->umem_tree = RB_ROOT_CACHED;
init_rwsem(&per_mm->umem_rwsem); init_rwsem(&per_mm->umem_rwsem);
per_mm->active = true;
WARN_ON(mm != current->mm);
rcu_read_lock(); rcu_read_lock();
per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID); per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
rcu_read_unlock(); rcu_read_unlock();
return &per_mm->mn;
WARN_ON(mm != current->mm);
per_mm->mn.ops = &ib_umem_notifiers;
ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
if (ret) {
dev_err(&ctx->device->dev,
"Failed to register mmu_notifier %d\n", ret);
goto out_pid;
}
list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
return per_mm;
out_pid:
put_pid(per_mm->tgid);
kfree(per_mm);
return ERR_PTR(ret);
}
static struct ib_ucontext_per_mm *get_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext *ctx = umem_odp->umem.context;
struct ib_ucontext_per_mm *per_mm;
lockdep_assert_held(&ctx->per_mm_list_lock);
/*
* Generally speaking we expect only one or two per_mm in this list,
* so no reason to optimize this search today.
*/
list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
if (per_mm->mm == umem_odp->umem.owning_mm)
return per_mm;
}
return alloc_per_mm(ctx, umem_odp->umem.owning_mm);
}
static void free_per_mm(struct rcu_head *rcu)
{
kfree(container_of(rcu, struct ib_ucontext_per_mm, rcu));
} }
static void put_per_mm(struct ib_umem_odp *umem_odp) static void ib_umem_free_notifier(struct mmu_notifier *mn)
{ {
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm; struct ib_ucontext_per_mm *per_mm =
struct ib_ucontext *ctx = umem_odp->umem.context; container_of(mn, struct ib_ucontext_per_mm, mn);
bool need_free;
mutex_lock(&ctx->per_mm_list_lock);
umem_odp->per_mm = NULL;
per_mm->odp_mrs_count--;
need_free = per_mm->odp_mrs_count == 0;
if (need_free)
list_del(&per_mm->ucontext_list);
mutex_unlock(&ctx->per_mm_list_lock);
if (!need_free)
return;
/*
* NOTE! mmu_notifier_unregister() can happen between a start/end
* callback, resulting in an start/end, and thus an unbalanced
* lock. This doesn't really matter to us since we are about to kfree
* the memory that holds the lock, however LOCKDEP doesn't like this.
*/
down_write(&per_mm->umem_rwsem);
per_mm->active = false;
up_write(&per_mm->umem_rwsem);
WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root)); WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
mmu_notifier_unregister_no_release(&per_mm->mn, per_mm->mm);
put_pid(per_mm->tgid); put_pid(per_mm->tgid);
mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm); kfree(per_mm);
} }
static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp, static const struct mmu_notifier_ops ib_umem_notifiers = {
struct ib_ucontext_per_mm *per_mm) .release = ib_umem_notifier_release,
.invalidate_range_start = ib_umem_notifier_invalidate_range_start,
.invalidate_range_end = ib_umem_notifier_invalidate_range_end,
.alloc_notifier = ib_umem_alloc_notifier,
.free_notifier = ib_umem_free_notifier,
};
static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp)
{ {
struct ib_ucontext *ctx = umem_odp->umem.context; struct ib_ucontext_per_mm *per_mm;
struct mmu_notifier *mn;
int ret; int ret;
umem_odp->umem.is_odp = 1; umem_odp->umem.is_odp = 1;
...@@ -327,17 +252,13 @@ static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp, ...@@ -327,17 +252,13 @@ static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
} }
} }
mutex_lock(&ctx->per_mm_list_lock); mn = mmu_notifier_get(&ib_umem_notifiers, umem_odp->umem.owning_mm);
if (!per_mm) { if (IS_ERR(mn)) {
per_mm = get_per_mm(umem_odp); ret = PTR_ERR(mn);
if (IS_ERR(per_mm)) { goto out_dma_list;
ret = PTR_ERR(per_mm);
goto out_unlock;
}
} }
umem_odp->per_mm = per_mm; umem_odp->per_mm = per_mm =
per_mm->odp_mrs_count++; container_of(mn, struct ib_ucontext_per_mm, mn);
mutex_unlock(&ctx->per_mm_list_lock);
mutex_init(&umem_odp->umem_mutex); mutex_init(&umem_odp->umem_mutex);
init_completion(&umem_odp->notifier_completion); init_completion(&umem_odp->notifier_completion);
...@@ -352,8 +273,7 @@ static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp, ...@@ -352,8 +273,7 @@ static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
return 0; return 0;
out_unlock: out_dma_list:
mutex_unlock(&ctx->per_mm_list_lock);
kvfree(umem_odp->dma_list); kvfree(umem_odp->dma_list);
out_page_list: out_page_list:
kvfree(umem_odp->page_list); kvfree(umem_odp->page_list);
...@@ -398,7 +318,7 @@ struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata, ...@@ -398,7 +318,7 @@ struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
umem_odp->is_implicit_odp = 1; umem_odp->is_implicit_odp = 1;
umem_odp->page_shift = PAGE_SHIFT; umem_odp->page_shift = PAGE_SHIFT;
ret = ib_init_umem_odp(umem_odp, NULL); ret = ib_init_umem_odp(umem_odp);
if (ret) { if (ret) {
kfree(umem_odp); kfree(umem_odp);
return ERR_PTR(ret); return ERR_PTR(ret);
...@@ -441,7 +361,7 @@ struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root, ...@@ -441,7 +361,7 @@ struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root,
umem->owning_mm = root->umem.owning_mm; umem->owning_mm = root->umem.owning_mm;
odp_data->page_shift = PAGE_SHIFT; odp_data->page_shift = PAGE_SHIFT;
ret = ib_init_umem_odp(odp_data, root->per_mm); ret = ib_init_umem_odp(odp_data);
if (ret) { if (ret) {
kfree(odp_data); kfree(odp_data);
return ERR_PTR(ret); return ERR_PTR(ret);
...@@ -509,7 +429,7 @@ struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr, ...@@ -509,7 +429,7 @@ struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
up_read(&mm->mmap_sem); up_read(&mm->mmap_sem);
} }
ret = ib_init_umem_odp(umem_odp, NULL); ret = ib_init_umem_odp(umem_odp);
if (ret) if (ret)
goto err_free; goto err_free;
return umem_odp; return umem_odp;
...@@ -522,6 +442,8 @@ EXPORT_SYMBOL(ib_umem_odp_get); ...@@ -522,6 +442,8 @@ EXPORT_SYMBOL(ib_umem_odp_get);
void ib_umem_odp_release(struct ib_umem_odp *umem_odp) void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
{ {
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
/* /*
* Ensure that no more pages are mapped in the umem. * Ensure that no more pages are mapped in the umem.
* *
...@@ -531,11 +453,27 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) ...@@ -531,11 +453,27 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
if (!umem_odp->is_implicit_odp) { if (!umem_odp->is_implicit_odp) {
ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp), ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
ib_umem_end(umem_odp)); ib_umem_end(umem_odp));
remove_umem_from_per_mm(umem_odp);
kvfree(umem_odp->dma_list); kvfree(umem_odp->dma_list);
kvfree(umem_odp->page_list); kvfree(umem_odp->page_list);
} }
put_per_mm(umem_odp);
down_write(&per_mm->umem_rwsem);
if (!umem_odp->is_implicit_odp) {
interval_tree_remove(&umem_odp->interval_tree,
&per_mm->umem_tree);
complete_all(&umem_odp->notifier_completion);
}
/*
* NOTE! mmu_notifier_unregister() can happen between a start/end
* callback, resulting in a missing end, and thus an unbalanced
* lock. This doesn't really matter to us since we are about to kfree
* the memory that holds the lock, however LOCKDEP doesn't like this.
* Thus we call the mmu_notifier_put under the rwsem and test the
* internal users count to reliably see if we are past this point.
*/
mmu_notifier_put(&per_mm->mn);
up_write(&per_mm->umem_rwsem);
mmdrop(umem_odp->umem.owning_mm); mmdrop(umem_odp->umem.owning_mm);
kfree(umem_odp); kfree(umem_odp);
} }
......
...@@ -252,9 +252,6 @@ static int ib_uverbs_get_context(struct uverbs_attr_bundle *attrs) ...@@ -252,9 +252,6 @@ static int ib_uverbs_get_context(struct uverbs_attr_bundle *attrs)
ucontext->closing = false; ucontext->closing = false;
ucontext->cleanup_retryable = false; ucontext->cleanup_retryable = false;
mutex_init(&ucontext->per_mm_list_lock);
INIT_LIST_HEAD(&ucontext->per_mm_list);
ret = get_unused_fd_flags(O_CLOEXEC); ret = get_unused_fd_flags(O_CLOEXEC);
if (ret < 0) if (ret < 0)
goto err_free; goto err_free;
......
...@@ -1487,6 +1487,7 @@ static void __exit ib_uverbs_cleanup(void) ...@@ -1487,6 +1487,7 @@ static void __exit ib_uverbs_cleanup(void)
IB_UVERBS_NUM_FIXED_MINOR); IB_UVERBS_NUM_FIXED_MINOR);
unregister_chrdev_region(dynamic_uverbs_dev, unregister_chrdev_region(dynamic_uverbs_dev,
IB_UVERBS_NUM_DYNAMIC_MINOR); IB_UVERBS_NUM_DYNAMIC_MINOR);
mmu_notifier_synchronize();
} }
module_init(ib_uverbs_init); module_init(ib_uverbs_init);
......
...@@ -1995,11 +1995,6 @@ static void mlx5_ib_dealloc_ucontext(struct ib_ucontext *ibcontext) ...@@ -1995,11 +1995,6 @@ static void mlx5_ib_dealloc_ucontext(struct ib_ucontext *ibcontext)
struct mlx5_ib_dev *dev = to_mdev(ibcontext->device); struct mlx5_ib_dev *dev = to_mdev(ibcontext->device);
struct mlx5_bfreg_info *bfregi; struct mlx5_bfreg_info *bfregi;
/* All umem's must be destroyed before destroying the ucontext. */
mutex_lock(&ibcontext->per_mm_list_lock);
WARN_ON(!list_empty(&ibcontext->per_mm_list));
mutex_unlock(&ibcontext->per_mm_list_lock);
bfregi = &context->bfregi; bfregi = &context->bfregi;
mlx5_ib_dealloc_transport_domain(dev, context->tdn, context->devx_uid); mlx5_ib_dealloc_transport_domain(dev, context->tdn, context->devx_uid);
......
...@@ -122,20 +122,12 @@ static inline size_t ib_umem_odp_num_pages(struct ib_umem_odp *umem_odp) ...@@ -122,20 +122,12 @@ static inline size_t ib_umem_odp_num_pages(struct ib_umem_odp *umem_odp)
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
struct ib_ucontext_per_mm { struct ib_ucontext_per_mm {
struct ib_ucontext *context; struct mmu_notifier mn;
struct mm_struct *mm;
struct pid *tgid; struct pid *tgid;
bool active;
struct rb_root_cached umem_tree; struct rb_root_cached umem_tree;
/* Protects umem_tree */ /* Protects umem_tree */
struct rw_semaphore umem_rwsem; struct rw_semaphore umem_rwsem;
struct mmu_notifier mn;
unsigned int odp_mrs_count;
struct list_head ucontext_list;
struct rcu_head rcu;
}; };
struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr, struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
......
...@@ -1417,9 +1417,6 @@ struct ib_ucontext { ...@@ -1417,9 +1417,6 @@ struct ib_ucontext {
bool cleanup_retryable; bool cleanup_retryable;
struct mutex per_mm_list_lock;
struct list_head per_mm_list;
struct ib_rdmacg_object cg_obj; struct ib_rdmacg_object cg_obj;
/* /*
* Implementation details of the RDMA core, don't use in drivers: * Implementation details of the RDMA core, don't use in drivers:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment