Commit 39bed42d authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'for-linus-hmm' of git://git.kernel.org/pub/scm/linux/kernel/git/rdma/rdma

Pull mmu_notifier updates from Jason Gunthorpe:
 "This small series revises the names in mmu_notifier to make the code
  clearer and more readable"

* tag 'for-linus-hmm' of git://git.kernel.org/pub/scm/linux/kernel/git/rdma/rdma:
  mm/mmu_notifiers: Use 'interval_sub' as the variable for mmu_interval_notifier
  mm/mmu_notifiers: Use 'subscription' as the variable name for mmu_notifier
  mm/mmu_notifier: Rename struct mmu_notifier_mm to mmu_notifier_subscriptions
parents 83fa805b 5292e24a
...@@ -149,14 +149,14 @@ CPU page table into a device page table; HMM helps keep both synchronized. A ...@@ -149,14 +149,14 @@ CPU page table into a device page table; HMM helps keep both synchronized. A
device driver that wants to mirror a process address space must start with the device driver that wants to mirror a process address space must start with the
registration of a mmu_interval_notifier:: registration of a mmu_interval_notifier::
mni->ops = &driver_ops; int mmu_interval_notifier_insert(struct mmu_interval_notifier *interval_sub,
int mmu_interval_notifier_insert(struct mmu_interval_notifier *mni, struct mm_struct *mm, unsigned long start,
unsigned long start, unsigned long length, unsigned long length,
struct mm_struct *mm); const struct mmu_interval_notifier_ops *ops);
During the driver_ops->invalidate() callback the device driver must perform During the ops->invalidate() callback the device driver must perform the
the update action to the range (mark range read only, or fully unmap, update action to the range (mark range read only, or fully unmap, etc.). The
etc.). The device must complete the update before the driver callback returns. device must complete the update before the driver callback returns.
When the device driver wants to populate a range of virtual addresses, it can When the device driver wants to populate a range of virtual addresses, it can
use:: use::
...@@ -183,7 +183,7 @@ The usage pattern is:: ...@@ -183,7 +183,7 @@ The usage pattern is::
struct hmm_range range; struct hmm_range range;
... ...
range.notifier = &mni; range.notifier = &interval_sub;
range.start = ...; range.start = ...;
range.end = ...; range.end = ...;
range.pfns = ...; range.pfns = ...;
...@@ -191,11 +191,11 @@ The usage pattern is:: ...@@ -191,11 +191,11 @@ The usage pattern is::
range.values = ...; range.values = ...;
range.pfn_shift = ...; range.pfn_shift = ...;
if (!mmget_not_zero(mni->notifier.mm)) if (!mmget_not_zero(interval_sub->notifier.mm))
return -EFAULT; return -EFAULT;
again: again:
range.notifier_seq = mmu_interval_read_begin(&mni); range.notifier_seq = mmu_interval_read_begin(&interval_sub);
down_read(&mm->mmap_sem); down_read(&mm->mmap_sem);
ret = hmm_range_fault(&range, HMM_RANGE_SNAPSHOT); ret = hmm_range_fault(&range, HMM_RANGE_SNAPSHOT);
if (ret) { if (ret) {
......
...@@ -490,7 +490,7 @@ struct mm_struct { ...@@ -490,7 +490,7 @@ struct mm_struct {
/* store ref to file /proc/<pid>/exe symlink points to */ /* store ref to file /proc/<pid>/exe symlink points to */
struct file __rcu *exe_file; struct file __rcu *exe_file;
#ifdef CONFIG_MMU_NOTIFIER #ifdef CONFIG_MMU_NOTIFIER
struct mmu_notifier_mm *mmu_notifier_mm; struct mmu_notifier_subscriptions *notifier_subscriptions;
#endif #endif
#if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS #if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS
pgtable_t pmd_huge_pte; /* protected by page_table_lock */ pgtable_t pmd_huge_pte; /* protected by page_table_lock */
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <linux/srcu.h> #include <linux/srcu.h>
#include <linux/interval_tree.h> #include <linux/interval_tree.h>
struct mmu_notifier_mm; struct mmu_notifier_subscriptions;
struct mmu_notifier; struct mmu_notifier;
struct mmu_notifier_range; struct mmu_notifier_range;
struct mmu_interval_notifier; struct mmu_interval_notifier;
...@@ -73,7 +73,7 @@ struct mmu_notifier_ops { ...@@ -73,7 +73,7 @@ struct mmu_notifier_ops {
* through the gart alias address, so leading to memory * through the gart alias address, so leading to memory
* corruption. * corruption.
*/ */
void (*release)(struct mmu_notifier *mn, void (*release)(struct mmu_notifier *subscription,
struct mm_struct *mm); struct mm_struct *mm);
/* /*
...@@ -85,7 +85,7 @@ struct mmu_notifier_ops { ...@@ -85,7 +85,7 @@ struct mmu_notifier_ops {
* Start-end is necessary in case the secondary MMU is mapping the page * Start-end is necessary in case the secondary MMU is mapping the page
* at a smaller granularity than the primary MMU. * at a smaller granularity than the primary MMU.
*/ */
int (*clear_flush_young)(struct mmu_notifier *mn, int (*clear_flush_young)(struct mmu_notifier *subscription,
struct mm_struct *mm, struct mm_struct *mm,
unsigned long start, unsigned long start,
unsigned long end); unsigned long end);
...@@ -95,7 +95,7 @@ struct mmu_notifier_ops { ...@@ -95,7 +95,7 @@ struct mmu_notifier_ops {
* latter, it is supposed to test-and-clear the young/accessed bitflag * latter, it is supposed to test-and-clear the young/accessed bitflag
* in the secondary pte, but it may omit flushing the secondary tlb. * in the secondary pte, but it may omit flushing the secondary tlb.
*/ */
int (*clear_young)(struct mmu_notifier *mn, int (*clear_young)(struct mmu_notifier *subscription,
struct mm_struct *mm, struct mm_struct *mm,
unsigned long start, unsigned long start,
unsigned long end); unsigned long end);
...@@ -106,7 +106,7 @@ struct mmu_notifier_ops { ...@@ -106,7 +106,7 @@ struct mmu_notifier_ops {
* frequently used without actually clearing the flag or tearing * frequently used without actually clearing the flag or tearing
* down the secondary mapping on the page. * down the secondary mapping on the page.
*/ */
int (*test_young)(struct mmu_notifier *mn, int (*test_young)(struct mmu_notifier *subscription,
struct mm_struct *mm, struct mm_struct *mm,
unsigned long address); unsigned long address);
...@@ -114,7 +114,7 @@ struct mmu_notifier_ops { ...@@ -114,7 +114,7 @@ struct mmu_notifier_ops {
* change_pte is called in cases that pte mapping to page is changed: * change_pte is called in cases that pte mapping to page is changed:
* for example, when ksm remaps pte to point to a new shared page. * for example, when ksm remaps pte to point to a new shared page.
*/ */
void (*change_pte)(struct mmu_notifier *mn, void (*change_pte)(struct mmu_notifier *subscription,
struct mm_struct *mm, struct mm_struct *mm,
unsigned long address, unsigned long address,
pte_t pte); pte_t pte);
...@@ -169,9 +169,9 @@ struct mmu_notifier_ops { ...@@ -169,9 +169,9 @@ struct mmu_notifier_ops {
* invalidate_range_end. * invalidate_range_end.
* *
*/ */
int (*invalidate_range_start)(struct mmu_notifier *mn, int (*invalidate_range_start)(struct mmu_notifier *subscription,
const struct mmu_notifier_range *range); const struct mmu_notifier_range *range);
void (*invalidate_range_end)(struct mmu_notifier *mn, void (*invalidate_range_end)(struct mmu_notifier *subscription,
const struct mmu_notifier_range *range); const struct mmu_notifier_range *range);
/* /*
...@@ -192,8 +192,10 @@ struct mmu_notifier_ops { ...@@ -192,8 +192,10 @@ struct mmu_notifier_ops {
* of what was passed to invalidate_range_start()/end(), if * of what was passed to invalidate_range_start()/end(), if
* called between those functions. * called between those functions.
*/ */
void (*invalidate_range)(struct mmu_notifier *mn, struct mm_struct *mm, void (*invalidate_range)(struct mmu_notifier *subscription,
unsigned long start, unsigned long end); struct mm_struct *mm,
unsigned long start,
unsigned long end);
/* /*
* These callbacks are used with the get/put interface to manage the * These callbacks are used with the get/put interface to manage the
...@@ -206,7 +208,7 @@ struct mmu_notifier_ops { ...@@ -206,7 +208,7 @@ struct mmu_notifier_ops {
* and cannot sleep. * and cannot sleep.
*/ */
struct mmu_notifier *(*alloc_notifier)(struct mm_struct *mm); struct mmu_notifier *(*alloc_notifier)(struct mm_struct *mm);
void (*free_notifier)(struct mmu_notifier *mn); void (*free_notifier)(struct mmu_notifier *subscription);
}; };
/* /*
...@@ -235,7 +237,7 @@ struct mmu_notifier { ...@@ -235,7 +237,7 @@ struct mmu_notifier {
* was required but mmu_notifier_range_blockable(range) is false. * was required but mmu_notifier_range_blockable(range) is false.
*/ */
struct mmu_interval_notifier_ops { struct mmu_interval_notifier_ops {
bool (*invalidate)(struct mmu_interval_notifier *mni, bool (*invalidate)(struct mmu_interval_notifier *interval_sub,
const struct mmu_notifier_range *range, const struct mmu_notifier_range *range,
unsigned long cur_seq); unsigned long cur_seq);
}; };
...@@ -265,7 +267,7 @@ struct mmu_notifier_range { ...@@ -265,7 +267,7 @@ struct mmu_notifier_range {
static inline int mm_has_notifiers(struct mm_struct *mm) static inline int mm_has_notifiers(struct mm_struct *mm)
{ {
return unlikely(mm->mmu_notifier_mm); return unlikely(mm->notifier_subscriptions);
} }
struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops, struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
...@@ -280,30 +282,31 @@ mmu_notifier_get(const struct mmu_notifier_ops *ops, struct mm_struct *mm) ...@@ -280,30 +282,31 @@ mmu_notifier_get(const struct mmu_notifier_ops *ops, struct mm_struct *mm)
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
return ret; return ret;
} }
void mmu_notifier_put(struct mmu_notifier *mn); void mmu_notifier_put(struct mmu_notifier *subscription);
void mmu_notifier_synchronize(void); void mmu_notifier_synchronize(void);
extern int mmu_notifier_register(struct mmu_notifier *mn, extern int mmu_notifier_register(struct mmu_notifier *subscription,
struct mm_struct *mm); struct mm_struct *mm);
extern int __mmu_notifier_register(struct mmu_notifier *mn, extern int __mmu_notifier_register(struct mmu_notifier *subscription,
struct mm_struct *mm); struct mm_struct *mm);
extern void mmu_notifier_unregister(struct mmu_notifier *mn, extern void mmu_notifier_unregister(struct mmu_notifier *subscription,
struct mm_struct *mm); struct mm_struct *mm);
unsigned long mmu_interval_read_begin(struct mmu_interval_notifier *mni); unsigned long
int mmu_interval_notifier_insert(struct mmu_interval_notifier *mni, mmu_interval_read_begin(struct mmu_interval_notifier *interval_sub);
int mmu_interval_notifier_insert(struct mmu_interval_notifier *interval_sub,
struct mm_struct *mm, unsigned long start, struct mm_struct *mm, unsigned long start,
unsigned long length, unsigned long length,
const struct mmu_interval_notifier_ops *ops); const struct mmu_interval_notifier_ops *ops);
int mmu_interval_notifier_insert_locked( int mmu_interval_notifier_insert_locked(
struct mmu_interval_notifier *mni, struct mm_struct *mm, struct mmu_interval_notifier *interval_sub, struct mm_struct *mm,
unsigned long start, unsigned long length, unsigned long start, unsigned long length,
const struct mmu_interval_notifier_ops *ops); const struct mmu_interval_notifier_ops *ops);
void mmu_interval_notifier_remove(struct mmu_interval_notifier *mni); void mmu_interval_notifier_remove(struct mmu_interval_notifier *interval_sub);
/** /**
* mmu_interval_set_seq - Save the invalidation sequence * mmu_interval_set_seq - Save the invalidation sequence
* @mni - The mni passed to invalidate * @interval_sub - The subscription passed to invalidate
* @cur_seq - The cur_seq passed to the invalidate() callback * @cur_seq - The cur_seq passed to the invalidate() callback
* *
* This must be called unconditionally from the invalidate callback of a * This must be called unconditionally from the invalidate callback of a
...@@ -314,15 +317,16 @@ void mmu_interval_notifier_remove(struct mmu_interval_notifier *mni); ...@@ -314,15 +317,16 @@ void mmu_interval_notifier_remove(struct mmu_interval_notifier *mni);
* If the caller does not call mmu_interval_read_begin() or * If the caller does not call mmu_interval_read_begin() or
* mmu_interval_read_retry() then this call is not required. * mmu_interval_read_retry() then this call is not required.
*/ */
static inline void mmu_interval_set_seq(struct mmu_interval_notifier *mni, static inline void
unsigned long cur_seq) mmu_interval_set_seq(struct mmu_interval_notifier *interval_sub,
unsigned long cur_seq)
{ {
WRITE_ONCE(mni->invalidate_seq, cur_seq); WRITE_ONCE(interval_sub->invalidate_seq, cur_seq);
} }
/** /**
* mmu_interval_read_retry - End a read side critical section against a VA range * mmu_interval_read_retry - End a read side critical section against a VA range
* mni: The range * interval_sub: The subscription
* seq: The return of the paired mmu_interval_read_begin() * seq: The return of the paired mmu_interval_read_begin()
* *
* This MUST be called under a user provided lock that is also held * This MUST be called under a user provided lock that is also held
...@@ -334,15 +338,16 @@ static inline void mmu_interval_set_seq(struct mmu_interval_notifier *mni, ...@@ -334,15 +338,16 @@ static inline void mmu_interval_set_seq(struct mmu_interval_notifier *mni,
* Returns true if an invalidation collided with this critical section, and * Returns true if an invalidation collided with this critical section, and
* the caller should retry. * the caller should retry.
*/ */
static inline bool mmu_interval_read_retry(struct mmu_interval_notifier *mni, static inline bool
unsigned long seq) mmu_interval_read_retry(struct mmu_interval_notifier *interval_sub,
unsigned long seq)
{ {
return mni->invalidate_seq != seq; return interval_sub->invalidate_seq != seq;
} }
/** /**
* mmu_interval_check_retry - Test if a collision has occurred * mmu_interval_check_retry - Test if a collision has occurred
* mni: The range * interval_sub: The subscription
* seq: The return of the matching mmu_interval_read_begin() * seq: The return of the matching mmu_interval_read_begin()
* *
* This can be used in the critical section between mmu_interval_read_begin() * This can be used in the critical section between mmu_interval_read_begin()
...@@ -357,14 +362,15 @@ static inline bool mmu_interval_read_retry(struct mmu_interval_notifier *mni, ...@@ -357,14 +362,15 @@ static inline bool mmu_interval_read_retry(struct mmu_interval_notifier *mni,
* This call can be used as part of loops and other expensive operations to * This call can be used as part of loops and other expensive operations to
* expedite a retry. * expedite a retry.
*/ */
static inline bool mmu_interval_check_retry(struct mmu_interval_notifier *mni, static inline bool
unsigned long seq) mmu_interval_check_retry(struct mmu_interval_notifier *interval_sub,
unsigned long seq)
{ {
/* Pairs with the WRITE_ONCE in mmu_interval_set_seq() */ /* Pairs with the WRITE_ONCE in mmu_interval_set_seq() */
return READ_ONCE(mni->invalidate_seq) != seq; return READ_ONCE(interval_sub->invalidate_seq) != seq;
} }
extern void __mmu_notifier_mm_destroy(struct mm_struct *mm); extern void __mmu_notifier_subscriptions_destroy(struct mm_struct *mm);
extern void __mmu_notifier_release(struct mm_struct *mm); extern void __mmu_notifier_release(struct mm_struct *mm);
extern int __mmu_notifier_clear_flush_young(struct mm_struct *mm, extern int __mmu_notifier_clear_flush_young(struct mm_struct *mm,
unsigned long start, unsigned long start,
...@@ -480,15 +486,15 @@ static inline void mmu_notifier_invalidate_range(struct mm_struct *mm, ...@@ -480,15 +486,15 @@ static inline void mmu_notifier_invalidate_range(struct mm_struct *mm,
__mmu_notifier_invalidate_range(mm, start, end); __mmu_notifier_invalidate_range(mm, start, end);
} }
static inline void mmu_notifier_mm_init(struct mm_struct *mm) static inline void mmu_notifier_subscriptions_init(struct mm_struct *mm)
{ {
mm->mmu_notifier_mm = NULL; mm->notifier_subscriptions = NULL;
} }
static inline void mmu_notifier_mm_destroy(struct mm_struct *mm) static inline void mmu_notifier_subscriptions_destroy(struct mm_struct *mm)
{ {
if (mm_has_notifiers(mm)) if (mm_has_notifiers(mm))
__mmu_notifier_mm_destroy(mm); __mmu_notifier_subscriptions_destroy(mm);
} }
...@@ -692,11 +698,11 @@ static inline void mmu_notifier_invalidate_range(struct mm_struct *mm, ...@@ -692,11 +698,11 @@ static inline void mmu_notifier_invalidate_range(struct mm_struct *mm,
{ {
} }
static inline void mmu_notifier_mm_init(struct mm_struct *mm) static inline void mmu_notifier_subscriptions_init(struct mm_struct *mm)
{ {
} }
static inline void mmu_notifier_mm_destroy(struct mm_struct *mm) static inline void mmu_notifier_subscriptions_destroy(struct mm_struct *mm)
{ {
} }
......
...@@ -692,7 +692,7 @@ void __mmdrop(struct mm_struct *mm) ...@@ -692,7 +692,7 @@ void __mmdrop(struct mm_struct *mm)
WARN_ON_ONCE(mm == current->active_mm); WARN_ON_ONCE(mm == current->active_mm);
mm_free_pgd(mm); mm_free_pgd(mm);
destroy_context(mm); destroy_context(mm);
mmu_notifier_mm_destroy(mm); mmu_notifier_subscriptions_destroy(mm);
check_mm(mm); check_mm(mm);
put_user_ns(mm->user_ns); put_user_ns(mm->user_ns);
free_mm(mm); free_mm(mm);
...@@ -1025,7 +1025,7 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, ...@@ -1025,7 +1025,7 @@ static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
mm_init_aio(mm); mm_init_aio(mm);
mm_init_owner(mm, p); mm_init_owner(mm, p);
RCU_INIT_POINTER(mm->exe_file, NULL); RCU_INIT_POINTER(mm->exe_file, NULL);
mmu_notifier_mm_init(mm); mmu_notifier_subscriptions_init(mm);
init_tlb_flush_pending(mm); init_tlb_flush_pending(mm);
#if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS #if defined(CONFIG_TRANSPARENT_HUGEPAGE) && !USE_SPLIT_PMD_PTLOCKS
mm->pmd_huge_pte = NULL; mm->pmd_huge_pte = NULL;
......
...@@ -153,7 +153,7 @@ void dump_mm(const struct mm_struct *mm) ...@@ -153,7 +153,7 @@ void dump_mm(const struct mm_struct *mm)
#endif #endif
"exe_file %px\n" "exe_file %px\n"
#ifdef CONFIG_MMU_NOTIFIER #ifdef CONFIG_MMU_NOTIFIER
"mmu_notifier_mm %px\n" "notifier_subscriptions %px\n"
#endif #endif
#ifdef CONFIG_NUMA_BALANCING #ifdef CONFIG_NUMA_BALANCING
"numa_next_scan %lu numa_scan_offset %lu numa_scan_seq %d\n" "numa_next_scan %lu numa_scan_offset %lu numa_scan_seq %d\n"
...@@ -185,7 +185,7 @@ void dump_mm(const struct mm_struct *mm) ...@@ -185,7 +185,7 @@ void dump_mm(const struct mm_struct *mm)
#endif #endif
mm->exe_file, mm->exe_file,
#ifdef CONFIG_MMU_NOTIFIER #ifdef CONFIG_MMU_NOTIFIER
mm->mmu_notifier_mm, mm->notifier_subscriptions,
#endif #endif
#ifdef CONFIG_NUMA_BALANCING #ifdef CONFIG_NUMA_BALANCING
mm->numa_next_scan, mm->numa_scan_offset, mm->numa_scan_seq, mm->numa_next_scan, mm->numa_scan_offset, mm->numa_scan_seq,
......
...@@ -29,12 +29,12 @@ struct lockdep_map __mmu_notifier_invalidate_range_start_map = { ...@@ -29,12 +29,12 @@ struct lockdep_map __mmu_notifier_invalidate_range_start_map = {
#endif #endif
/* /*
* The mmu notifier_mm structure is allocated and installed in * The mmu_notifier_subscriptions structure is allocated and installed in
* mm->mmu_notifier_mm inside the mm_take_all_locks() protected * mm->notifier_subscriptions inside the mm_take_all_locks() protected
* critical section and it's released only when mm_count reaches zero * critical section and it's released only when mm_count reaches zero
* in mmdrop(). * in mmdrop().
*/ */
struct mmu_notifier_mm { struct mmu_notifier_subscriptions {
/* all mmu notifiers registered in this mm are queued in this list */ /* all mmu notifiers registered in this mm are queued in this list */
struct hlist_head list; struct hlist_head list;
bool has_itree; bool has_itree;
...@@ -65,80 +65,81 @@ struct mmu_notifier_mm { ...@@ -65,80 +65,81 @@ struct mmu_notifier_mm {
* *
* The write side has two states, fully excluded: * The write side has two states, fully excluded:
* - mm->active_invalidate_ranges != 0 * - mm->active_invalidate_ranges != 0
* - mnn->invalidate_seq & 1 == True (odd) * - subscriptions->invalidate_seq & 1 == True (odd)
* - some range on the mm_struct is being invalidated * - some range on the mm_struct is being invalidated
* - the itree is not allowed to change * - the itree is not allowed to change
* *
* And partially excluded: * And partially excluded:
* - mm->active_invalidate_ranges != 0 * - mm->active_invalidate_ranges != 0
* - mnn->invalidate_seq & 1 == False (even) * - subscriptions->invalidate_seq & 1 == False (even)
* - some range on the mm_struct is being invalidated * - some range on the mm_struct is being invalidated
* - the itree is allowed to change * - the itree is allowed to change
* *
* Operations on mmu_notifier_mm->invalidate_seq (under spinlock): * Operations on notifier_subscriptions->invalidate_seq (under spinlock):
* seq |= 1 # Begin writing * seq |= 1 # Begin writing
* seq++ # Release the writing state * seq++ # Release the writing state
* seq & 1 # True if a writer exists * seq & 1 # True if a writer exists
* *
* The later state avoids some expensive work on inv_end in the common case of * The later state avoids some expensive work on inv_end in the common case of
* no mni monitoring the VA. * no mmu_interval_notifier monitoring the VA.
*/ */
static bool mn_itree_is_invalidating(struct mmu_notifier_mm *mmn_mm) static bool
mn_itree_is_invalidating(struct mmu_notifier_subscriptions *subscriptions)
{ {
lockdep_assert_held(&mmn_mm->lock); lockdep_assert_held(&subscriptions->lock);
return mmn_mm->invalidate_seq & 1; return subscriptions->invalidate_seq & 1;
} }
static struct mmu_interval_notifier * static struct mmu_interval_notifier *
mn_itree_inv_start_range(struct mmu_notifier_mm *mmn_mm, mn_itree_inv_start_range(struct mmu_notifier_subscriptions *subscriptions,
const struct mmu_notifier_range *range, const struct mmu_notifier_range *range,
unsigned long *seq) unsigned long *seq)
{ {
struct interval_tree_node *node; struct interval_tree_node *node;
struct mmu_interval_notifier *res = NULL; struct mmu_interval_notifier *res = NULL;
spin_lock(&mmn_mm->lock); spin_lock(&subscriptions->lock);
mmn_mm->active_invalidate_ranges++; subscriptions->active_invalidate_ranges++;
node = interval_tree_iter_first(&mmn_mm->itree, range->start, node = interval_tree_iter_first(&subscriptions->itree, range->start,
range->end - 1); range->end - 1);
if (node) { if (node) {
mmn_mm->invalidate_seq |= 1; subscriptions->invalidate_seq |= 1;
res = container_of(node, struct mmu_interval_notifier, res = container_of(node, struct mmu_interval_notifier,
interval_tree); interval_tree);
} }
*seq = mmn_mm->invalidate_seq; *seq = subscriptions->invalidate_seq;
spin_unlock(&mmn_mm->lock); spin_unlock(&subscriptions->lock);
return res; return res;
} }
static struct mmu_interval_notifier * static struct mmu_interval_notifier *
mn_itree_inv_next(struct mmu_interval_notifier *mni, mn_itree_inv_next(struct mmu_interval_notifier *interval_sub,
const struct mmu_notifier_range *range) const struct mmu_notifier_range *range)
{ {
struct interval_tree_node *node; struct interval_tree_node *node;
node = interval_tree_iter_next(&mni->interval_tree, range->start, node = interval_tree_iter_next(&interval_sub->interval_tree,
range->end - 1); range->start, range->end - 1);
if (!node) if (!node)
return NULL; return NULL;
return container_of(node, struct mmu_interval_notifier, interval_tree); return container_of(node, struct mmu_interval_notifier, interval_tree);
} }
static void mn_itree_inv_end(struct mmu_notifier_mm *mmn_mm) static void mn_itree_inv_end(struct mmu_notifier_subscriptions *subscriptions)
{ {
struct mmu_interval_notifier *mni; struct mmu_interval_notifier *interval_sub;
struct hlist_node *next; struct hlist_node *next;
spin_lock(&mmn_mm->lock); spin_lock(&subscriptions->lock);
if (--mmn_mm->active_invalidate_ranges || if (--subscriptions->active_invalidate_ranges ||
!mn_itree_is_invalidating(mmn_mm)) { !mn_itree_is_invalidating(subscriptions)) {
spin_unlock(&mmn_mm->lock); spin_unlock(&subscriptions->lock);
return; return;
} }
/* Make invalidate_seq even */ /* Make invalidate_seq even */
mmn_mm->invalidate_seq++; subscriptions->invalidate_seq++;
/* /*
* The inv_end incorporates a deferred mechanism like rtnl_unlock(). * The inv_end incorporates a deferred mechanism like rtnl_unlock().
...@@ -146,30 +147,31 @@ static void mn_itree_inv_end(struct mmu_notifier_mm *mmn_mm) ...@@ -146,30 +147,31 @@ static void mn_itree_inv_end(struct mmu_notifier_mm *mmn_mm)
* they are progressed. This arrangement for tree updates is used to * they are progressed. This arrangement for tree updates is used to
* avoid using a blocking lock during invalidate_range_start. * avoid using a blocking lock during invalidate_range_start.
*/ */
hlist_for_each_entry_safe(mni, next, &mmn_mm->deferred_list, hlist_for_each_entry_safe(interval_sub, next,
&subscriptions->deferred_list,
deferred_item) { deferred_item) {
if (RB_EMPTY_NODE(&mni->interval_tree.rb)) if (RB_EMPTY_NODE(&interval_sub->interval_tree.rb))
interval_tree_insert(&mni->interval_tree, interval_tree_insert(&interval_sub->interval_tree,
&mmn_mm->itree); &subscriptions->itree);
else else
interval_tree_remove(&mni->interval_tree, interval_tree_remove(&interval_sub->interval_tree,
&mmn_mm->itree); &subscriptions->itree);
hlist_del(&mni->deferred_item); hlist_del(&interval_sub->deferred_item);
} }
spin_unlock(&mmn_mm->lock); spin_unlock(&subscriptions->lock);
wake_up_all(&mmn_mm->wq); wake_up_all(&subscriptions->wq);
} }
/** /**
* mmu_interval_read_begin - Begin a read side critical section against a VA * mmu_interval_read_begin - Begin a read side critical section against a VA
* range * range
* mni: The range to use * interval_sub: The interval subscription
* *
* mmu_iterval_read_begin()/mmu_iterval_read_retry() implement a * mmu_iterval_read_begin()/mmu_iterval_read_retry() implement a
* collision-retry scheme similar to seqcount for the VA range under mni. If * collision-retry scheme similar to seqcount for the VA range under
* the mm invokes invalidation during the critical section then * subscription. If the mm invokes invalidation during the critical section
* mmu_interval_read_retry() will return true. * then mmu_interval_read_retry() will return true.
* *
* This is useful to obtain shadow PTEs where teardown or setup of the SPTEs * This is useful to obtain shadow PTEs where teardown or setup of the SPTEs
* require a blocking context. The critical region formed by this can sleep, * require a blocking context. The critical region formed by this can sleep,
...@@ -180,68 +182,71 @@ static void mn_itree_inv_end(struct mmu_notifier_mm *mmn_mm) ...@@ -180,68 +182,71 @@ static void mn_itree_inv_end(struct mmu_notifier_mm *mmn_mm)
* *
* The return value should be passed to mmu_interval_read_retry(). * The return value should be passed to mmu_interval_read_retry().
*/ */
unsigned long mmu_interval_read_begin(struct mmu_interval_notifier *mni) unsigned long
mmu_interval_read_begin(struct mmu_interval_notifier *interval_sub)
{ {
struct mmu_notifier_mm *mmn_mm = mni->mm->mmu_notifier_mm; struct mmu_notifier_subscriptions *subscriptions =
interval_sub->mm->notifier_subscriptions;
unsigned long seq; unsigned long seq;
bool is_invalidating; bool is_invalidating;
/* /*
* If the mni has a different seq value under the user_lock than we * If the subscription has a different seq value under the user_lock
* started with then it has collided. * than we started with then it has collided.
* *
* If the mni currently has the same seq value as the mmn_mm seq, then * If the subscription currently has the same seq value as the
* it is currently between invalidate_start/end and is colliding. * subscriptions seq, then it is currently between
* invalidate_start/end and is colliding.
* *
* The locking looks broadly like this: * The locking looks broadly like this:
* mn_tree_invalidate_start(): mmu_interval_read_begin(): * mn_tree_invalidate_start(): mmu_interval_read_begin():
* spin_lock * spin_lock
* seq = READ_ONCE(mni->invalidate_seq); * seq = READ_ONCE(interval_sub->invalidate_seq);
* seq == mmn_mm->invalidate_seq * seq == subs->invalidate_seq
* spin_unlock * spin_unlock
* spin_lock * spin_lock
* seq = ++mmn_mm->invalidate_seq * seq = ++subscriptions->invalidate_seq
* spin_unlock * spin_unlock
* op->invalidate_range(): * op->invalidate_range():
* user_lock * user_lock
* mmu_interval_set_seq() * mmu_interval_set_seq()
* mni->invalidate_seq = seq * interval_sub->invalidate_seq = seq
* user_unlock * user_unlock
* *
* [Required: mmu_interval_read_retry() == true] * [Required: mmu_interval_read_retry() == true]
* *
* mn_itree_inv_end(): * mn_itree_inv_end():
* spin_lock * spin_lock
* seq = ++mmn_mm->invalidate_seq * seq = ++subscriptions->invalidate_seq
* spin_unlock * spin_unlock
* *
* user_lock * user_lock
* mmu_interval_read_retry(): * mmu_interval_read_retry():
* mni->invalidate_seq != seq * interval_sub->invalidate_seq != seq
* user_unlock * user_unlock
* *
* Barriers are not needed here as any races here are closed by an * Barriers are not needed here as any races here are closed by an
* eventual mmu_interval_read_retry(), which provides a barrier via the * eventual mmu_interval_read_retry(), which provides a barrier via the
* user_lock. * user_lock.
*/ */
spin_lock(&mmn_mm->lock); spin_lock(&subscriptions->lock);
/* Pairs with the WRITE_ONCE in mmu_interval_set_seq() */ /* Pairs with the WRITE_ONCE in mmu_interval_set_seq() */
seq = READ_ONCE(mni->invalidate_seq); seq = READ_ONCE(interval_sub->invalidate_seq);
is_invalidating = seq == mmn_mm->invalidate_seq; is_invalidating = seq == subscriptions->invalidate_seq;
spin_unlock(&mmn_mm->lock); spin_unlock(&subscriptions->lock);
/* /*
* mni->invalidate_seq must always be set to an odd value via * interval_sub->invalidate_seq must always be set to an odd value via
* mmu_interval_set_seq() using the provided cur_seq from * mmu_interval_set_seq() using the provided cur_seq from
* mn_itree_inv_start_range(). This ensures that if seq does wrap we * mn_itree_inv_start_range(). This ensures that if seq does wrap we
* will always clear the below sleep in some reasonable time as * will always clear the below sleep in some reasonable time as
* mmn_mm->invalidate_seq is even in the idle state. * subscriptions->invalidate_seq is even in the idle state.
*/ */
lock_map_acquire(&__mmu_notifier_invalidate_range_start_map); lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
lock_map_release(&__mmu_notifier_invalidate_range_start_map); lock_map_release(&__mmu_notifier_invalidate_range_start_map);
if (is_invalidating) if (is_invalidating)
wait_event(mmn_mm->wq, wait_event(subscriptions->wq,
READ_ONCE(mmn_mm->invalidate_seq) != seq); READ_ONCE(subscriptions->invalidate_seq) != seq);
/* /*
* Notice that mmu_interval_read_retry() can already be true at this * Notice that mmu_interval_read_retry() can already be true at this
...@@ -253,7 +258,7 @@ unsigned long mmu_interval_read_begin(struct mmu_interval_notifier *mni) ...@@ -253,7 +258,7 @@ unsigned long mmu_interval_read_begin(struct mmu_interval_notifier *mni)
} }
EXPORT_SYMBOL_GPL(mmu_interval_read_begin); EXPORT_SYMBOL_GPL(mmu_interval_read_begin);
static void mn_itree_release(struct mmu_notifier_mm *mmn_mm, static void mn_itree_release(struct mmu_notifier_subscriptions *subscriptions,
struct mm_struct *mm) struct mm_struct *mm)
{ {
struct mmu_notifier_range range = { struct mmu_notifier_range range = {
...@@ -263,17 +268,20 @@ static void mn_itree_release(struct mmu_notifier_mm *mmn_mm, ...@@ -263,17 +268,20 @@ static void mn_itree_release(struct mmu_notifier_mm *mmn_mm,
.start = 0, .start = 0,
.end = ULONG_MAX, .end = ULONG_MAX,
}; };
struct mmu_interval_notifier *mni; struct mmu_interval_notifier *interval_sub;
unsigned long cur_seq; unsigned long cur_seq;
bool ret; bool ret;
for (mni = mn_itree_inv_start_range(mmn_mm, &range, &cur_seq); mni; for (interval_sub =
mni = mn_itree_inv_next(mni, &range)) { mn_itree_inv_start_range(subscriptions, &range, &cur_seq);
ret = mni->ops->invalidate(mni, &range, cur_seq); interval_sub;
interval_sub = mn_itree_inv_next(interval_sub, &range)) {
ret = interval_sub->ops->invalidate(interval_sub, &range,
cur_seq);
WARN_ON(!ret); WARN_ON(!ret);
} }
mn_itree_inv_end(mmn_mm); mn_itree_inv_end(subscriptions);
} }
/* /*
...@@ -283,15 +291,15 @@ static void mn_itree_release(struct mmu_notifier_mm *mmn_mm, ...@@ -283,15 +291,15 @@ static void mn_itree_release(struct mmu_notifier_mm *mmn_mm,
* in parallel despite there being no task using this mm any more, * in parallel despite there being no task using this mm any more,
* through the vmas outside of the exit_mmap context, such as with * through the vmas outside of the exit_mmap context, such as with
* vmtruncate. This serializes against mmu_notifier_unregister with * vmtruncate. This serializes against mmu_notifier_unregister with
* the mmu_notifier_mm->lock in addition to SRCU and it serializes * the notifier_subscriptions->lock in addition to SRCU and it serializes
* against the other mmu notifiers with SRCU. struct mmu_notifier_mm * against the other mmu notifiers with SRCU. struct mmu_notifier_subscriptions
* can't go away from under us as exit_mmap holds an mm_count pin * can't go away from under us as exit_mmap holds an mm_count pin
* itself. * itself.
*/ */
static void mn_hlist_release(struct mmu_notifier_mm *mmn_mm, static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
struct mm_struct *mm) struct mm_struct *mm)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
int id; int id;
/* /*
...@@ -299,29 +307,29 @@ static void mn_hlist_release(struct mmu_notifier_mm *mmn_mm, ...@@ -299,29 +307,29 @@ static void mn_hlist_release(struct mmu_notifier_mm *mmn_mm,
* ->release returns. * ->release returns.
*/ */
id = srcu_read_lock(&srcu); id = srcu_read_lock(&srcu);
hlist_for_each_entry_rcu(mn, &mmn_mm->list, hlist) hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist)
/* /*
* If ->release runs before mmu_notifier_unregister it must be * If ->release runs before mmu_notifier_unregister it must be
* handled, as it's the only way for the driver to flush all * handled, as it's the only way for the driver to flush all
* existing sptes and stop the driver from establishing any more * existing sptes and stop the driver from establishing any more
* sptes before all the pages in the mm are freed. * sptes before all the pages in the mm are freed.
*/ */
if (mn->ops->release) if (subscription->ops->release)
mn->ops->release(mn, mm); subscription->ops->release(subscription, mm);
spin_lock(&mmn_mm->lock); spin_lock(&subscriptions->lock);
while (unlikely(!hlist_empty(&mmn_mm->list))) { while (unlikely(!hlist_empty(&subscriptions->list))) {
mn = hlist_entry(mmn_mm->list.first, struct mmu_notifier, subscription = hlist_entry(subscriptions->list.first,
hlist); struct mmu_notifier, hlist);
/* /*
* We arrived before mmu_notifier_unregister so * We arrived before mmu_notifier_unregister so
* mmu_notifier_unregister will do nothing other than to wait * mmu_notifier_unregister will do nothing other than to wait
* for ->release to finish and for mmu_notifier_unregister to * for ->release to finish and for mmu_notifier_unregister to
* return. * return.
*/ */
hlist_del_init_rcu(&mn->hlist); hlist_del_init_rcu(&subscription->hlist);
} }
spin_unlock(&mmn_mm->lock); spin_unlock(&subscriptions->lock);
srcu_read_unlock(&srcu, id); srcu_read_unlock(&srcu, id);
/* /*
...@@ -330,21 +338,22 @@ static void mn_hlist_release(struct mmu_notifier_mm *mmn_mm, ...@@ -330,21 +338,22 @@ static void mn_hlist_release(struct mmu_notifier_mm *mmn_mm,
* until the ->release method returns, if it was invoked by * until the ->release method returns, if it was invoked by
* mmu_notifier_unregister. * mmu_notifier_unregister.
* *
* The mmu_notifier_mm can't go away from under us because one mm_count * The notifier_subscriptions can't go away from under us because
* is held by exit_mmap. * one mm_count is held by exit_mmap.
*/ */
synchronize_srcu(&srcu); synchronize_srcu(&srcu);
} }
void __mmu_notifier_release(struct mm_struct *mm) void __mmu_notifier_release(struct mm_struct *mm)
{ {
struct mmu_notifier_mm *mmn_mm = mm->mmu_notifier_mm; struct mmu_notifier_subscriptions *subscriptions =
mm->notifier_subscriptions;
if (mmn_mm->has_itree) if (subscriptions->has_itree)
mn_itree_release(mmn_mm, mm); mn_itree_release(subscriptions, mm);
if (!hlist_empty(&mmn_mm->list)) if (!hlist_empty(&subscriptions->list))
mn_hlist_release(mmn_mm, mm); mn_hlist_release(subscriptions, mm);
} }
/* /*
...@@ -356,13 +365,15 @@ int __mmu_notifier_clear_flush_young(struct mm_struct *mm, ...@@ -356,13 +365,15 @@ int __mmu_notifier_clear_flush_young(struct mm_struct *mm,
unsigned long start, unsigned long start,
unsigned long end) unsigned long end)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
int young = 0, id; int young = 0, id;
id = srcu_read_lock(&srcu); id = srcu_read_lock(&srcu);
hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list, hlist) { hlist_for_each_entry_rcu(subscription,
if (mn->ops->clear_flush_young) &mm->notifier_subscriptions->list, hlist) {
young |= mn->ops->clear_flush_young(mn, mm, start, end); if (subscription->ops->clear_flush_young)
young |= subscription->ops->clear_flush_young(
subscription, mm, start, end);
} }
srcu_read_unlock(&srcu, id); srcu_read_unlock(&srcu, id);
...@@ -373,13 +384,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm, ...@@ -373,13 +384,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm,
unsigned long start, unsigned long start,
unsigned long end) unsigned long end)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
int young = 0, id; int young = 0, id;
id = srcu_read_lock(&srcu); id = srcu_read_lock(&srcu);
hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list, hlist) { hlist_for_each_entry_rcu(subscription,
if (mn->ops->clear_young) &mm->notifier_subscriptions->list, hlist) {
young |= mn->ops->clear_young(mn, mm, start, end); if (subscription->ops->clear_young)
young |= subscription->ops->clear_young(subscription,
mm, start, end);
} }
srcu_read_unlock(&srcu, id); srcu_read_unlock(&srcu, id);
...@@ -389,13 +402,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm, ...@@ -389,13 +402,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm,
int __mmu_notifier_test_young(struct mm_struct *mm, int __mmu_notifier_test_young(struct mm_struct *mm,
unsigned long address) unsigned long address)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
int young = 0, id; int young = 0, id;
id = srcu_read_lock(&srcu); id = srcu_read_lock(&srcu);
hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list, hlist) { hlist_for_each_entry_rcu(subscription,
if (mn->ops->test_young) { &mm->notifier_subscriptions->list, hlist) {
young = mn->ops->test_young(mn, mm, address); if (subscription->ops->test_young) {
young = subscription->ops->test_young(subscription, mm,
address);
if (young) if (young)
break; break;
} }
...@@ -408,28 +423,33 @@ int __mmu_notifier_test_young(struct mm_struct *mm, ...@@ -408,28 +423,33 @@ int __mmu_notifier_test_young(struct mm_struct *mm,
void __mmu_notifier_change_pte(struct mm_struct *mm, unsigned long address, void __mmu_notifier_change_pte(struct mm_struct *mm, unsigned long address,
pte_t pte) pte_t pte)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
int id; int id;
id = srcu_read_lock(&srcu); id = srcu_read_lock(&srcu);
hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list, hlist) { hlist_for_each_entry_rcu(subscription,
if (mn->ops->change_pte) &mm->notifier_subscriptions->list, hlist) {
mn->ops->change_pte(mn, mm, address, pte); if (subscription->ops->change_pte)
subscription->ops->change_pte(subscription, mm, address,
pte);
} }
srcu_read_unlock(&srcu, id); srcu_read_unlock(&srcu, id);
} }
static int mn_itree_invalidate(struct mmu_notifier_mm *mmn_mm, static int mn_itree_invalidate(struct mmu_notifier_subscriptions *subscriptions,
const struct mmu_notifier_range *range) const struct mmu_notifier_range *range)
{ {
struct mmu_interval_notifier *mni; struct mmu_interval_notifier *interval_sub;
unsigned long cur_seq; unsigned long cur_seq;
for (mni = mn_itree_inv_start_range(mmn_mm, range, &cur_seq); mni; for (interval_sub =
mni = mn_itree_inv_next(mni, range)) { mn_itree_inv_start_range(subscriptions, range, &cur_seq);
interval_sub;
interval_sub = mn_itree_inv_next(interval_sub, range)) {
bool ret; bool ret;
ret = mni->ops->invalidate(mni, range, cur_seq); ret = interval_sub->ops->invalidate(interval_sub, range,
cur_seq);
if (!ret) { if (!ret) {
if (WARN_ON(mmu_notifier_range_blockable(range))) if (WARN_ON(mmu_notifier_range_blockable(range)))
continue; continue;
...@@ -443,31 +463,36 @@ static int mn_itree_invalidate(struct mmu_notifier_mm *mmn_mm, ...@@ -443,31 +463,36 @@ static int mn_itree_invalidate(struct mmu_notifier_mm *mmn_mm,
* On -EAGAIN the non-blocking caller is not allowed to call * On -EAGAIN the non-blocking caller is not allowed to call
* invalidate_range_end() * invalidate_range_end()
*/ */
mn_itree_inv_end(mmn_mm); mn_itree_inv_end(subscriptions);
return -EAGAIN; return -EAGAIN;
} }
static int mn_hlist_invalidate_range_start(struct mmu_notifier_mm *mmn_mm, static int mn_hlist_invalidate_range_start(
struct mmu_notifier_range *range) struct mmu_notifier_subscriptions *subscriptions,
struct mmu_notifier_range *range)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
int ret = 0; int ret = 0;
int id; int id;
id = srcu_read_lock(&srcu); id = srcu_read_lock(&srcu);
hlist_for_each_entry_rcu(mn, &mmn_mm->list, hlist) { hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist) {
if (mn->ops->invalidate_range_start) { const struct mmu_notifier_ops *ops = subscription->ops;
if (ops->invalidate_range_start) {
int _ret; int _ret;
if (!mmu_notifier_range_blockable(range)) if (!mmu_notifier_range_blockable(range))
non_block_start(); non_block_start();
_ret = mn->ops->invalidate_range_start(mn, range); _ret = ops->invalidate_range_start(subscription, range);
if (!mmu_notifier_range_blockable(range)) if (!mmu_notifier_range_blockable(range))
non_block_end(); non_block_end();
if (_ret) { if (_ret) {
pr_info("%pS callback failed with %d in %sblockable context.\n", pr_info("%pS callback failed with %d in %sblockable context.\n",
mn->ops->invalidate_range_start, _ret, ops->invalidate_range_start, _ret,
!mmu_notifier_range_blockable(range) ? "non-" : ""); !mmu_notifier_range_blockable(range) ?
"non-" :
"");
WARN_ON(mmu_notifier_range_blockable(range) || WARN_ON(mmu_notifier_range_blockable(range) ||
_ret != -EAGAIN); _ret != -EAGAIN);
ret = _ret; ret = _ret;
...@@ -481,28 +506,29 @@ static int mn_hlist_invalidate_range_start(struct mmu_notifier_mm *mmn_mm, ...@@ -481,28 +506,29 @@ static int mn_hlist_invalidate_range_start(struct mmu_notifier_mm *mmn_mm,
int __mmu_notifier_invalidate_range_start(struct mmu_notifier_range *range) int __mmu_notifier_invalidate_range_start(struct mmu_notifier_range *range)
{ {
struct mmu_notifier_mm *mmn_mm = range->mm->mmu_notifier_mm; struct mmu_notifier_subscriptions *subscriptions =
range->mm->notifier_subscriptions;
int ret; int ret;
if (mmn_mm->has_itree) { if (subscriptions->has_itree) {
ret = mn_itree_invalidate(mmn_mm, range); ret = mn_itree_invalidate(subscriptions, range);
if (ret) if (ret)
return ret; return ret;
} }
if (!hlist_empty(&mmn_mm->list)) if (!hlist_empty(&subscriptions->list))
return mn_hlist_invalidate_range_start(mmn_mm, range); return mn_hlist_invalidate_range_start(subscriptions, range);
return 0; return 0;
} }
static void mn_hlist_invalidate_end(struct mmu_notifier_mm *mmn_mm, static void
struct mmu_notifier_range *range, mn_hlist_invalidate_end(struct mmu_notifier_subscriptions *subscriptions,
bool only_end) struct mmu_notifier_range *range, bool only_end)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
int id; int id;
id = srcu_read_lock(&srcu); id = srcu_read_lock(&srcu);
hlist_for_each_entry_rcu(mn, &mmn_mm->list, hlist) { hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist) {
/* /*
* Call invalidate_range here too to avoid the need for the * Call invalidate_range here too to avoid the need for the
* subsystem of having to register an invalidate_range_end * subsystem of having to register an invalidate_range_end
...@@ -516,14 +542,16 @@ static void mn_hlist_invalidate_end(struct mmu_notifier_mm *mmn_mm, ...@@ -516,14 +542,16 @@ static void mn_hlist_invalidate_end(struct mmu_notifier_mm *mmn_mm,
* is safe to do when we know that a call to invalidate_range() * is safe to do when we know that a call to invalidate_range()
* already happen under page table lock. * already happen under page table lock.
*/ */
if (!only_end && mn->ops->invalidate_range) if (!only_end && subscription->ops->invalidate_range)
mn->ops->invalidate_range(mn, range->mm, subscription->ops->invalidate_range(subscription,
range->start, range->mm,
range->end); range->start,
if (mn->ops->invalidate_range_end) { range->end);
if (subscription->ops->invalidate_range_end) {
if (!mmu_notifier_range_blockable(range)) if (!mmu_notifier_range_blockable(range))
non_block_start(); non_block_start();
mn->ops->invalidate_range_end(mn, range); subscription->ops->invalidate_range_end(subscription,
range);
if (!mmu_notifier_range_blockable(range)) if (!mmu_notifier_range_blockable(range))
non_block_end(); non_block_end();
} }
...@@ -534,27 +562,30 @@ static void mn_hlist_invalidate_end(struct mmu_notifier_mm *mmn_mm, ...@@ -534,27 +562,30 @@ static void mn_hlist_invalidate_end(struct mmu_notifier_mm *mmn_mm,
void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range, void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
bool only_end) bool only_end)
{ {
struct mmu_notifier_mm *mmn_mm = range->mm->mmu_notifier_mm; struct mmu_notifier_subscriptions *subscriptions =
range->mm->notifier_subscriptions;
lock_map_acquire(&__mmu_notifier_invalidate_range_start_map); lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
if (mmn_mm->has_itree) if (subscriptions->has_itree)
mn_itree_inv_end(mmn_mm); mn_itree_inv_end(subscriptions);
if (!hlist_empty(&mmn_mm->list)) if (!hlist_empty(&subscriptions->list))
mn_hlist_invalidate_end(mmn_mm, range, only_end); mn_hlist_invalidate_end(subscriptions, range, only_end);
lock_map_release(&__mmu_notifier_invalidate_range_start_map); lock_map_release(&__mmu_notifier_invalidate_range_start_map);
} }
void __mmu_notifier_invalidate_range(struct mm_struct *mm, void __mmu_notifier_invalidate_range(struct mm_struct *mm,
unsigned long start, unsigned long end) unsigned long start, unsigned long end)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
int id; int id;
id = srcu_read_lock(&srcu); id = srcu_read_lock(&srcu);
hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list, hlist) { hlist_for_each_entry_rcu(subscription,
if (mn->ops->invalidate_range) &mm->notifier_subscriptions->list, hlist) {
mn->ops->invalidate_range(mn, mm, start, end); if (subscription->ops->invalidate_range)
subscription->ops->invalidate_range(subscription, mm,
start, end);
} }
srcu_read_unlock(&srcu, id); srcu_read_unlock(&srcu, id);
} }
...@@ -564,9 +595,10 @@ void __mmu_notifier_invalidate_range(struct mm_struct *mm, ...@@ -564,9 +595,10 @@ void __mmu_notifier_invalidate_range(struct mm_struct *mm,
* write mode. A NULL mn signals the notifier is being registered for itree * write mode. A NULL mn signals the notifier is being registered for itree
* mode. * mode.
*/ */
int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm) int __mmu_notifier_register(struct mmu_notifier *subscription,
struct mm_struct *mm)
{ {
struct mmu_notifier_mm *mmu_notifier_mm = NULL; struct mmu_notifier_subscriptions *subscriptions = NULL;
int ret; int ret;
lockdep_assert_held_write(&mm->mmap_sem); lockdep_assert_held_write(&mm->mmap_sem);
...@@ -579,23 +611,23 @@ int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm) ...@@ -579,23 +611,23 @@ int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
fs_reclaim_release(GFP_KERNEL); fs_reclaim_release(GFP_KERNEL);
} }
if (!mm->mmu_notifier_mm) { if (!mm->notifier_subscriptions) {
/* /*
* kmalloc cannot be called under mm_take_all_locks(), but we * kmalloc cannot be called under mm_take_all_locks(), but we
* know that mm->mmu_notifier_mm can't change while we hold * know that mm->notifier_subscriptions can't change while we
* the write side of the mmap_sem. * hold the write side of the mmap_sem.
*/ */
mmu_notifier_mm = subscriptions = kzalloc(
kzalloc(sizeof(struct mmu_notifier_mm), GFP_KERNEL); sizeof(struct mmu_notifier_subscriptions), GFP_KERNEL);
if (!mmu_notifier_mm) if (!subscriptions)
return -ENOMEM; return -ENOMEM;
INIT_HLIST_HEAD(&mmu_notifier_mm->list); INIT_HLIST_HEAD(&subscriptions->list);
spin_lock_init(&mmu_notifier_mm->lock); spin_lock_init(&subscriptions->lock);
mmu_notifier_mm->invalidate_seq = 2; subscriptions->invalidate_seq = 2;
mmu_notifier_mm->itree = RB_ROOT_CACHED; subscriptions->itree = RB_ROOT_CACHED;
init_waitqueue_head(&mmu_notifier_mm->wq); init_waitqueue_head(&subscriptions->wq);
INIT_HLIST_HEAD(&mmu_notifier_mm->deferred_list); INIT_HLIST_HEAD(&subscriptions->deferred_list);
} }
ret = mm_take_all_locks(mm); ret = mm_take_all_locks(mm);
...@@ -610,34 +642,36 @@ int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm) ...@@ -610,34 +642,36 @@ int __mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm)
* We can't race against any other mmu notifier method either * We can't race against any other mmu notifier method either
* thanks to mm_take_all_locks(). * thanks to mm_take_all_locks().
* *
* release semantics on the initialization of the mmu_notifier_mm's * release semantics on the initialization of the
* contents are provided for unlocked readers. acquire can only be * mmu_notifier_subscriptions's contents are provided for unlocked
* used while holding the mmgrab or mmget, and is safe because once * readers. acquire can only be used while holding the mmgrab or
* created the mmu_notififer_mm is not freed until the mm is * mmget, and is safe because once created the
* destroyed. As above, users holding the mmap_sem or one of the * mmu_notifier_subscriptions is not freed until the mm is destroyed.
* As above, users holding the mmap_sem or one of the
* mm_take_all_locks() do not need to use acquire semantics. * mm_take_all_locks() do not need to use acquire semantics.
*/ */
if (mmu_notifier_mm) if (subscriptions)
smp_store_release(&mm->mmu_notifier_mm, mmu_notifier_mm); smp_store_release(&mm->notifier_subscriptions, subscriptions);
if (mn) { if (subscription) {
/* Pairs with the mmdrop in mmu_notifier_unregister_* */ /* Pairs with the mmdrop in mmu_notifier_unregister_* */
mmgrab(mm); mmgrab(mm);
mn->mm = mm; subscription->mm = mm;
mn->users = 1; subscription->users = 1;
spin_lock(&mm->mmu_notifier_mm->lock); spin_lock(&mm->notifier_subscriptions->lock);
hlist_add_head_rcu(&mn->hlist, &mm->mmu_notifier_mm->list); hlist_add_head_rcu(&subscription->hlist,
spin_unlock(&mm->mmu_notifier_mm->lock); &mm->notifier_subscriptions->list);
spin_unlock(&mm->notifier_subscriptions->lock);
} else } else
mm->mmu_notifier_mm->has_itree = true; mm->notifier_subscriptions->has_itree = true;
mm_drop_all_locks(mm); mm_drop_all_locks(mm);
BUG_ON(atomic_read(&mm->mm_users) <= 0); BUG_ON(atomic_read(&mm->mm_users) <= 0);
return 0; return 0;
out_clean: out_clean:
kfree(mmu_notifier_mm); kfree(subscriptions);
return ret; return ret;
} }
EXPORT_SYMBOL_GPL(__mmu_notifier_register); EXPORT_SYMBOL_GPL(__mmu_notifier_register);
...@@ -658,15 +692,16 @@ EXPORT_SYMBOL_GPL(__mmu_notifier_register); ...@@ -658,15 +692,16 @@ EXPORT_SYMBOL_GPL(__mmu_notifier_register);
* mmu_notifier_unregister() or mmu_notifier_put() must be always called to * mmu_notifier_unregister() or mmu_notifier_put() must be always called to
* unregister the notifier. * unregister the notifier.
* *
* While the caller has a mmu_notifier get the mn->mm pointer will remain * While the caller has a mmu_notifier get the subscription->mm pointer will remain
* valid, and can be converted to an active mm pointer via mmget_not_zero(). * valid, and can be converted to an active mm pointer via mmget_not_zero().
*/ */
int mmu_notifier_register(struct mmu_notifier *mn, struct mm_struct *mm) int mmu_notifier_register(struct mmu_notifier *subscription,
struct mm_struct *mm)
{ {
int ret; int ret;
down_write(&mm->mmap_sem); down_write(&mm->mmap_sem);
ret = __mmu_notifier_register(mn, mm); ret = __mmu_notifier_register(subscription, mm);
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
return ret; return ret;
} }
...@@ -675,21 +710,22 @@ EXPORT_SYMBOL_GPL(mmu_notifier_register); ...@@ -675,21 +710,22 @@ EXPORT_SYMBOL_GPL(mmu_notifier_register);
static struct mmu_notifier * static struct mmu_notifier *
find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops) find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
spin_lock(&mm->mmu_notifier_mm->lock); spin_lock(&mm->notifier_subscriptions->lock);
hlist_for_each_entry_rcu (mn, &mm->mmu_notifier_mm->list, hlist) { hlist_for_each_entry_rcu(subscription,
if (mn->ops != ops) &mm->notifier_subscriptions->list, hlist) {
if (subscription->ops != ops)
continue; continue;
if (likely(mn->users != UINT_MAX)) if (likely(subscription->users != UINT_MAX))
mn->users++; subscription->users++;
else else
mn = ERR_PTR(-EOVERFLOW); subscription = ERR_PTR(-EOVERFLOW);
spin_unlock(&mm->mmu_notifier_mm->lock); spin_unlock(&mm->notifier_subscriptions->lock);
return mn; return subscription;
} }
spin_unlock(&mm->mmu_notifier_mm->lock); spin_unlock(&mm->notifier_subscriptions->lock);
return NULL; return NULL;
} }
...@@ -713,37 +749,37 @@ find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops) ...@@ -713,37 +749,37 @@ find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops)
struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops, struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
struct mm_struct *mm) struct mm_struct *mm)
{ {
struct mmu_notifier *mn; struct mmu_notifier *subscription;
int ret; int ret;
lockdep_assert_held_write(&mm->mmap_sem); lockdep_assert_held_write(&mm->mmap_sem);
if (mm->mmu_notifier_mm) { if (mm->notifier_subscriptions) {
mn = find_get_mmu_notifier(mm, ops); subscription = find_get_mmu_notifier(mm, ops);
if (mn) if (subscription)
return mn; return subscription;
} }
mn = ops->alloc_notifier(mm); subscription = ops->alloc_notifier(mm);
if (IS_ERR(mn)) if (IS_ERR(subscription))
return mn; return subscription;
mn->ops = ops; subscription->ops = ops;
ret = __mmu_notifier_register(mn, mm); ret = __mmu_notifier_register(subscription, mm);
if (ret) if (ret)
goto out_free; goto out_free;
return mn; return subscription;
out_free: out_free:
mn->ops->free_notifier(mn); subscription->ops->free_notifier(subscription);
return ERR_PTR(ret); return ERR_PTR(ret);
} }
EXPORT_SYMBOL_GPL(mmu_notifier_get_locked); EXPORT_SYMBOL_GPL(mmu_notifier_get_locked);
/* this is called after the last mmu_notifier_unregister() returned */ /* this is called after the last mmu_notifier_unregister() returned */
void __mmu_notifier_mm_destroy(struct mm_struct *mm) void __mmu_notifier_subscriptions_destroy(struct mm_struct *mm)
{ {
BUG_ON(!hlist_empty(&mm->mmu_notifier_mm->list)); BUG_ON(!hlist_empty(&mm->notifier_subscriptions->list));
kfree(mm->mmu_notifier_mm); kfree(mm->notifier_subscriptions);
mm->mmu_notifier_mm = LIST_POISON1; /* debug */ mm->notifier_subscriptions = LIST_POISON1; /* debug */
} }
/* /*
...@@ -756,11 +792,12 @@ void __mmu_notifier_mm_destroy(struct mm_struct *mm) ...@@ -756,11 +792,12 @@ void __mmu_notifier_mm_destroy(struct mm_struct *mm)
* and only after mmu_notifier_unregister returned we're guaranteed * and only after mmu_notifier_unregister returned we're guaranteed
* that ->release or any other method can't run anymore. * that ->release or any other method can't run anymore.
*/ */
void mmu_notifier_unregister(struct mmu_notifier *mn, struct mm_struct *mm) void mmu_notifier_unregister(struct mmu_notifier *subscription,
struct mm_struct *mm)
{ {
BUG_ON(atomic_read(&mm->mm_count) <= 0); BUG_ON(atomic_read(&mm->mm_count) <= 0);
if (!hlist_unhashed(&mn->hlist)) { if (!hlist_unhashed(&subscription->hlist)) {
/* /*
* SRCU here will force exit_mmap to wait for ->release to * SRCU here will force exit_mmap to wait for ->release to
* finish before freeing the pages. * finish before freeing the pages.
...@@ -772,17 +809,17 @@ void mmu_notifier_unregister(struct mmu_notifier *mn, struct mm_struct *mm) ...@@ -772,17 +809,17 @@ void mmu_notifier_unregister(struct mmu_notifier *mn, struct mm_struct *mm)
* exit_mmap will block in mmu_notifier_release to guarantee * exit_mmap will block in mmu_notifier_release to guarantee
* that ->release is called before freeing the pages. * that ->release is called before freeing the pages.
*/ */
if (mn->ops->release) if (subscription->ops->release)
mn->ops->release(mn, mm); subscription->ops->release(subscription, mm);
srcu_read_unlock(&srcu, id); srcu_read_unlock(&srcu, id);
spin_lock(&mm->mmu_notifier_mm->lock); spin_lock(&mm->notifier_subscriptions->lock);
/* /*
* Can not use list_del_rcu() since __mmu_notifier_release * Can not use list_del_rcu() since __mmu_notifier_release
* can delete it before we hold the lock. * can delete it before we hold the lock.
*/ */
hlist_del_init_rcu(&mn->hlist); hlist_del_init_rcu(&subscription->hlist);
spin_unlock(&mm->mmu_notifier_mm->lock); spin_unlock(&mm->notifier_subscriptions->lock);
} }
/* /*
...@@ -799,10 +836,11 @@ EXPORT_SYMBOL_GPL(mmu_notifier_unregister); ...@@ -799,10 +836,11 @@ EXPORT_SYMBOL_GPL(mmu_notifier_unregister);
static void mmu_notifier_free_rcu(struct rcu_head *rcu) static void mmu_notifier_free_rcu(struct rcu_head *rcu)
{ {
struct mmu_notifier *mn = container_of(rcu, struct mmu_notifier, rcu); struct mmu_notifier *subscription =
struct mm_struct *mm = mn->mm; container_of(rcu, struct mmu_notifier, rcu);
struct mm_struct *mm = subscription->mm;
mn->ops->free_notifier(mn); subscription->ops->free_notifier(subscription);
/* Pairs with the get in __mmu_notifier_register() */ /* Pairs with the get in __mmu_notifier_register() */
mmdrop(mm); mmdrop(mm);
} }
...@@ -829,39 +867,40 @@ static void mmu_notifier_free_rcu(struct rcu_head *rcu) ...@@ -829,39 +867,40 @@ static void mmu_notifier_free_rcu(struct rcu_head *rcu)
* Modules calling this function must call mmu_notifier_synchronize() in * Modules calling this function must call mmu_notifier_synchronize() in
* their __exit functions to ensure the async work is completed. * their __exit functions to ensure the async work is completed.
*/ */
void mmu_notifier_put(struct mmu_notifier *mn) void mmu_notifier_put(struct mmu_notifier *subscription)
{ {
struct mm_struct *mm = mn->mm; struct mm_struct *mm = subscription->mm;
spin_lock(&mm->mmu_notifier_mm->lock); spin_lock(&mm->notifier_subscriptions->lock);
if (WARN_ON(!mn->users) || --mn->users) if (WARN_ON(!subscription->users) || --subscription->users)
goto out_unlock; goto out_unlock;
hlist_del_init_rcu(&mn->hlist); hlist_del_init_rcu(&subscription->hlist);
spin_unlock(&mm->mmu_notifier_mm->lock); spin_unlock(&mm->notifier_subscriptions->lock);
call_srcu(&srcu, &mn->rcu, mmu_notifier_free_rcu); call_srcu(&srcu, &subscription->rcu, mmu_notifier_free_rcu);
return; return;
out_unlock: out_unlock:
spin_unlock(&mm->mmu_notifier_mm->lock); spin_unlock(&mm->notifier_subscriptions->lock);
} }
EXPORT_SYMBOL_GPL(mmu_notifier_put); EXPORT_SYMBOL_GPL(mmu_notifier_put);
static int __mmu_interval_notifier_insert( static int __mmu_interval_notifier_insert(
struct mmu_interval_notifier *mni, struct mm_struct *mm, struct mmu_interval_notifier *interval_sub, struct mm_struct *mm,
struct mmu_notifier_mm *mmn_mm, unsigned long start, struct mmu_notifier_subscriptions *subscriptions, unsigned long start,
unsigned long length, const struct mmu_interval_notifier_ops *ops) unsigned long length, const struct mmu_interval_notifier_ops *ops)
{ {
mni->mm = mm; interval_sub->mm = mm;
mni->ops = ops; interval_sub->ops = ops;
RB_CLEAR_NODE(&mni->interval_tree.rb); RB_CLEAR_NODE(&interval_sub->interval_tree.rb);
mni->interval_tree.start = start; interval_sub->interval_tree.start = start;
/* /*
* Note that the representation of the intervals in the interval tree * Note that the representation of the intervals in the interval tree
* considers the ending point as contained in the interval. * considers the ending point as contained in the interval.
*/ */
if (length == 0 || if (length == 0 ||
check_add_overflow(start, length - 1, &mni->interval_tree.last)) check_add_overflow(start, length - 1,
&interval_sub->interval_tree.last))
return -EOVERFLOW; return -EOVERFLOW;
/* Must call with a mmget() held */ /* Must call with a mmget() held */
...@@ -881,38 +920,40 @@ static int __mmu_interval_notifier_insert( ...@@ -881,38 +920,40 @@ static int __mmu_interval_notifier_insert(
* possibility for live lock, instead defer the add to * possibility for live lock, instead defer the add to
* mn_itree_inv_end() so this algorithm is deterministic. * mn_itree_inv_end() so this algorithm is deterministic.
* *
* In all cases the value for the mni->invalidate_seq should be * In all cases the value for the interval_sub->invalidate_seq should be
* odd, see mmu_interval_read_begin() * odd, see mmu_interval_read_begin()
*/ */
spin_lock(&mmn_mm->lock); spin_lock(&subscriptions->lock);
if (mmn_mm->active_invalidate_ranges) { if (subscriptions->active_invalidate_ranges) {
if (mn_itree_is_invalidating(mmn_mm)) if (mn_itree_is_invalidating(subscriptions))
hlist_add_head(&mni->deferred_item, hlist_add_head(&interval_sub->deferred_item,
&mmn_mm->deferred_list); &subscriptions->deferred_list);
else { else {
mmn_mm->invalidate_seq |= 1; subscriptions->invalidate_seq |= 1;
interval_tree_insert(&mni->interval_tree, interval_tree_insert(&interval_sub->interval_tree,
&mmn_mm->itree); &subscriptions->itree);
} }
mni->invalidate_seq = mmn_mm->invalidate_seq; interval_sub->invalidate_seq = subscriptions->invalidate_seq;
} else { } else {
WARN_ON(mn_itree_is_invalidating(mmn_mm)); WARN_ON(mn_itree_is_invalidating(subscriptions));
/* /*
* The starting seq for a mni not under invalidation should be * The starting seq for a subscription not under invalidation
* odd, not equal to the current invalidate_seq and * should be odd, not equal to the current invalidate_seq and
* invalidate_seq should not 'wrap' to the new seq any time * invalidate_seq should not 'wrap' to the new seq any time
* soon. * soon.
*/ */
mni->invalidate_seq = mmn_mm->invalidate_seq - 1; interval_sub->invalidate_seq =
interval_tree_insert(&mni->interval_tree, &mmn_mm->itree); subscriptions->invalidate_seq - 1;
interval_tree_insert(&interval_sub->interval_tree,
&subscriptions->itree);
} }
spin_unlock(&mmn_mm->lock); spin_unlock(&subscriptions->lock);
return 0; return 0;
} }
/** /**
* mmu_interval_notifier_insert - Insert an interval notifier * mmu_interval_notifier_insert - Insert an interval notifier
* @mni: Interval notifier to register * @interval_sub: Interval subscription to register
* @start: Starting virtual address to monitor * @start: Starting virtual address to monitor
* @length: Length of the range to monitor * @length: Length of the range to monitor
* @mm : mm_struct to attach to * @mm : mm_struct to attach to
...@@ -925,53 +966,53 @@ static int __mmu_interval_notifier_insert( ...@@ -925,53 +966,53 @@ static int __mmu_interval_notifier_insert(
* The caller must use the normal interval notifier read flow via * The caller must use the normal interval notifier read flow via
* mmu_interval_read_begin() to establish SPTEs for this range. * mmu_interval_read_begin() to establish SPTEs for this range.
*/ */
int mmu_interval_notifier_insert(struct mmu_interval_notifier *mni, int mmu_interval_notifier_insert(struct mmu_interval_notifier *interval_sub,
struct mm_struct *mm, unsigned long start, struct mm_struct *mm, unsigned long start,
unsigned long length, unsigned long length,
const struct mmu_interval_notifier_ops *ops) const struct mmu_interval_notifier_ops *ops)
{ {
struct mmu_notifier_mm *mmn_mm; struct mmu_notifier_subscriptions *subscriptions;
int ret; int ret;
might_lock(&mm->mmap_sem); might_lock(&mm->mmap_sem);
mmn_mm = smp_load_acquire(&mm->mmu_notifier_mm); subscriptions = smp_load_acquire(&mm->notifier_subscriptions);
if (!mmn_mm || !mmn_mm->has_itree) { if (!subscriptions || !subscriptions->has_itree) {
ret = mmu_notifier_register(NULL, mm); ret = mmu_notifier_register(NULL, mm);
if (ret) if (ret)
return ret; return ret;
mmn_mm = mm->mmu_notifier_mm; subscriptions = mm->notifier_subscriptions;
} }
return __mmu_interval_notifier_insert(mni, mm, mmn_mm, start, length, return __mmu_interval_notifier_insert(interval_sub, mm, subscriptions,
ops); start, length, ops);
} }
EXPORT_SYMBOL_GPL(mmu_interval_notifier_insert); EXPORT_SYMBOL_GPL(mmu_interval_notifier_insert);
int mmu_interval_notifier_insert_locked( int mmu_interval_notifier_insert_locked(
struct mmu_interval_notifier *mni, struct mm_struct *mm, struct mmu_interval_notifier *interval_sub, struct mm_struct *mm,
unsigned long start, unsigned long length, unsigned long start, unsigned long length,
const struct mmu_interval_notifier_ops *ops) const struct mmu_interval_notifier_ops *ops)
{ {
struct mmu_notifier_mm *mmn_mm; struct mmu_notifier_subscriptions *subscriptions =
mm->notifier_subscriptions;
int ret; int ret;
lockdep_assert_held_write(&mm->mmap_sem); lockdep_assert_held_write(&mm->mmap_sem);
mmn_mm = mm->mmu_notifier_mm; if (!subscriptions || !subscriptions->has_itree) {
if (!mmn_mm || !mmn_mm->has_itree) {
ret = __mmu_notifier_register(NULL, mm); ret = __mmu_notifier_register(NULL, mm);
if (ret) if (ret)
return ret; return ret;
mmn_mm = mm->mmu_notifier_mm; subscriptions = mm->notifier_subscriptions;
} }
return __mmu_interval_notifier_insert(mni, mm, mmn_mm, start, length, return __mmu_interval_notifier_insert(interval_sub, mm, subscriptions,
ops); start, length, ops);
} }
EXPORT_SYMBOL_GPL(mmu_interval_notifier_insert_locked); EXPORT_SYMBOL_GPL(mmu_interval_notifier_insert_locked);
/** /**
* mmu_interval_notifier_remove - Remove a interval notifier * mmu_interval_notifier_remove - Remove a interval notifier
* @mni: Interval notifier to unregister * @interval_sub: Interval subscription to unregister
* *
* This function must be paired with mmu_interval_notifier_insert(). It cannot * This function must be paired with mmu_interval_notifier_insert(). It cannot
* be called from any ops callback. * be called from any ops callback.
...@@ -979,32 +1020,34 @@ EXPORT_SYMBOL_GPL(mmu_interval_notifier_insert_locked); ...@@ -979,32 +1020,34 @@ EXPORT_SYMBOL_GPL(mmu_interval_notifier_insert_locked);
* Once this returns ops callbacks are no longer running on other CPUs and * Once this returns ops callbacks are no longer running on other CPUs and
* will not be called in future. * will not be called in future.
*/ */
void mmu_interval_notifier_remove(struct mmu_interval_notifier *mni) void mmu_interval_notifier_remove(struct mmu_interval_notifier *interval_sub)
{ {
struct mm_struct *mm = mni->mm; struct mm_struct *mm = interval_sub->mm;
struct mmu_notifier_mm *mmn_mm = mm->mmu_notifier_mm; struct mmu_notifier_subscriptions *subscriptions =
mm->notifier_subscriptions;
unsigned long seq = 0; unsigned long seq = 0;
might_sleep(); might_sleep();
spin_lock(&mmn_mm->lock); spin_lock(&subscriptions->lock);
if (mn_itree_is_invalidating(mmn_mm)) { if (mn_itree_is_invalidating(subscriptions)) {
/* /*
* remove is being called after insert put this on the * remove is being called after insert put this on the
* deferred list, but before the deferred list was processed. * deferred list, but before the deferred list was processed.
*/ */
if (RB_EMPTY_NODE(&mni->interval_tree.rb)) { if (RB_EMPTY_NODE(&interval_sub->interval_tree.rb)) {
hlist_del(&mni->deferred_item); hlist_del(&interval_sub->deferred_item);
} else { } else {
hlist_add_head(&mni->deferred_item, hlist_add_head(&interval_sub->deferred_item,
&mmn_mm->deferred_list); &subscriptions->deferred_list);
seq = mmn_mm->invalidate_seq; seq = subscriptions->invalidate_seq;
} }
} else { } else {
WARN_ON(RB_EMPTY_NODE(&mni->interval_tree.rb)); WARN_ON(RB_EMPTY_NODE(&interval_sub->interval_tree.rb));
interval_tree_remove(&mni->interval_tree, &mmn_mm->itree); interval_tree_remove(&interval_sub->interval_tree,
&subscriptions->itree);
} }
spin_unlock(&mmn_mm->lock); spin_unlock(&subscriptions->lock);
/* /*
* The possible sleep on progress in the invalidation requires the * The possible sleep on progress in the invalidation requires the
...@@ -1013,8 +1056,8 @@ void mmu_interval_notifier_remove(struct mmu_interval_notifier *mni) ...@@ -1013,8 +1056,8 @@ void mmu_interval_notifier_remove(struct mmu_interval_notifier *mni)
lock_map_acquire(&__mmu_notifier_invalidate_range_start_map); lock_map_acquire(&__mmu_notifier_invalidate_range_start_map);
lock_map_release(&__mmu_notifier_invalidate_range_start_map); lock_map_release(&__mmu_notifier_invalidate_range_start_map);
if (seq) if (seq)
wait_event(mmn_mm->wq, wait_event(subscriptions->wq,
READ_ONCE(mmn_mm->invalidate_seq) != seq); READ_ONCE(subscriptions->invalidate_seq) != seq);
/* pairs with mmgrab in mmu_interval_notifier_insert() */ /* pairs with mmgrab in mmu_interval_notifier_insert() */
mmdrop(mm); mmdrop(mm);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment