Commit 704f3f2c authored by Jérôme Glisse's avatar Jérôme Glisse Committed by Linus Torvalds

mm/hmm: use reference counting for HMM struct

Every time I read the code to check that the HMM structure does not vanish
before it should thanks to the many lock protecting its removal i get a
headache.  Switch to reference counting instead it is much easier to
follow and harder to break.  This also remove some code that is no longer
needed with refcounting.

Link: http://lkml.kernel.org/r/20190403193318.16478-3-jglisse@redhat.comSigned-off-by: default avatarJérôme Glisse <jglisse@redhat.com>
Reviewed-by: default avatarRalph Campbell <rcampbell@nvidia.com>
Cc: John Hubbard <jhubbard@nvidia.com>
Cc: Dan Williams <dan.j.williams@intel.com>
Cc: Arnd Bergmann <arnd@arndb.de>
Cc: Balbir Singh <bsingharora@gmail.com>
Cc: Dan Carpenter <dan.carpenter@oracle.com>
Cc: Ira Weiny <ira.weiny@intel.com>
Cc: Matthew Wilcox <willy@infradead.org>
Cc: Souptick Joarder <jrdr.linux@gmail.com>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
Signed-off-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
parent 734fb899
...@@ -131,6 +131,7 @@ enum hmm_pfn_value_e { ...@@ -131,6 +131,7 @@ enum hmm_pfn_value_e {
/* /*
* struct hmm_range - track invalidation lock on virtual address range * struct hmm_range - track invalidation lock on virtual address range
* *
* @hmm: the core HMM structure this range is active against
* @vma: the vm area struct for the range * @vma: the vm area struct for the range
* @list: all range lock are on a list * @list: all range lock are on a list
* @start: range virtual start address (inclusive) * @start: range virtual start address (inclusive)
...@@ -142,6 +143,7 @@ enum hmm_pfn_value_e { ...@@ -142,6 +143,7 @@ enum hmm_pfn_value_e {
* @valid: pfns array did not change since it has been fill by an HMM function * @valid: pfns array did not change since it has been fill by an HMM function
*/ */
struct hmm_range { struct hmm_range {
struct hmm *hmm;
struct vm_area_struct *vma; struct vm_area_struct *vma;
struct list_head list; struct list_head list;
unsigned long start; unsigned long start;
......
...@@ -50,6 +50,7 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops; ...@@ -50,6 +50,7 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops;
*/ */
struct hmm { struct hmm {
struct mm_struct *mm; struct mm_struct *mm;
struct kref kref;
spinlock_t lock; spinlock_t lock;
struct list_head ranges; struct list_head ranges;
struct list_head mirrors; struct list_head mirrors;
...@@ -57,24 +58,33 @@ struct hmm { ...@@ -57,24 +58,33 @@ struct hmm {
struct rw_semaphore mirrors_sem; struct rw_semaphore mirrors_sem;
}; };
/* static inline struct hmm *mm_get_hmm(struct mm_struct *mm)
* hmm_register - register HMM against an mm (HMM internal) {
struct hmm *hmm = READ_ONCE(mm->hmm);
if (hmm && kref_get_unless_zero(&hmm->kref))
return hmm;
return NULL;
}
/**
* hmm_get_or_create - register HMM against an mm (HMM internal)
* *
* @mm: mm struct to attach to * @mm: mm struct to attach to
* Returns: returns an HMM object, either by referencing the existing
* (per-process) object, or by creating a new one.
* *
* This is not intended to be used directly by device drivers. It allocates an * This is not intended to be used directly by device drivers. If mm already
* HMM struct if mm does not have one, and initializes it. * has an HMM struct then it get a reference on it and returns it. Otherwise
* it allocates an HMM struct, initializes it, associate it with the mm and
* returns it.
*/ */
static struct hmm *hmm_register(struct mm_struct *mm) static struct hmm *hmm_get_or_create(struct mm_struct *mm)
{ {
struct hmm *hmm = READ_ONCE(mm->hmm); struct hmm *hmm = mm_get_hmm(mm);
bool cleanup = false; bool cleanup = false;
/*
* The hmm struct can only be freed once the mm_struct goes away,
* hence we should always have pre-allocated an new hmm struct
* above.
*/
if (hmm) if (hmm)
return hmm; return hmm;
...@@ -86,6 +96,7 @@ static struct hmm *hmm_register(struct mm_struct *mm) ...@@ -86,6 +96,7 @@ static struct hmm *hmm_register(struct mm_struct *mm)
hmm->mmu_notifier.ops = NULL; hmm->mmu_notifier.ops = NULL;
INIT_LIST_HEAD(&hmm->ranges); INIT_LIST_HEAD(&hmm->ranges);
spin_lock_init(&hmm->lock); spin_lock_init(&hmm->lock);
kref_init(&hmm->kref);
hmm->mm = mm; hmm->mm = mm;
spin_lock(&mm->page_table_lock); spin_lock(&mm->page_table_lock);
...@@ -106,7 +117,7 @@ static struct hmm *hmm_register(struct mm_struct *mm) ...@@ -106,7 +117,7 @@ static struct hmm *hmm_register(struct mm_struct *mm)
if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) if (__mmu_notifier_register(&hmm->mmu_notifier, mm))
goto error_mm; goto error_mm;
return mm->hmm; return hmm;
error_mm: error_mm:
spin_lock(&mm->page_table_lock); spin_lock(&mm->page_table_lock);
...@@ -118,9 +129,41 @@ static struct hmm *hmm_register(struct mm_struct *mm) ...@@ -118,9 +129,41 @@ static struct hmm *hmm_register(struct mm_struct *mm)
return NULL; return NULL;
} }
static void hmm_free(struct kref *kref)
{
struct hmm *hmm = container_of(kref, struct hmm, kref);
struct mm_struct *mm = hmm->mm;
mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
spin_lock(&mm->page_table_lock);
if (mm->hmm == hmm)
mm->hmm = NULL;
spin_unlock(&mm->page_table_lock);
kfree(hmm);
}
static inline void hmm_put(struct hmm *hmm)
{
kref_put(&hmm->kref, hmm_free);
}
void hmm_mm_destroy(struct mm_struct *mm) void hmm_mm_destroy(struct mm_struct *mm)
{ {
kfree(mm->hmm); struct hmm *hmm;
spin_lock(&mm->page_table_lock);
hmm = mm_get_hmm(mm);
mm->hmm = NULL;
if (hmm) {
hmm->mm = NULL;
spin_unlock(&mm->page_table_lock);
hmm_put(hmm);
return;
}
spin_unlock(&mm->page_table_lock);
} }
static int hmm_invalidate_range(struct hmm *hmm, bool device, static int hmm_invalidate_range(struct hmm *hmm, bool device,
...@@ -165,7 +208,7 @@ static int hmm_invalidate_range(struct hmm *hmm, bool device, ...@@ -165,7 +208,7 @@ static int hmm_invalidate_range(struct hmm *hmm, bool device,
static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
{ {
struct hmm_mirror *mirror; struct hmm_mirror *mirror;
struct hmm *hmm = mm->hmm; struct hmm *hmm = mm_get_hmm(mm);
down_write(&hmm->mirrors_sem); down_write(&hmm->mirrors_sem);
mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror, mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror,
...@@ -186,13 +229,16 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm) ...@@ -186,13 +229,16 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
struct hmm_mirror, list); struct hmm_mirror, list);
} }
up_write(&hmm->mirrors_sem); up_write(&hmm->mirrors_sem);
hmm_put(hmm);
} }
static int hmm_invalidate_range_start(struct mmu_notifier *mn, static int hmm_invalidate_range_start(struct mmu_notifier *mn,
const struct mmu_notifier_range *range) const struct mmu_notifier_range *range)
{ {
struct hmm *hmm = mm_get_hmm(range->mm);
struct hmm_update update; struct hmm_update update;
struct hmm *hmm = range->mm->hmm; int ret;
VM_BUG_ON(!hmm); VM_BUG_ON(!hmm);
...@@ -200,14 +246,16 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn, ...@@ -200,14 +246,16 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
update.end = range->end; update.end = range->end;
update.event = HMM_UPDATE_INVALIDATE; update.event = HMM_UPDATE_INVALIDATE;
update.blockable = range->blockable; update.blockable = range->blockable;
return hmm_invalidate_range(hmm, true, &update); ret = hmm_invalidate_range(hmm, true, &update);
hmm_put(hmm);
return ret;
} }
static void hmm_invalidate_range_end(struct mmu_notifier *mn, static void hmm_invalidate_range_end(struct mmu_notifier *mn,
const struct mmu_notifier_range *range) const struct mmu_notifier_range *range)
{ {
struct hmm *hmm = mm_get_hmm(range->mm);
struct hmm_update update; struct hmm_update update;
struct hmm *hmm = range->mm->hmm;
VM_BUG_ON(!hmm); VM_BUG_ON(!hmm);
...@@ -216,6 +264,7 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn, ...@@ -216,6 +264,7 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
update.event = HMM_UPDATE_INVALIDATE; update.event = HMM_UPDATE_INVALIDATE;
update.blockable = true; update.blockable = true;
hmm_invalidate_range(hmm, false, &update); hmm_invalidate_range(hmm, false, &update);
hmm_put(hmm);
} }
static const struct mmu_notifier_ops hmm_mmu_notifier_ops = { static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
...@@ -241,24 +290,13 @@ int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm) ...@@ -241,24 +290,13 @@ int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
if (!mm || !mirror || !mirror->ops) if (!mm || !mirror || !mirror->ops)
return -EINVAL; return -EINVAL;
again: mirror->hmm = hmm_get_or_create(mm);
mirror->hmm = hmm_register(mm);
if (!mirror->hmm) if (!mirror->hmm)
return -ENOMEM; return -ENOMEM;
down_write(&mirror->hmm->mirrors_sem); down_write(&mirror->hmm->mirrors_sem);
if (mirror->hmm->mm == NULL) { list_add(&mirror->list, &mirror->hmm->mirrors);
/* up_write(&mirror->hmm->mirrors_sem);
* A racing hmm_mirror_unregister() is about to destroy the hmm
* struct. Try again to allocate a new one.
*/
up_write(&mirror->hmm->mirrors_sem);
mirror->hmm = NULL;
goto again;
} else {
list_add(&mirror->list, &mirror->hmm->mirrors);
up_write(&mirror->hmm->mirrors_sem);
}
return 0; return 0;
} }
...@@ -273,33 +311,18 @@ EXPORT_SYMBOL(hmm_mirror_register); ...@@ -273,33 +311,18 @@ EXPORT_SYMBOL(hmm_mirror_register);
*/ */
void hmm_mirror_unregister(struct hmm_mirror *mirror) void hmm_mirror_unregister(struct hmm_mirror *mirror)
{ {
bool should_unregister = false; struct hmm *hmm = READ_ONCE(mirror->hmm);
struct mm_struct *mm;
struct hmm *hmm;
if (mirror->hmm == NULL) if (hmm == NULL)
return; return;
hmm = mirror->hmm;
down_write(&hmm->mirrors_sem); down_write(&hmm->mirrors_sem);
list_del_init(&mirror->list); list_del_init(&mirror->list);
should_unregister = list_empty(&hmm->mirrors); /* To protect us against double unregister ... */
mirror->hmm = NULL; mirror->hmm = NULL;
mm = hmm->mm;
hmm->mm = NULL;
up_write(&hmm->mirrors_sem); up_write(&hmm->mirrors_sem);
if (!should_unregister || mm == NULL) hmm_put(hmm);
return;
mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
spin_lock(&mm->page_table_lock);
if (mm->hmm == hmm)
mm->hmm = NULL;
spin_unlock(&mm->page_table_lock);
kfree(hmm);
} }
EXPORT_SYMBOL(hmm_mirror_unregister); EXPORT_SYMBOL(hmm_mirror_unregister);
...@@ -708,23 +731,29 @@ int hmm_vma_get_pfns(struct hmm_range *range) ...@@ -708,23 +731,29 @@ int hmm_vma_get_pfns(struct hmm_range *range)
struct mm_walk mm_walk; struct mm_walk mm_walk;
struct hmm *hmm; struct hmm *hmm;
range->hmm = NULL;
/* Sanity check, this really should not happen ! */ /* Sanity check, this really should not happen ! */
if (range->start < vma->vm_start || range->start >= vma->vm_end) if (range->start < vma->vm_start || range->start >= vma->vm_end)
return -EINVAL; return -EINVAL;
if (range->end < vma->vm_start || range->end > vma->vm_end) if (range->end < vma->vm_start || range->end > vma->vm_end)
return -EINVAL; return -EINVAL;
hmm = hmm_register(vma->vm_mm); hmm = hmm_get_or_create(vma->vm_mm);
if (!hmm) if (!hmm)
return -ENOMEM; return -ENOMEM;
/* Caller must have registered a mirror, via hmm_mirror_register() ! */
if (!hmm->mmu_notifier.ops) /* Check if hmm_mm_destroy() was call. */
if (hmm->mm == NULL) {
hmm_put(hmm);
return -EINVAL; return -EINVAL;
}
/* FIXME support hugetlb fs */ /* FIXME support hugetlb fs */
if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) || if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) ||
vma_is_dax(vma)) { vma_is_dax(vma)) {
hmm_pfns_special(range); hmm_pfns_special(range);
hmm_put(hmm);
return -EINVAL; return -EINVAL;
} }
...@@ -736,6 +765,7 @@ int hmm_vma_get_pfns(struct hmm_range *range) ...@@ -736,6 +765,7 @@ int hmm_vma_get_pfns(struct hmm_range *range)
* operations such has atomic access would not work. * operations such has atomic access would not work.
*/ */
hmm_pfns_clear(range, range->pfns, range->start, range->end); hmm_pfns_clear(range, range->pfns, range->start, range->end);
hmm_put(hmm);
return -EPERM; return -EPERM;
} }
...@@ -758,6 +788,12 @@ int hmm_vma_get_pfns(struct hmm_range *range) ...@@ -758,6 +788,12 @@ int hmm_vma_get_pfns(struct hmm_range *range)
mm_walk.pte_hole = hmm_vma_walk_hole; mm_walk.pte_hole = hmm_vma_walk_hole;
walk_page_range(range->start, range->end, &mm_walk); walk_page_range(range->start, range->end, &mm_walk);
/*
* Transfer hmm reference to the range struct it will be drop inside
* the hmm_vma_range_done() function (which _must_ be call if this
* function return 0).
*/
range->hmm = hmm;
return 0; return 0;
} }
EXPORT_SYMBOL(hmm_vma_get_pfns); EXPORT_SYMBOL(hmm_vma_get_pfns);
...@@ -802,25 +838,27 @@ EXPORT_SYMBOL(hmm_vma_get_pfns); ...@@ -802,25 +838,27 @@ EXPORT_SYMBOL(hmm_vma_get_pfns);
*/ */
bool hmm_vma_range_done(struct hmm_range *range) bool hmm_vma_range_done(struct hmm_range *range)
{ {
unsigned long npages = (range->end - range->start) >> PAGE_SHIFT; bool ret = false;
struct hmm *hmm;
if (range->end <= range->start) { /* Sanity check this really should not happen. */
if (range->hmm == NULL || range->end <= range->start) {
BUG(); BUG();
return false; return false;
} }
hmm = hmm_register(range->vma->vm_mm); spin_lock(&range->hmm->lock);
if (!hmm) {
memset(range->pfns, 0, sizeof(*range->pfns) * npages);
return false;
}
spin_lock(&hmm->lock);
list_del_rcu(&range->list); list_del_rcu(&range->list);
spin_unlock(&hmm->lock); ret = range->valid;
spin_unlock(&range->hmm->lock);
return range->valid; /* Is the mm still alive ? */
if (range->hmm->mm == NULL)
ret = false;
/* Drop reference taken by hmm_vma_fault() or hmm_vma_get_pfns() */
hmm_put(range->hmm);
range->hmm = NULL;
return ret;
} }
EXPORT_SYMBOL(hmm_vma_range_done); EXPORT_SYMBOL(hmm_vma_range_done);
...@@ -880,25 +918,31 @@ int hmm_vma_fault(struct hmm_range *range, bool block) ...@@ -880,25 +918,31 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
struct hmm *hmm; struct hmm *hmm;
int ret; int ret;
range->hmm = NULL;
/* Sanity check, this really should not happen ! */ /* Sanity check, this really should not happen ! */
if (range->start < vma->vm_start || range->start >= vma->vm_end) if (range->start < vma->vm_start || range->start >= vma->vm_end)
return -EINVAL; return -EINVAL;
if (range->end < vma->vm_start || range->end > vma->vm_end) if (range->end < vma->vm_start || range->end > vma->vm_end)
return -EINVAL; return -EINVAL;
hmm = hmm_register(vma->vm_mm); hmm = hmm_get_or_create(vma->vm_mm);
if (!hmm) { if (!hmm) {
hmm_pfns_clear(range, range->pfns, range->start, range->end); hmm_pfns_clear(range, range->pfns, range->start, range->end);
return -ENOMEM; return -ENOMEM;
} }
/* Caller must have registered a mirror using hmm_mirror_register() */
if (!hmm->mmu_notifier.ops) /* Check if hmm_mm_destroy() was call. */
if (hmm->mm == NULL) {
hmm_put(hmm);
return -EINVAL; return -EINVAL;
}
/* FIXME support hugetlb fs */ /* FIXME support hugetlb fs */
if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) || if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) ||
vma_is_dax(vma)) { vma_is_dax(vma)) {
hmm_pfns_special(range); hmm_pfns_special(range);
hmm_put(hmm);
return -EINVAL; return -EINVAL;
} }
...@@ -910,6 +954,7 @@ int hmm_vma_fault(struct hmm_range *range, bool block) ...@@ -910,6 +954,7 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
* operations such has atomic access would not work. * operations such has atomic access would not work.
*/ */
hmm_pfns_clear(range, range->pfns, range->start, range->end); hmm_pfns_clear(range, range->pfns, range->start, range->end);
hmm_put(hmm);
return -EPERM; return -EPERM;
} }
...@@ -945,7 +990,16 @@ int hmm_vma_fault(struct hmm_range *range, bool block) ...@@ -945,7 +990,16 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
hmm_pfns_clear(range, &range->pfns[i], hmm_vma_walk.last, hmm_pfns_clear(range, &range->pfns[i], hmm_vma_walk.last,
range->end); range->end);
hmm_vma_range_done(range); hmm_vma_range_done(range);
hmm_put(hmm);
} else {
/*
* Transfer hmm reference to the range struct it will be drop
* inside the hmm_vma_range_done() function (which _must_ be
* call if this function return 0).
*/
range->hmm = hmm;
} }
return ret; return ret;
} }
EXPORT_SYMBOL(hmm_vma_fault); EXPORT_SYMBOL(hmm_vma_fault);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment