Commit 408579cd authored by Liam R. Howlett's avatar Liam R. Howlett Committed by Linus Torvalds

mm: Update do_vmi_align_munmap() return semantics

Since do_vmi_align_munmap() will always honor the downgrade request on
the success, the callers no longer have to deal with confusing return
codes.  Since all callers that request downgrade actually want the lock
to be dropped, change the downgrade to an unlock request.

Note that the lock still needs to be held in read mode during the page
table clean up to avoid races with a map request.

Update do_vmi_align_munmap() to return 0 for success.  Clean up the
callers and comments to always expect the unlock to be honored on the
success path.  The error path will always leave the lock untouched.

As part of the cleanup, the wrapper function do_vmi_munmap() and callers
to the wrapper are also updated.
Suggested-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
Link: https://lore.kernel.org/linux-mm/20230629191414.1215929-1-willy@infradead.org/Signed-off-by: default avatarLiam R. Howlett <Liam.Howlett@oracle.com>
Signed-off-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
parent e4bd84c0
...@@ -3177,7 +3177,7 @@ extern unsigned long do_mmap(struct file *file, unsigned long addr, ...@@ -3177,7 +3177,7 @@ extern unsigned long do_mmap(struct file *file, unsigned long addr,
unsigned long pgoff, unsigned long *populate, struct list_head *uf); unsigned long pgoff, unsigned long *populate, struct list_head *uf);
extern int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm, extern int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm,
unsigned long start, size_t len, struct list_head *uf, unsigned long start, size_t len, struct list_head *uf,
bool downgrade); bool unlock);
extern int do_munmap(struct mm_struct *, unsigned long, size_t, extern int do_munmap(struct mm_struct *, unsigned long, size_t,
struct list_head *uf); struct list_head *uf);
extern int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, int behavior); extern int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, int behavior);
...@@ -3185,7 +3185,7 @@ extern int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, ...@@ -3185,7 +3185,7 @@ extern int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in,
#ifdef CONFIG_MMU #ifdef CONFIG_MMU
extern int do_vma_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma, extern int do_vma_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
unsigned long start, unsigned long end, unsigned long start, unsigned long end,
struct list_head *uf, bool downgrade); struct list_head *uf, bool unlock);
extern int __mm_populate(unsigned long addr, unsigned long len, extern int __mm_populate(unsigned long addr, unsigned long len,
int ignore_errors); int ignore_errors);
static inline void mm_populate(unsigned long addr, unsigned long len) static inline void mm_populate(unsigned long addr, unsigned long len)
......
...@@ -193,8 +193,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -193,8 +193,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
struct vm_area_struct *brkvma, *next = NULL; struct vm_area_struct *brkvma, *next = NULL;
unsigned long min_brk; unsigned long min_brk;
bool populate; bool populate = false;
bool downgraded = false;
LIST_HEAD(uf); LIST_HEAD(uf);
struct vma_iterator vmi; struct vma_iterator vmi;
...@@ -236,13 +235,8 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -236,13 +235,8 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
goto success; goto success;
} }
/* /* Always allow shrinking brk. */
* Always allow shrinking brk.
* do_vma_munmap() may downgrade mmap_lock to read.
*/
if (brk <= mm->brk) { if (brk <= mm->brk) {
int ret;
/* Search one past newbrk */ /* Search one past newbrk */
vma_iter_init(&vmi, mm, newbrk); vma_iter_init(&vmi, mm, newbrk);
brkvma = vma_find(&vmi, oldbrk); brkvma = vma_find(&vmi, oldbrk);
...@@ -250,19 +244,14 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -250,19 +244,14 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
goto out; /* mapping intersects with an existing non-brk vma. */ goto out; /* mapping intersects with an existing non-brk vma. */
/* /*
* mm->brk must be protected by write mmap_lock. * mm->brk must be protected by write mmap_lock.
* do_vma_munmap() may downgrade the lock, so update it * do_vma_munmap() will drop the lock on success, so update it
* before calling do_vma_munmap(). * before calling do_vma_munmap().
*/ */
mm->brk = brk; mm->brk = brk;
ret = do_vma_munmap(&vmi, brkvma, newbrk, oldbrk, &uf, true); if (do_vma_munmap(&vmi, brkvma, newbrk, oldbrk, &uf, true))
if (ret == 1) {
downgraded = true;
goto success;
} else if (!ret)
goto success;
mm->brk = origbrk;
goto out; goto out;
goto success_unlocked;
} }
if (check_brk_limits(oldbrk, newbrk - oldbrk)) if (check_brk_limits(oldbrk, newbrk - oldbrk))
...@@ -283,19 +272,19 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -283,19 +272,19 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
goto out; goto out;
mm->brk = brk; mm->brk = brk;
if (mm->def_flags & VM_LOCKED)
populate = true;
success: success:
populate = newbrk > oldbrk && (mm->def_flags & VM_LOCKED) != 0;
if (downgraded)
mmap_read_unlock(mm);
else
mmap_write_unlock(mm); mmap_write_unlock(mm);
success_unlocked:
userfaultfd_unmap_complete(mm, &uf); userfaultfd_unmap_complete(mm, &uf);
if (populate) if (populate)
mm_populate(oldbrk, newbrk - oldbrk); mm_populate(oldbrk, newbrk - oldbrk);
return brk; return brk;
out: out:
mm->brk = origbrk;
mmap_write_unlock(mm); mmap_write_unlock(mm);
return origbrk; return origbrk;
} }
...@@ -2428,14 +2417,16 @@ int split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma, ...@@ -2428,14 +2417,16 @@ int split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
* @start: The aligned start address to munmap. * @start: The aligned start address to munmap.
* @end: The aligned end address to munmap. * @end: The aligned end address to munmap.
* @uf: The userfaultfd list_head * @uf: The userfaultfd list_head
* @downgrade: Set to true to attempt a write downgrade of the mmap_lock * @unlock: Set to true to drop the mmap_lock. unlocking only happens on
* success.
* *
* If @downgrade is true, check return code for potential release of the lock. * Return: 0 on success and drops the lock if so directed, error and leaves the
* lock held otherwise.
*/ */
static int static int
do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma, do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
struct mm_struct *mm, unsigned long start, struct mm_struct *mm, unsigned long start,
unsigned long end, struct list_head *uf, bool downgrade) unsigned long end, struct list_head *uf, bool unlock)
{ {
struct vm_area_struct *prev, *next = NULL; struct vm_area_struct *prev, *next = NULL;
struct maple_tree mt_detach; struct maple_tree mt_detach;
...@@ -2551,22 +2542,24 @@ do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma, ...@@ -2551,22 +2542,24 @@ do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
/* Point of no return */ /* Point of no return */
mm->locked_vm -= locked_vm; mm->locked_vm -= locked_vm;
mm->map_count -= count; mm->map_count -= count;
if (downgrade) if (unlock)
mmap_write_downgrade(mm); mmap_write_downgrade(mm);
/* /*
* We can free page tables without write-locking mmap_lock because VMAs * We can free page tables without write-locking mmap_lock because VMAs
* were isolated before we downgraded mmap_lock. * were isolated before we downgraded mmap_lock.
*/ */
unmap_region(mm, &mt_detach, vma, prev, next, start, end, !downgrade); unmap_region(mm, &mt_detach, vma, prev, next, start, end, !unlock);
/* Statistics and freeing VMAs */ /* Statistics and freeing VMAs */
mas_set(&mas_detach, start); mas_set(&mas_detach, start);
remove_mt(mm, &mas_detach); remove_mt(mm, &mas_detach);
__mt_destroy(&mt_detach); __mt_destroy(&mt_detach);
if (unlock)
mmap_read_unlock(mm);
validate_mm(mm); validate_mm(mm);
return downgrade ? 1 : 0; return 0;
clear_tree_failed: clear_tree_failed:
userfaultfd_error: userfaultfd_error:
...@@ -2589,18 +2582,18 @@ do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma, ...@@ -2589,18 +2582,18 @@ do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
* @start: The start address to munmap * @start: The start address to munmap
* @len: The length of the range to munmap * @len: The length of the range to munmap
* @uf: The userfaultfd list_head * @uf: The userfaultfd list_head
* @downgrade: set to true if the user wants to attempt to write_downgrade the * @unlock: set to true if the user wants to drop the mmap_lock on success
* mmap_lock
* *
* This function takes a @mas that is either pointing to the previous VMA or set * This function takes a @mas that is either pointing to the previous VMA or set
* to MA_START and sets it up to remove the mapping(s). The @len will be * to MA_START and sets it up to remove the mapping(s). The @len will be
* aligned and any arch_unmap work will be preformed. * aligned and any arch_unmap work will be preformed.
* *
* Returns: -EINVAL on failure, 1 on success and unlock, 0 otherwise. * Return: 0 on success and drops the lock if so directed, error and leaves the
* lock held otherwise.
*/ */
int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm, int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm,
unsigned long start, size_t len, struct list_head *uf, unsigned long start, size_t len, struct list_head *uf,
bool downgrade) bool unlock)
{ {
unsigned long end; unsigned long end;
struct vm_area_struct *vma; struct vm_area_struct *vma;
...@@ -2617,10 +2610,13 @@ int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm, ...@@ -2617,10 +2610,13 @@ int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm,
/* Find the first overlapping VMA */ /* Find the first overlapping VMA */
vma = vma_find(vmi, end); vma = vma_find(vmi, end);
if (!vma) if (!vma) {
if (unlock)
mmap_write_unlock(mm);
return 0; return 0;
}
return do_vmi_align_munmap(vmi, vma, mm, start, end, uf, downgrade); return do_vmi_align_munmap(vmi, vma, mm, start, end, uf, unlock);
} }
/* do_munmap() - Wrapper function for non-maple tree aware do_munmap() calls. /* do_munmap() - Wrapper function for non-maple tree aware do_munmap() calls.
...@@ -2628,6 +2624,8 @@ int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm, ...@@ -2628,6 +2624,8 @@ int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm,
* @start: The start address to munmap * @start: The start address to munmap
* @len: The length to be munmapped. * @len: The length to be munmapped.
* @uf: The userfaultfd list_head * @uf: The userfaultfd list_head
*
* Return: 0 on success, error otherwise.
*/ */
int do_munmap(struct mm_struct *mm, unsigned long start, size_t len, int do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
struct list_head *uf) struct list_head *uf)
...@@ -2888,7 +2886,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr, ...@@ -2888,7 +2886,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
return error; return error;
} }
static int __vm_munmap(unsigned long start, size_t len, bool downgrade) static int __vm_munmap(unsigned long start, size_t len, bool unlock)
{ {
int ret; int ret;
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
...@@ -2898,16 +2896,8 @@ static int __vm_munmap(unsigned long start, size_t len, bool downgrade) ...@@ -2898,16 +2896,8 @@ static int __vm_munmap(unsigned long start, size_t len, bool downgrade)
if (mmap_write_lock_killable(mm)) if (mmap_write_lock_killable(mm))
return -EINTR; return -EINTR;
ret = do_vmi_munmap(&vmi, mm, start, len, &uf, downgrade); ret = do_vmi_munmap(&vmi, mm, start, len, &uf, unlock);
/* if (ret || !unlock)
* Returning 1 indicates mmap_lock is downgraded.
* But 1 is not legal return value of vm_munmap() and munmap(), reset
* it to 0 before return.
*/
if (ret == 1) {
mmap_read_unlock(mm);
ret = 0;
} else
mmap_write_unlock(mm); mmap_write_unlock(mm);
userfaultfd_unmap_complete(mm, &uf); userfaultfd_unmap_complete(mm, &uf);
...@@ -3017,21 +3007,23 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size, ...@@ -3017,21 +3007,23 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
* @start: the start of the address to unmap * @start: the start of the address to unmap
* @end: The end of the address to unmap * @end: The end of the address to unmap
* @uf: The userfaultfd list_head * @uf: The userfaultfd list_head
* @downgrade: Attempt to downgrade or not * @unlock: Drop the lock on success
* *
* Returns: 0 on success and not downgraded, 1 on success and downgraded.
* unmaps a VMA mapping when the vma iterator is already in position. * unmaps a VMA mapping when the vma iterator is already in position.
* Does not handle alignment. * Does not handle alignment.
*
* Return: 0 on success drops the lock of so directed, error on failure and will
* still hold the lock.
*/ */
int do_vma_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma, int do_vma_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
unsigned long start, unsigned long end, unsigned long start, unsigned long end, struct list_head *uf,
struct list_head *uf, bool downgrade) bool unlock)
{ {
struct mm_struct *mm = vma->vm_mm; struct mm_struct *mm = vma->vm_mm;
int ret; int ret;
arch_unmap(mm, start, end); arch_unmap(mm, start, end);
ret = do_vmi_align_munmap(vmi, vma, mm, start, end, uf, downgrade); ret = do_vmi_align_munmap(vmi, vma, mm, start, end, uf, unlock);
validate_mm(mm); validate_mm(mm);
return ret; return ret;
} }
......
...@@ -715,7 +715,7 @@ static unsigned long move_vma(struct vm_area_struct *vma, ...@@ -715,7 +715,7 @@ static unsigned long move_vma(struct vm_area_struct *vma,
} }
vma_iter_init(&vmi, mm, old_addr); vma_iter_init(&vmi, mm, old_addr);
if (do_vmi_munmap(&vmi, mm, old_addr, old_len, uf_unmap, false) < 0) { if (!do_vmi_munmap(&vmi, mm, old_addr, old_len, uf_unmap, false)) {
/* OOM: unable to split vma, just get accounts right */ /* OOM: unable to split vma, just get accounts right */
if (vm_flags & VM_ACCOUNT && !(flags & MREMAP_DONTUNMAP)) if (vm_flags & VM_ACCOUNT && !(flags & MREMAP_DONTUNMAP))
vm_acct_memory(old_len >> PAGE_SHIFT); vm_acct_memory(old_len >> PAGE_SHIFT);
...@@ -913,7 +913,6 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, ...@@ -913,7 +913,6 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
struct vm_area_struct *vma; struct vm_area_struct *vma;
unsigned long ret = -EINVAL; unsigned long ret = -EINVAL;
bool locked = false; bool locked = false;
bool downgraded = false;
struct vm_userfaultfd_ctx uf = NULL_VM_UFFD_CTX; struct vm_userfaultfd_ctx uf = NULL_VM_UFFD_CTX;
LIST_HEAD(uf_unmap_early); LIST_HEAD(uf_unmap_early);
LIST_HEAD(uf_unmap); LIST_HEAD(uf_unmap);
...@@ -999,24 +998,23 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, ...@@ -999,24 +998,23 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
* Always allow a shrinking remap: that just unmaps * Always allow a shrinking remap: that just unmaps
* the unnecessary pages.. * the unnecessary pages..
* do_vmi_munmap does all the needed commit accounting, and * do_vmi_munmap does all the needed commit accounting, and
* downgrades mmap_lock to read if so directed. * unlocks the mmap_lock if so directed.
*/ */
if (old_len >= new_len) { if (old_len >= new_len) {
int retval;
VMA_ITERATOR(vmi, mm, addr + new_len); VMA_ITERATOR(vmi, mm, addr + new_len);
retval = do_vmi_munmap(&vmi, mm, addr + new_len, if (old_len == new_len) {
old_len - new_len, &uf_unmap, true); ret = addr;
/* Returning 1 indicates mmap_lock is downgraded to read. */
if (retval == 1) {
downgraded = true;
} else if (retval < 0 && old_len != new_len) {
ret = retval;
goto out; goto out;
} }
ret = addr; ret = do_vmi_munmap(&vmi, mm, addr + new_len, old_len - new_len,
&uf_unmap, true);
if (ret)
goto out; goto out;
ret = addr;
goto out_unlocked;
} }
/* /*
...@@ -1101,12 +1099,10 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, ...@@ -1101,12 +1099,10 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
out: out:
if (offset_in_page(ret)) if (offset_in_page(ret))
locked = false; locked = false;
if (downgraded)
mmap_read_unlock(current->mm);
else
mmap_write_unlock(current->mm); mmap_write_unlock(current->mm);
if (locked && new_len > old_len) if (locked && new_len > old_len)
mm_populate(new_addr + old_len, new_len - old_len); mm_populate(new_addr + old_len, new_len - old_len);
out_unlocked:
userfaultfd_unmap_complete(mm, &uf_unmap_early); userfaultfd_unmap_complete(mm, &uf_unmap_early);
mremap_userfaultfd_complete(&uf, addr, ret, old_len); mremap_userfaultfd_complete(&uf, addr, ret, old_len);
userfaultfd_unmap_complete(mm, &uf_unmap); userfaultfd_unmap_complete(mm, &uf_unmap);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment