Commit f27a0d50 authored by Jason Gunthorpe's avatar Jason Gunthorpe Committed by Doug Ledford

RDMA/umem: Use umem->owning_mm inside ODP

Since ODP had a single struct mmu_notifier located in the ucontext it
could only handle a single MM at a time, and this prevented it from using
the new owning_mm system.

With the prior rework it is now simple to let ODP track multiple MMs per
ucontext, finish the job so that the per_mm is allocated on a mm by mm
basis, and freed when the last umem is dropped from the ucontext.

As a side effect the new saner locking removes the lockdep splat about
nesting the umem_rwsem between mmu_notifier_unregister and
ib_umem_odp_release.

It also makes ODP work with multiple processes, across, fork, etc.
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Signed-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarDoug Ledford <dledford@redhat.com>
parent c9990ab3
...@@ -278,10 +278,135 @@ static const struct mmu_notifier_ops ib_umem_notifiers = { ...@@ -278,10 +278,135 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
.invalidate_range_end = ib_umem_notifier_invalidate_range_end, .invalidate_range_end = ib_umem_notifier_invalidate_range_end,
}; };
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
unsigned long addr, size_t size) {
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
struct ib_umem *umem = &umem_odp->umem;
down_write(&per_mm->umem_rwsem);
if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
rbt_ib_umem_insert(&umem_odp->interval_tree,
&per_mm->umem_tree);
if (likely(!atomic_read(&per_mm->notifier_count)))
umem_odp->mn_counters_active = true;
else
list_add(&umem_odp->no_private_counters,
&per_mm->no_private_counters);
up_write(&per_mm->umem_rwsem);
}
static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
struct ib_umem *umem = &umem_odp->umem;
down_write(&per_mm->umem_rwsem);
if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
rbt_ib_umem_remove(&umem_odp->interval_tree,
&per_mm->umem_tree);
if (!umem_odp->mn_counters_active) {
list_del(&umem_odp->no_private_counters);
complete_all(&umem_odp->notifier_completion);
}
up_write(&per_mm->umem_rwsem);
}
static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
struct mm_struct *mm)
{ {
struct ib_ucontext_per_mm *per_mm; struct ib_ucontext_per_mm *per_mm;
int ret;
per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
if (!per_mm)
return ERR_PTR(-ENOMEM);
per_mm->context = ctx;
per_mm->mm = mm;
per_mm->umem_tree = RB_ROOT_CACHED;
init_rwsem(&per_mm->umem_rwsem);
INIT_LIST_HEAD(&per_mm->no_private_counters);
rcu_read_lock();
per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
rcu_read_unlock();
WARN_ON(mm != current->mm);
per_mm->mn.ops = &ib_umem_notifiers;
ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
if (ret) {
dev_err(&ctx->device->dev,
"Failed to register mmu_notifier %d\n", ret);
goto out_pid;
}
list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
return per_mm;
out_pid:
put_pid(per_mm->tgid);
kfree(per_mm);
return ERR_PTR(ret);
}
static int get_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext *ctx = umem_odp->umem.context;
struct ib_ucontext_per_mm *per_mm;
/*
* Generally speaking we expect only one or two per_mm in this list,
* so no reason to optimize this search today.
*/
mutex_lock(&ctx->per_mm_list_lock);
list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
if (per_mm->mm == umem_odp->umem.owning_mm)
goto found;
}
per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
if (IS_ERR(per_mm)) {
mutex_unlock(&ctx->per_mm_list_lock);
return PTR_ERR(per_mm);
}
found:
umem_odp->per_mm = per_mm;
per_mm->odp_mrs_count++;
mutex_unlock(&ctx->per_mm_list_lock);
return 0;
}
void put_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
struct ib_ucontext *ctx = umem_odp->umem.context;
bool need_free;
mutex_lock(&ctx->per_mm_list_lock);
umem_odp->per_mm = NULL;
per_mm->odp_mrs_count--;
need_free = per_mm->odp_mrs_count == 0;
if (need_free)
list_del(&per_mm->ucontext_list);
mutex_unlock(&ctx->per_mm_list_lock);
if (!need_free)
return;
mmu_notifier_unregister(&per_mm->mn, per_mm->mm);
put_pid(per_mm->tgid);
kfree(per_mm);
}
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
unsigned long addr, size_t size)
{
struct ib_ucontext *ctx = per_mm->context;
struct ib_umem_odp *odp_data; struct ib_umem_odp *odp_data;
struct ib_umem *umem; struct ib_umem *umem;
int pages = size >> PAGE_SHIFT; int pages = size >> PAGE_SHIFT;
...@@ -291,13 +416,13 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, ...@@ -291,13 +416,13 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
if (!odp_data) if (!odp_data)
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
umem = &odp_data->umem; umem = &odp_data->umem;
umem->context = context; umem->context = ctx;
umem->length = size; umem->length = size;
umem->address = addr; umem->address = addr;
umem->page_shift = PAGE_SHIFT; umem->page_shift = PAGE_SHIFT;
umem->writable = 1; umem->writable = 1;
umem->is_odp = 1; umem->is_odp = 1;
odp_data->per_mm = per_mm = &context->per_mm; odp_data->per_mm = per_mm;
mutex_init(&odp_data->umem_mutex); mutex_init(&odp_data->umem_mutex);
init_completion(&odp_data->notifier_completion); init_completion(&odp_data->notifier_completion);
...@@ -316,15 +441,14 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, ...@@ -316,15 +441,14 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context,
goto out_page_list; goto out_page_list;
} }
down_write(&per_mm->umem_rwsem); /*
* Caller must ensure that the umem_odp that the per_mm came from
* cannot be freed during the call to ib_alloc_odp_umem.
*/
mutex_lock(&ctx->per_mm_list_lock);
per_mm->odp_mrs_count++; per_mm->odp_mrs_count++;
rbt_ib_umem_insert(&odp_data->interval_tree, &per_mm->umem_tree); mutex_unlock(&ctx->per_mm_list_lock);
if (likely(!atomic_read(&per_mm->notifier_count))) add_umem_to_per_mm(odp_data);
odp_data->mn_counters_active = true;
else
list_add(&odp_data->no_private_counters,
&per_mm->no_private_counters);
up_write(&per_mm->umem_rwsem);
return odp_data; return odp_data;
...@@ -338,15 +462,13 @@ EXPORT_SYMBOL(ib_alloc_odp_umem); ...@@ -338,15 +462,13 @@ EXPORT_SYMBOL(ib_alloc_odp_umem);
int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
{ {
struct ib_ucontext *context = umem_odp->umem.context;
struct ib_umem *umem = &umem_odp->umem; struct ib_umem *umem = &umem_odp->umem;
struct ib_ucontext_per_mm *per_mm; /*
* NOTE: This must called in a process context where umem->owning_mm
* == current->mm
*/
struct mm_struct *mm = umem->owning_mm;
int ret_val; int ret_val;
struct pid *our_pid;
struct mm_struct *mm = get_task_mm(current);
if (!mm)
return -EINVAL;
if (access & IB_ACCESS_HUGETLB) { if (access & IB_ACCESS_HUGETLB) {
struct vm_area_struct *vma; struct vm_area_struct *vma;
...@@ -366,16 +488,6 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) ...@@ -366,16 +488,6 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
umem->hugetlb = 0; umem->hugetlb = 0;
} }
/* Prevent creating ODP MRs in child processes */
rcu_read_lock();
our_pid = get_task_pid(current->group_leader, PIDTYPE_PID);
rcu_read_unlock();
put_pid(our_pid);
if (context->tgid != our_pid) {
ret_val = -EINVAL;
goto out_mm;
}
mutex_init(&umem_odp->umem_mutex); mutex_init(&umem_odp->umem_mutex);
init_completion(&umem_odp->notifier_completion); init_completion(&umem_odp->notifier_completion);
...@@ -384,10 +496,8 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) ...@@ -384,10 +496,8 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
umem_odp->page_list = umem_odp->page_list =
vzalloc(array_size(sizeof(*umem_odp->page_list), vzalloc(array_size(sizeof(*umem_odp->page_list),
ib_umem_num_pages(umem))); ib_umem_num_pages(umem)));
if (!umem_odp->page_list) { if (!umem_odp->page_list)
ret_val = -ENOMEM; return -ENOMEM;
goto out_mm;
}
umem_odp->dma_list = umem_odp->dma_list =
vzalloc(array_size(sizeof(*umem_odp->dma_list), vzalloc(array_size(sizeof(*umem_odp->dma_list),
...@@ -398,67 +508,23 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access) ...@@ -398,67 +508,23 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
} }
} }
/* ret_val = get_per_mm(umem_odp);
* When using MMU notifiers, we will get a if (ret_val)
* notification before the "current" task (and MM) is goto out_dma_list;
* destroyed. We use the umem_rwsem semaphore to synchronize. add_umem_to_per_mm(umem_odp);
*/
umem_odp->per_mm = per_mm = &context->per_mm;
down_write(&per_mm->umem_rwsem);
per_mm->odp_mrs_count++;
if (likely(ib_umem_start(umem) != ib_umem_end(umem)))
rbt_ib_umem_insert(&umem_odp->interval_tree,
&per_mm->umem_tree);
if (likely(!atomic_read(&per_mm->notifier_count)) ||
per_mm->odp_mrs_count == 1)
umem_odp->mn_counters_active = true;
else
list_add(&umem_odp->no_private_counters,
&per_mm->no_private_counters);
downgrade_write(&per_mm->umem_rwsem);
if (per_mm->odp_mrs_count == 1) {
/*
* Note that at this point, no MMU notifier is running
* for this per_mm!
*/
atomic_set(&per_mm->notifier_count, 0);
INIT_HLIST_NODE(&per_mm->mn.hlist);
per_mm->mn.ops = &ib_umem_notifiers;
ret_val = mmu_notifier_register(&per_mm->mn, mm);
if (ret_val) {
pr_err("Failed to register mmu_notifier %d\n", ret_val);
ret_val = -EBUSY;
goto out_mutex;
}
}
up_read(&per_mm->umem_rwsem);
/*
* Note that doing an mmput can cause a notifier for the relevant mm.
* If the notifier is called while we hold the umem_rwsem, this will
* cause a deadlock. Therefore, we release the reference only after we
* released the semaphore.
*/
mmput(mm);
return 0; return 0;
out_mutex: out_dma_list:
up_read(&per_mm->umem_rwsem);
vfree(umem_odp->dma_list); vfree(umem_odp->dma_list);
out_page_list: out_page_list:
vfree(umem_odp->page_list); vfree(umem_odp->page_list);
out_mm:
mmput(mm);
return ret_val; return ret_val;
} }
void ib_umem_odp_release(struct ib_umem_odp *umem_odp) void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
{ {
struct ib_umem *umem = &umem_odp->umem; struct ib_umem *umem = &umem_odp->umem;
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
/* /*
* Ensure that no more pages are mapped in the umem. * Ensure that no more pages are mapped in the umem.
...@@ -469,54 +535,8 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp) ...@@ -469,54 +535,8 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem), ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem),
ib_umem_end(umem)); ib_umem_end(umem));
down_write(&per_mm->umem_rwsem); remove_umem_from_per_mm(umem_odp);
if (likely(ib_umem_start(umem) != ib_umem_end(umem))) put_per_mm(umem_odp);
rbt_ib_umem_remove(&umem_odp->interval_tree,
&per_mm->umem_tree);
per_mm->odp_mrs_count--;
if (!umem_odp->mn_counters_active) {
list_del(&umem_odp->no_private_counters);
complete_all(&umem_odp->notifier_completion);
}
/*
* Downgrade the lock to a read lock. This ensures that the notifiers
* (who lock the mutex for reading) will be able to finish, and we
* will be able to enventually obtain the mmu notifiers SRCU. Note
* that since we are doing it atomically, no other user could register
* and unregister while we do the check.
*/
downgrade_write(&per_mm->umem_rwsem);
if (!per_mm->odp_mrs_count) {
struct task_struct *owning_process = NULL;
struct mm_struct *owning_mm = NULL;
owning_process =
get_pid_task(umem_odp->umem.context->tgid, PIDTYPE_PID);
if (owning_process == NULL)
/*
* The process is already dead, notifier were removed
* already.
*/
goto out;
owning_mm = get_task_mm(owning_process);
if (owning_mm == NULL)
/*
* The process' mm is already dead, notifier were
* removed already.
*/
goto out_put_task;
mmu_notifier_unregister(&per_mm->mn, owning_mm);
mmput(owning_mm);
out_put_task:
put_task_struct(owning_process);
}
out:
up_read(&per_mm->umem_rwsem);
vfree(umem_odp->dma_list); vfree(umem_odp->dma_list);
vfree(umem_odp->page_list); vfree(umem_odp->page_list);
} }
...@@ -634,7 +654,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, ...@@ -634,7 +654,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
{ {
struct ib_umem *umem = &umem_odp->umem; struct ib_umem *umem = &umem_odp->umem;
struct task_struct *owning_process = NULL; struct task_struct *owning_process = NULL;
struct mm_struct *owning_mm = NULL; struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
struct page **local_page_list = NULL; struct page **local_page_list = NULL;
u64 page_mask, off; u64 page_mask, off;
int j, k, ret = 0, start_idx, npages = 0, page_shift; int j, k, ret = 0, start_idx, npages = 0, page_shift;
...@@ -658,15 +678,14 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, ...@@ -658,15 +678,14 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
user_virt = user_virt & page_mask; user_virt = user_virt & page_mask;
bcnt += off; /* Charge for the first page offset as well. */ bcnt += off; /* Charge for the first page offset as well. */
owning_process = get_pid_task(umem->context->tgid, PIDTYPE_PID); /*
if (owning_process == NULL) { * owning_process is allowed to be NULL, this means somehow the mm is
* existing beyond the lifetime of the originating process.. Presumably
* mmget_not_zero will fail in this case.
*/
owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
if (WARN_ON(!mmget_not_zero(umem_odp->umem.owning_mm))) {
ret = -EINVAL; ret = -EINVAL;
goto out_no_task;
}
owning_mm = get_task_mm(owning_process);
if (owning_mm == NULL) {
ret = -ENOENT;
goto out_put_task; goto out_put_task;
} }
...@@ -738,8 +757,8 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt, ...@@ -738,8 +757,8 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
mmput(owning_mm); mmput(owning_mm);
out_put_task: out_put_task:
put_task_struct(owning_process); if (owning_process)
out_no_task: put_task_struct(owning_process);
free_page((unsigned long)local_page_list); free_page((unsigned long)local_page_list);
return ret; return ret;
} }
......
...@@ -124,12 +124,8 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file, ...@@ -124,12 +124,8 @@ ssize_t ib_uverbs_get_context(struct ib_uverbs_file *file,
ucontext->cleanup_retryable = false; ucontext->cleanup_retryable = false;
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
ucontext->per_mm.umem_tree = RB_ROOT_CACHED; mutex_init(&ucontext->per_mm_list_lock);
init_rwsem(&ucontext->per_mm.umem_rwsem); INIT_LIST_HEAD(&ucontext->per_mm_list);
ucontext->per_mm.odp_mrs_count = 0;
INIT_LIST_HEAD(&ucontext->per_mm.no_private_counters);
ucontext->per_mm.context = ucontext;
if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING)) if (!(ib_dev->attrs.device_cap_flags & IB_DEVICE_ON_DEMAND_PAGING))
ucontext->invalidate_range = NULL; ucontext->invalidate_range = NULL;
......
...@@ -1861,6 +1861,13 @@ static int mlx5_ib_dealloc_ucontext(struct ib_ucontext *ibcontext) ...@@ -1861,6 +1861,13 @@ static int mlx5_ib_dealloc_ucontext(struct ib_ucontext *ibcontext)
struct mlx5_ib_dev *dev = to_mdev(ibcontext->device); struct mlx5_ib_dev *dev = to_mdev(ibcontext->device);
struct mlx5_bfreg_info *bfregi; struct mlx5_bfreg_info *bfregi;
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
/* All umem's must be destroyed before destroying the ucontext. */
mutex_lock(&ibcontext->per_mm_list_lock);
WARN_ON(!list_empty(&ibcontext->per_mm_list));
mutex_unlock(&ibcontext->per_mm_list_lock);
#endif
if (context->devx_uid) if (context->devx_uid)
mlx5_ib_devx_destroy(dev, context); mlx5_ib_devx_destroy(dev, context);
......
...@@ -393,7 +393,7 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr, ...@@ -393,7 +393,7 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
if (nentries) if (nentries)
nentries++; nentries++;
} else { } else {
odp = ib_alloc_odp_umem(odp_mr->umem.context, addr, odp = ib_alloc_odp_umem(odp_mr->per_mm, addr,
MLX5_IMR_MTT_SIZE); MLX5_IMR_MTT_SIZE);
if (IS_ERR(odp)) { if (IS_ERR(odp)) {
mutex_unlock(&odp_mr->umem_mutex); mutex_unlock(&odp_mr->umem_mutex);
......
...@@ -91,8 +91,26 @@ static inline struct ib_umem_odp *to_ib_umem_odp(struct ib_umem *umem) ...@@ -91,8 +91,26 @@ static inline struct ib_umem_odp *to_ib_umem_odp(struct ib_umem *umem)
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
struct ib_ucontext_per_mm {
struct ib_ucontext *context;
struct mm_struct *mm;
struct pid *tgid;
struct rb_root_cached umem_tree;
/* Protects umem_tree */
struct rw_semaphore umem_rwsem;
atomic_t notifier_count;
struct mmu_notifier mn;
/* A list of umems that don't have private mmu notifier counters yet. */
struct list_head no_private_counters;
unsigned int odp_mrs_count;
struct list_head ucontext_list;
};
int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access); int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext *context, struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
unsigned long addr, size_t size); unsigned long addr, size_t size);
void ib_umem_odp_release(struct ib_umem_odp *umem_odp); void ib_umem_odp_release(struct ib_umem_odp *umem_odp);
......
...@@ -1488,25 +1488,6 @@ struct ib_rdmacg_object { ...@@ -1488,25 +1488,6 @@ struct ib_rdmacg_object {
#endif #endif
}; };
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
struct ib_ucontext_per_mm {
struct ib_ucontext *context;
struct rb_root_cached umem_tree;
/*
* Protects .umem_rbroot and tree, as well as odp_mrs_count and
* mmu notifiers registration.
*/
struct rw_semaphore umem_rwsem;
struct mmu_notifier mn;
atomic_t notifier_count;
/* A list of umems that don't have private mmu notifier counters yet. */
struct list_head no_private_counters;
unsigned int odp_mrs_count;
};
#endif
struct ib_ucontext { struct ib_ucontext {
struct ib_device *device; struct ib_device *device;
struct ib_uverbs_file *ufile; struct ib_uverbs_file *ufile;
...@@ -1523,7 +1504,8 @@ struct ib_ucontext { ...@@ -1523,7 +1504,8 @@ struct ib_ucontext {
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING #ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING
void (*invalidate_range)(struct ib_umem_odp *umem_odp, void (*invalidate_range)(struct ib_umem_odp *umem_odp,
unsigned long start, unsigned long end); unsigned long start, unsigned long end);
struct ib_ucontext_per_mm per_mm; struct mutex per_mm_list_lock;
struct list_head per_mm_list;
#endif #endif
struct ib_rdmacg_object cg_obj; struct ib_rdmacg_object cg_obj;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment