Commit 897ab3e0 authored by Mike Rapoport's avatar Mike Rapoport Committed by Linus Torvalds

userfaultfd: non-cooperative: add event for memory unmaps

When a non-cooperative userfaultfd monitor copies pages in the
background, it may encounter regions that were already unmapped.
Addition of UFFD_EVENT_UNMAP allows the uffd monitor to track precisely
changes in the virtual memory layout.

Since there might be different uffd contexts for the affected VMAs, we
first should create a temporary representation for the unmap event for
each uffd context and then notify them one by one to the appropriate
userfault file descriptors.

The event notification occurs after the mmap_sem has been released.

[arnd@arndb.de: fix nommu build]
  Link: http://lkml.kernel.org/r/20170203165141.3665284-1-arnd@arndb.de
[mhocko@suse.com: fix nommu build]
  Link: http://lkml.kernel.org/r/20170202091503.GA22823@dhcp22.suse.cz
Link: http://lkml.kernel.org/r/1485542673-24387-3-git-send-email-rppt@linux.vnet.ibm.comSigned-off-by: default avatarMike Rapoport <rppt@linux.vnet.ibm.com>
Signed-off-by: default avatarMichal Hocko <mhocko@suse.com>
Signed-off-by: default avatarArnd Bergmann <arnd@arndb.de>
Acked-by: default avatarHillf Danton <hillf.zj@alibaba-inc.com>
Cc: Andrea Arcangeli <aarcange@redhat.com>
Cc: "Dr. David Alan Gilbert" <dgilbert@redhat.com>
Cc: Mike Kravetz <mike.kravetz@oracle.com>
Cc: Pavel Emelyanov <xemul@virtuozzo.com>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
Signed-off-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
parent 846b1a0f
...@@ -111,7 +111,7 @@ int arch_setup_additional_pages(struct linux_binprm *bprm, int uses_interp) ...@@ -111,7 +111,7 @@ int arch_setup_additional_pages(struct linux_binprm *bprm, int uses_interp)
base = mmap_region(NULL, STACK_TOP, PAGE_SIZE, base = mmap_region(NULL, STACK_TOP, PAGE_SIZE,
VM_READ|VM_WRITE|VM_EXEC| VM_READ|VM_WRITE|VM_EXEC|
VM_MAYREAD|VM_MAYWRITE|VM_MAYEXEC, VM_MAYREAD|VM_MAYWRITE|VM_MAYEXEC,
0); 0, NULL);
if (IS_ERR_VALUE(base)) { if (IS_ERR_VALUE(base)) {
ret = base; ret = base;
goto out; goto out;
......
...@@ -143,7 +143,7 @@ int arch_setup_additional_pages(struct linux_binprm *bprm, ...@@ -143,7 +143,7 @@ int arch_setup_additional_pages(struct linux_binprm *bprm,
unsigned long addr = MEM_USER_INTRPT; unsigned long addr = MEM_USER_INTRPT;
addr = mmap_region(NULL, addr, INTRPT_SIZE, addr = mmap_region(NULL, addr, INTRPT_SIZE,
VM_READ|VM_EXEC| VM_READ|VM_EXEC|
VM_MAYREAD|VM_MAYWRITE|VM_MAYEXEC, 0); VM_MAYREAD|VM_MAYWRITE|VM_MAYEXEC, 0, NULL);
if (addr > (unsigned long) -PAGE_SIZE) if (addr > (unsigned long) -PAGE_SIZE)
retval = (int) addr; retval = (int) addr;
} }
......
...@@ -186,7 +186,7 @@ static int map_vdso(const struct vdso_image *image, unsigned long addr) ...@@ -186,7 +186,7 @@ static int map_vdso(const struct vdso_image *image, unsigned long addr)
if (IS_ERR(vma)) { if (IS_ERR(vma)) {
ret = PTR_ERR(vma); ret = PTR_ERR(vma);
do_munmap(mm, text_start, image->size); do_munmap(mm, text_start, image->size, NULL);
} else { } else {
current->mm->context.vdso = (void __user *)text_start; current->mm->context.vdso = (void __user *)text_start;
current->mm->context.vdso_image = image; current->mm->context.vdso_image = image;
......
...@@ -51,7 +51,7 @@ static unsigned long mpx_mmap(unsigned long len) ...@@ -51,7 +51,7 @@ static unsigned long mpx_mmap(unsigned long len)
down_write(&mm->mmap_sem); down_write(&mm->mmap_sem);
addr = do_mmap(NULL, 0, len, PROT_READ | PROT_WRITE, addr = do_mmap(NULL, 0, len, PROT_READ | PROT_WRITE,
MAP_ANONYMOUS | MAP_PRIVATE, VM_MPX, 0, &populate); MAP_ANONYMOUS | MAP_PRIVATE, VM_MPX, 0, &populate, NULL);
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
if (populate) if (populate)
mm_populate(addr, populate); mm_populate(addr, populate);
...@@ -893,7 +893,7 @@ static int unmap_entire_bt(struct mm_struct *mm, ...@@ -893,7 +893,7 @@ static int unmap_entire_bt(struct mm_struct *mm,
* avoid recursion, do_munmap() will check whether it comes * avoid recursion, do_munmap() will check whether it comes
* from one bounds table through VM_MPX flag. * from one bounds table through VM_MPX flag.
*/ */
return do_munmap(mm, bt_addr, mpx_bt_size_bytes(mm)); return do_munmap(mm, bt_addr, mpx_bt_size_bytes(mm), NULL);
} }
static int try_unmap_single_bt(struct mm_struct *mm, static int try_unmap_single_bt(struct mm_struct *mm,
......
...@@ -512,7 +512,7 @@ static int aio_setup_ring(struct kioctx *ctx) ...@@ -512,7 +512,7 @@ static int aio_setup_ring(struct kioctx *ctx)
ctx->mmap_base = do_mmap_pgoff(ctx->aio_ring_file, 0, ctx->mmap_size, ctx->mmap_base = do_mmap_pgoff(ctx->aio_ring_file, 0, ctx->mmap_size,
PROT_READ | PROT_WRITE, PROT_READ | PROT_WRITE,
MAP_SHARED, 0, &unused); MAP_SHARED, 0, &unused, NULL);
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
if (IS_ERR((void *)ctx->mmap_base)) { if (IS_ERR((void *)ctx->mmap_base)) {
ctx->mmap_size = 0; ctx->mmap_size = 0;
......
...@@ -388,7 +388,7 @@ static int remap_oldmem_pfn_checked(struct vm_area_struct *vma, ...@@ -388,7 +388,7 @@ static int remap_oldmem_pfn_checked(struct vm_area_struct *vma,
} }
return 0; return 0;
fail: fail:
do_munmap(vma->vm_mm, from, len); do_munmap(vma->vm_mm, from, len, NULL);
return -EAGAIN; return -EAGAIN;
} }
...@@ -481,7 +481,7 @@ static int mmap_vmcore(struct file *file, struct vm_area_struct *vma) ...@@ -481,7 +481,7 @@ static int mmap_vmcore(struct file *file, struct vm_area_struct *vma)
return 0; return 0;
fail: fail:
do_munmap(vma->vm_mm, vma->vm_start, len); do_munmap(vma->vm_mm, vma->vm_start, len, NULL);
return -EAGAIN; return -EAGAIN;
} }
#else #else
......
...@@ -71,6 +71,13 @@ struct userfaultfd_fork_ctx { ...@@ -71,6 +71,13 @@ struct userfaultfd_fork_ctx {
struct list_head list; struct list_head list;
}; };
struct userfaultfd_unmap_ctx {
struct userfaultfd_ctx *ctx;
unsigned long start;
unsigned long end;
struct list_head list;
};
struct userfaultfd_wait_queue { struct userfaultfd_wait_queue {
struct uffd_msg msg; struct uffd_msg msg;
wait_queue_t wq; wait_queue_t wq;
...@@ -709,6 +716,64 @@ void userfaultfd_remove(struct vm_area_struct *vma, ...@@ -709,6 +716,64 @@ void userfaultfd_remove(struct vm_area_struct *vma,
down_read(&mm->mmap_sem); down_read(&mm->mmap_sem);
} }
static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps,
unsigned long start, unsigned long end)
{
struct userfaultfd_unmap_ctx *unmap_ctx;
list_for_each_entry(unmap_ctx, unmaps, list)
if (unmap_ctx->ctx == ctx && unmap_ctx->start == start &&
unmap_ctx->end == end)
return true;
return false;
}
int userfaultfd_unmap_prep(struct vm_area_struct *vma,
unsigned long start, unsigned long end,
struct list_head *unmaps)
{
for ( ; vma && vma->vm_start < end; vma = vma->vm_next) {
struct userfaultfd_unmap_ctx *unmap_ctx;
struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx;
if (!ctx || !(ctx->features & UFFD_FEATURE_EVENT_UNMAP) ||
has_unmap_ctx(ctx, unmaps, start, end))
continue;
unmap_ctx = kzalloc(sizeof(*unmap_ctx), GFP_KERNEL);
if (!unmap_ctx)
return -ENOMEM;
userfaultfd_ctx_get(ctx);
unmap_ctx->ctx = ctx;
unmap_ctx->start = start;
unmap_ctx->end = end;
list_add_tail(&unmap_ctx->list, unmaps);
}
return 0;
}
void userfaultfd_unmap_complete(struct mm_struct *mm, struct list_head *uf)
{
struct userfaultfd_unmap_ctx *ctx, *n;
struct userfaultfd_wait_queue ewq;
list_for_each_entry_safe(ctx, n, uf, list) {
msg_init(&ewq.msg);
ewq.msg.event = UFFD_EVENT_UNMAP;
ewq.msg.arg.remove.start = ctx->start;
ewq.msg.arg.remove.end = ctx->end;
userfaultfd_event_wait_completion(ctx->ctx, &ewq);
list_del(&ctx->list);
kfree(ctx);
}
}
static int userfaultfd_release(struct inode *inode, struct file *file) static int userfaultfd_release(struct inode *inode, struct file *file)
{ {
struct userfaultfd_ctx *ctx = file->private_data; struct userfaultfd_ctx *ctx = file->private_data;
......
...@@ -2090,18 +2090,22 @@ extern int install_special_mapping(struct mm_struct *mm, ...@@ -2090,18 +2090,22 @@ extern int install_special_mapping(struct mm_struct *mm,
extern unsigned long get_unmapped_area(struct file *, unsigned long, unsigned long, unsigned long, unsigned long); extern unsigned long get_unmapped_area(struct file *, unsigned long, unsigned long, unsigned long, unsigned long);
extern unsigned long mmap_region(struct file *file, unsigned long addr, extern unsigned long mmap_region(struct file *file, unsigned long addr,
unsigned long len, vm_flags_t vm_flags, unsigned long pgoff); unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
struct list_head *uf);
extern unsigned long do_mmap(struct file *file, unsigned long addr, extern unsigned long do_mmap(struct file *file, unsigned long addr,
unsigned long len, unsigned long prot, unsigned long flags, unsigned long len, unsigned long prot, unsigned long flags,
vm_flags_t vm_flags, unsigned long pgoff, unsigned long *populate); vm_flags_t vm_flags, unsigned long pgoff, unsigned long *populate,
extern int do_munmap(struct mm_struct *, unsigned long, size_t); struct list_head *uf);
extern int do_munmap(struct mm_struct *, unsigned long, size_t,
struct list_head *uf);
static inline unsigned long static inline unsigned long
do_mmap_pgoff(struct file *file, unsigned long addr, do_mmap_pgoff(struct file *file, unsigned long addr,
unsigned long len, unsigned long prot, unsigned long flags, unsigned long len, unsigned long prot, unsigned long flags,
unsigned long pgoff, unsigned long *populate) unsigned long pgoff, unsigned long *populate,
struct list_head *uf)
{ {
return do_mmap(file, addr, len, prot, flags, 0, pgoff, populate); return do_mmap(file, addr, len, prot, flags, 0, pgoff, populate, uf);
} }
#ifdef CONFIG_MMU #ifdef CONFIG_MMU
......
...@@ -66,6 +66,12 @@ extern void userfaultfd_remove(struct vm_area_struct *vma, ...@@ -66,6 +66,12 @@ extern void userfaultfd_remove(struct vm_area_struct *vma,
unsigned long start, unsigned long start,
unsigned long end); unsigned long end);
extern int userfaultfd_unmap_prep(struct vm_area_struct *vma,
unsigned long start, unsigned long end,
struct list_head *uf);
extern void userfaultfd_unmap_complete(struct mm_struct *mm,
struct list_head *uf);
#else /* CONFIG_USERFAULTFD */ #else /* CONFIG_USERFAULTFD */
/* mm helpers */ /* mm helpers */
...@@ -118,6 +124,18 @@ static inline void userfaultfd_remove(struct vm_area_struct *vma, ...@@ -118,6 +124,18 @@ static inline void userfaultfd_remove(struct vm_area_struct *vma,
unsigned long end) unsigned long end)
{ {
} }
static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma,
unsigned long start, unsigned long end,
struct list_head *uf)
{
return 0;
}
static inline void userfaultfd_unmap_complete(struct mm_struct *mm,
struct list_head *uf)
{
}
#endif /* CONFIG_USERFAULTFD */ #endif /* CONFIG_USERFAULTFD */
#endif /* _LINUX_USERFAULTFD_K_H */ #endif /* _LINUX_USERFAULTFD_K_H */
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#define UFFD_API_FEATURES (UFFD_FEATURE_EVENT_FORK | \ #define UFFD_API_FEATURES (UFFD_FEATURE_EVENT_FORK | \
UFFD_FEATURE_EVENT_REMAP | \ UFFD_FEATURE_EVENT_REMAP | \
UFFD_FEATURE_EVENT_REMOVE | \ UFFD_FEATURE_EVENT_REMOVE | \
UFFD_FEATURE_EVENT_UNMAP | \
UFFD_FEATURE_MISSING_HUGETLBFS | \ UFFD_FEATURE_MISSING_HUGETLBFS | \
UFFD_FEATURE_MISSING_SHMEM) UFFD_FEATURE_MISSING_SHMEM)
#define UFFD_API_IOCTLS \ #define UFFD_API_IOCTLS \
...@@ -110,6 +111,7 @@ struct uffd_msg { ...@@ -110,6 +111,7 @@ struct uffd_msg {
#define UFFD_EVENT_FORK 0x13 #define UFFD_EVENT_FORK 0x13
#define UFFD_EVENT_REMAP 0x14 #define UFFD_EVENT_REMAP 0x14
#define UFFD_EVENT_REMOVE 0x15 #define UFFD_EVENT_REMOVE 0x15
#define UFFD_EVENT_UNMAP 0x16
/* flags for UFFD_EVENT_PAGEFAULT */ /* flags for UFFD_EVENT_PAGEFAULT */
#define UFFD_PAGEFAULT_FLAG_WRITE (1<<0) /* If this was a write fault */ #define UFFD_PAGEFAULT_FLAG_WRITE (1<<0) /* If this was a write fault */
...@@ -158,6 +160,7 @@ struct uffdio_api { ...@@ -158,6 +160,7 @@ struct uffdio_api {
#define UFFD_FEATURE_EVENT_REMOVE (1<<3) #define UFFD_FEATURE_EVENT_REMOVE (1<<3)
#define UFFD_FEATURE_MISSING_HUGETLBFS (1<<4) #define UFFD_FEATURE_MISSING_HUGETLBFS (1<<4)
#define UFFD_FEATURE_MISSING_SHMEM (1<<5) #define UFFD_FEATURE_MISSING_SHMEM (1<<5)
#define UFFD_FEATURE_EVENT_UNMAP (1<<6)
__u64 features; __u64 features;
__u64 ioctls; __u64 ioctls;
......
...@@ -1222,7 +1222,7 @@ long do_shmat(int shmid, char __user *shmaddr, int shmflg, ulong *raddr, ...@@ -1222,7 +1222,7 @@ long do_shmat(int shmid, char __user *shmaddr, int shmflg, ulong *raddr,
goto invalid; goto invalid;
} }
addr = do_mmap_pgoff(file, addr, size, prot, flags, 0, &populate); addr = do_mmap_pgoff(file, addr, size, prot, flags, 0, &populate, NULL);
*raddr = addr; *raddr = addr;
err = 0; err = 0;
if (IS_ERR_VALUE(addr)) if (IS_ERR_VALUE(addr))
...@@ -1329,7 +1329,7 @@ SYSCALL_DEFINE1(shmdt, char __user *, shmaddr) ...@@ -1329,7 +1329,7 @@ SYSCALL_DEFINE1(shmdt, char __user *, shmaddr)
*/ */
file = vma->vm_file; file = vma->vm_file;
size = i_size_read(file_inode(vma->vm_file)); size = i_size_read(file_inode(vma->vm_file));
do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start); do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start, NULL);
/* /*
* We discovered the size of the shm segment, so * We discovered the size of the shm segment, so
* break out of here and fall through to the next * break out of here and fall through to the next
...@@ -1356,7 +1356,7 @@ SYSCALL_DEFINE1(shmdt, char __user *, shmaddr) ...@@ -1356,7 +1356,7 @@ SYSCALL_DEFINE1(shmdt, char __user *, shmaddr)
if ((vma->vm_ops == &shm_vm_ops) && if ((vma->vm_ops == &shm_vm_ops) &&
((vma->vm_start - addr)/PAGE_SIZE == vma->vm_pgoff) && ((vma->vm_start - addr)/PAGE_SIZE == vma->vm_pgoff) &&
(vma->vm_file == file)) (vma->vm_file == file))
do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start); do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start, NULL);
vma = next; vma = next;
} }
...@@ -1365,7 +1365,7 @@ SYSCALL_DEFINE1(shmdt, char __user *, shmaddr) ...@@ -1365,7 +1365,7 @@ SYSCALL_DEFINE1(shmdt, char __user *, shmaddr)
* given * given
*/ */
if (vma && vma->vm_start == addr && vma->vm_ops == &shm_vm_ops) { if (vma && vma->vm_start == addr && vma->vm_ops == &shm_vm_ops) {
do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start); do_munmap(mm, vma->vm_start, vma->vm_end - vma->vm_start, NULL);
retval = 0; retval = 0;
} }
......
...@@ -176,7 +176,7 @@ static struct vm_area_struct *remove_vma(struct vm_area_struct *vma) ...@@ -176,7 +176,7 @@ static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
return next; return next;
} }
static int do_brk(unsigned long addr, unsigned long len); static int do_brk(unsigned long addr, unsigned long len, struct list_head *uf);
SYSCALL_DEFINE1(brk, unsigned long, brk) SYSCALL_DEFINE1(brk, unsigned long, brk)
{ {
...@@ -185,6 +185,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -185,6 +185,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
unsigned long min_brk; unsigned long min_brk;
bool populate; bool populate;
LIST_HEAD(uf);
if (down_write_killable(&mm->mmap_sem)) if (down_write_killable(&mm->mmap_sem))
return -EINTR; return -EINTR;
...@@ -222,7 +223,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -222,7 +223,7 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
/* Always allow shrinking brk. */ /* Always allow shrinking brk. */
if (brk <= mm->brk) { if (brk <= mm->brk) {
if (!do_munmap(mm, newbrk, oldbrk-newbrk)) if (!do_munmap(mm, newbrk, oldbrk-newbrk, &uf))
goto set_brk; goto set_brk;
goto out; goto out;
} }
...@@ -232,13 +233,14 @@ SYSCALL_DEFINE1(brk, unsigned long, brk) ...@@ -232,13 +233,14 @@ SYSCALL_DEFINE1(brk, unsigned long, brk)
goto out; goto out;
/* Ok, looks good - let it rip. */ /* Ok, looks good - let it rip. */
if (do_brk(oldbrk, newbrk-oldbrk) < 0) if (do_brk(oldbrk, newbrk-oldbrk, &uf) < 0)
goto out; goto out;
set_brk: set_brk:
mm->brk = brk; mm->brk = brk;
populate = newbrk > oldbrk && (mm->def_flags & VM_LOCKED) != 0; populate = newbrk > oldbrk && (mm->def_flags & VM_LOCKED) != 0;
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
userfaultfd_unmap_complete(mm, &uf);
if (populate) if (populate)
mm_populate(oldbrk, newbrk - oldbrk); mm_populate(oldbrk, newbrk - oldbrk);
return brk; return brk;
...@@ -1304,7 +1306,8 @@ static inline int mlock_future_check(struct mm_struct *mm, ...@@ -1304,7 +1306,8 @@ static inline int mlock_future_check(struct mm_struct *mm,
unsigned long do_mmap(struct file *file, unsigned long addr, unsigned long do_mmap(struct file *file, unsigned long addr,
unsigned long len, unsigned long prot, unsigned long len, unsigned long prot,
unsigned long flags, vm_flags_t vm_flags, unsigned long flags, vm_flags_t vm_flags,
unsigned long pgoff, unsigned long *populate) unsigned long pgoff, unsigned long *populate,
struct list_head *uf)
{ {
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
int pkey = 0; int pkey = 0;
...@@ -1447,7 +1450,7 @@ unsigned long do_mmap(struct file *file, unsigned long addr, ...@@ -1447,7 +1450,7 @@ unsigned long do_mmap(struct file *file, unsigned long addr,
vm_flags |= VM_NORESERVE; vm_flags |= VM_NORESERVE;
} }
addr = mmap_region(file, addr, len, vm_flags, pgoff); addr = mmap_region(file, addr, len, vm_flags, pgoff, uf);
if (!IS_ERR_VALUE(addr) && if (!IS_ERR_VALUE(addr) &&
((vm_flags & VM_LOCKED) || ((vm_flags & VM_LOCKED) ||
(flags & (MAP_POPULATE | MAP_NONBLOCK)) == MAP_POPULATE)) (flags & (MAP_POPULATE | MAP_NONBLOCK)) == MAP_POPULATE))
...@@ -1583,7 +1586,8 @@ static inline int accountable_mapping(struct file *file, vm_flags_t vm_flags) ...@@ -1583,7 +1586,8 @@ static inline int accountable_mapping(struct file *file, vm_flags_t vm_flags)
} }
unsigned long mmap_region(struct file *file, unsigned long addr, unsigned long mmap_region(struct file *file, unsigned long addr,
unsigned long len, vm_flags_t vm_flags, unsigned long pgoff) unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
struct list_head *uf)
{ {
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
struct vm_area_struct *vma, *prev; struct vm_area_struct *vma, *prev;
...@@ -1609,7 +1613,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr, ...@@ -1609,7 +1613,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
/* Clear old maps */ /* Clear old maps */
while (find_vma_links(mm, addr, addr + len, &prev, &rb_link, while (find_vma_links(mm, addr, addr + len, &prev, &rb_link,
&rb_parent)) { &rb_parent)) {
if (do_munmap(mm, addr, len)) if (do_munmap(mm, addr, len, uf))
return -ENOMEM; return -ENOMEM;
} }
...@@ -2579,7 +2583,8 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma, ...@@ -2579,7 +2583,8 @@ int split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
* work. This now handles partial unmappings. * work. This now handles partial unmappings.
* Jeremy Fitzhardinge <jeremy@goop.org> * Jeremy Fitzhardinge <jeremy@goop.org>
*/ */
int do_munmap(struct mm_struct *mm, unsigned long start, size_t len) int do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
struct list_head *uf)
{ {
unsigned long end; unsigned long end;
struct vm_area_struct *vma, *prev, *last; struct vm_area_struct *vma, *prev, *last;
...@@ -2603,6 +2608,13 @@ int do_munmap(struct mm_struct *mm, unsigned long start, size_t len) ...@@ -2603,6 +2608,13 @@ int do_munmap(struct mm_struct *mm, unsigned long start, size_t len)
if (vma->vm_start >= end) if (vma->vm_start >= end)
return 0; return 0;
if (uf) {
int error = userfaultfd_unmap_prep(vma, start, end, uf);
if (error)
return error;
}
/* /*
* If we need to split any vma, do it now to save pain later. * If we need to split any vma, do it now to save pain later.
* *
...@@ -2668,12 +2680,14 @@ int vm_munmap(unsigned long start, size_t len) ...@@ -2668,12 +2680,14 @@ int vm_munmap(unsigned long start, size_t len)
{ {
int ret; int ret;
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
LIST_HEAD(uf);
if (down_write_killable(&mm->mmap_sem)) if (down_write_killable(&mm->mmap_sem))
return -EINTR; return -EINTR;
ret = do_munmap(mm, start, len); ret = do_munmap(mm, start, len, &uf);
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
userfaultfd_unmap_complete(mm, &uf);
return ret; return ret;
} }
EXPORT_SYMBOL(vm_munmap); EXPORT_SYMBOL(vm_munmap);
...@@ -2773,7 +2787,7 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size, ...@@ -2773,7 +2787,7 @@ SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
file = get_file(vma->vm_file); file = get_file(vma->vm_file);
ret = do_mmap_pgoff(vma->vm_file, start, size, ret = do_mmap_pgoff(vma->vm_file, start, size,
prot, flags, pgoff, &populate); prot, flags, pgoff, &populate, NULL);
fput(file); fput(file);
out: out:
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
...@@ -2799,7 +2813,7 @@ static inline void verify_mm_writelocked(struct mm_struct *mm) ...@@ -2799,7 +2813,7 @@ static inline void verify_mm_writelocked(struct mm_struct *mm)
* anonymous maps. eventually we may be able to do some * anonymous maps. eventually we may be able to do some
* brk-specific accounting here. * brk-specific accounting here.
*/ */
static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long flags) static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long flags, struct list_head *uf)
{ {
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
struct vm_area_struct *vma, *prev; struct vm_area_struct *vma, *prev;
...@@ -2838,7 +2852,7 @@ static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long ...@@ -2838,7 +2852,7 @@ static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long
*/ */
while (find_vma_links(mm, addr, addr + len, &prev, &rb_link, while (find_vma_links(mm, addr, addr + len, &prev, &rb_link,
&rb_parent)) { &rb_parent)) {
if (do_munmap(mm, addr, len)) if (do_munmap(mm, addr, len, uf))
return -ENOMEM; return -ENOMEM;
} }
...@@ -2885,9 +2899,9 @@ static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long ...@@ -2885,9 +2899,9 @@ static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long
return 0; return 0;
} }
static int do_brk(unsigned long addr, unsigned long len) static int do_brk(unsigned long addr, unsigned long len, struct list_head *uf)
{ {
return do_brk_flags(addr, len, 0); return do_brk_flags(addr, len, 0, uf);
} }
int vm_brk_flags(unsigned long addr, unsigned long len, unsigned long flags) int vm_brk_flags(unsigned long addr, unsigned long len, unsigned long flags)
...@@ -2895,13 +2909,15 @@ int vm_brk_flags(unsigned long addr, unsigned long len, unsigned long flags) ...@@ -2895,13 +2909,15 @@ int vm_brk_flags(unsigned long addr, unsigned long len, unsigned long flags)
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
int ret; int ret;
bool populate; bool populate;
LIST_HEAD(uf);
if (down_write_killable(&mm->mmap_sem)) if (down_write_killable(&mm->mmap_sem))
return -EINTR; return -EINTR;
ret = do_brk_flags(addr, len, flags); ret = do_brk_flags(addr, len, flags, &uf);
populate = ((mm->def_flags & VM_LOCKED) != 0); populate = ((mm->def_flags & VM_LOCKED) != 0);
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
userfaultfd_unmap_complete(mm, &uf);
if (populate && !ret) if (populate && !ret)
mm_populate(addr, len); mm_populate(addr, len);
return ret; return ret;
......
...@@ -252,7 +252,8 @@ unsigned long move_page_tables(struct vm_area_struct *vma, ...@@ -252,7 +252,8 @@ unsigned long move_page_tables(struct vm_area_struct *vma,
static unsigned long move_vma(struct vm_area_struct *vma, static unsigned long move_vma(struct vm_area_struct *vma,
unsigned long old_addr, unsigned long old_len, unsigned long old_addr, unsigned long old_len,
unsigned long new_len, unsigned long new_addr, unsigned long new_len, unsigned long new_addr,
bool *locked, struct vm_userfaultfd_ctx *uf) bool *locked, struct vm_userfaultfd_ctx *uf,
struct list_head *uf_unmap)
{ {
struct mm_struct *mm = vma->vm_mm; struct mm_struct *mm = vma->vm_mm;
struct vm_area_struct *new_vma; struct vm_area_struct *new_vma;
...@@ -341,7 +342,7 @@ static unsigned long move_vma(struct vm_area_struct *vma, ...@@ -341,7 +342,7 @@ static unsigned long move_vma(struct vm_area_struct *vma,
if (unlikely(vma->vm_flags & VM_PFNMAP)) if (unlikely(vma->vm_flags & VM_PFNMAP))
untrack_pfn_moved(vma); untrack_pfn_moved(vma);
if (do_munmap(mm, old_addr, old_len) < 0) { if (do_munmap(mm, old_addr, old_len, uf_unmap) < 0) {
/* OOM: unable to split vma, just get accounts right */ /* OOM: unable to split vma, just get accounts right */
vm_unacct_memory(excess >> PAGE_SHIFT); vm_unacct_memory(excess >> PAGE_SHIFT);
excess = 0; excess = 0;
...@@ -417,7 +418,8 @@ static struct vm_area_struct *vma_to_resize(unsigned long addr, ...@@ -417,7 +418,8 @@ static struct vm_area_struct *vma_to_resize(unsigned long addr,
static unsigned long mremap_to(unsigned long addr, unsigned long old_len, static unsigned long mremap_to(unsigned long addr, unsigned long old_len,
unsigned long new_addr, unsigned long new_len, bool *locked, unsigned long new_addr, unsigned long new_len, bool *locked,
struct vm_userfaultfd_ctx *uf) struct vm_userfaultfd_ctx *uf,
struct list_head *uf_unmap)
{ {
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
struct vm_area_struct *vma; struct vm_area_struct *vma;
...@@ -435,12 +437,12 @@ static unsigned long mremap_to(unsigned long addr, unsigned long old_len, ...@@ -435,12 +437,12 @@ static unsigned long mremap_to(unsigned long addr, unsigned long old_len,
if (addr + old_len > new_addr && new_addr + new_len > addr) if (addr + old_len > new_addr && new_addr + new_len > addr)
goto out; goto out;
ret = do_munmap(mm, new_addr, new_len); ret = do_munmap(mm, new_addr, new_len, NULL);
if (ret) if (ret)
goto out; goto out;
if (old_len >= new_len) { if (old_len >= new_len) {
ret = do_munmap(mm, addr+new_len, old_len - new_len); ret = do_munmap(mm, addr+new_len, old_len - new_len, uf_unmap);
if (ret && old_len != new_len) if (ret && old_len != new_len)
goto out; goto out;
old_len = new_len; old_len = new_len;
...@@ -462,7 +464,8 @@ static unsigned long mremap_to(unsigned long addr, unsigned long old_len, ...@@ -462,7 +464,8 @@ static unsigned long mremap_to(unsigned long addr, unsigned long old_len,
if (offset_in_page(ret)) if (offset_in_page(ret))
goto out1; goto out1;
ret = move_vma(vma, addr, old_len, new_len, new_addr, locked, uf); ret = move_vma(vma, addr, old_len, new_len, new_addr, locked, uf,
uf_unmap);
if (!(offset_in_page(ret))) if (!(offset_in_page(ret)))
goto out; goto out;
out1: out1:
...@@ -502,6 +505,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, ...@@ -502,6 +505,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
unsigned long charged = 0; unsigned long charged = 0;
bool locked = false; bool locked = false;
struct vm_userfaultfd_ctx uf = NULL_VM_UFFD_CTX; struct vm_userfaultfd_ctx uf = NULL_VM_UFFD_CTX;
LIST_HEAD(uf_unmap);
if (flags & ~(MREMAP_FIXED | MREMAP_MAYMOVE)) if (flags & ~(MREMAP_FIXED | MREMAP_MAYMOVE))
return ret; return ret;
...@@ -528,7 +532,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, ...@@ -528,7 +532,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
if (flags & MREMAP_FIXED) { if (flags & MREMAP_FIXED) {
ret = mremap_to(addr, old_len, new_addr, new_len, ret = mremap_to(addr, old_len, new_addr, new_len,
&locked, &uf); &locked, &uf, &uf_unmap);
goto out; goto out;
} }
...@@ -538,7 +542,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, ...@@ -538,7 +542,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
* do_munmap does all the needed commit accounting * do_munmap does all the needed commit accounting
*/ */
if (old_len >= new_len) { if (old_len >= new_len) {
ret = do_munmap(mm, addr+new_len, old_len - new_len); ret = do_munmap(mm, addr+new_len, old_len - new_len, &uf_unmap);
if (ret && old_len != new_len) if (ret && old_len != new_len)
goto out; goto out;
ret = addr; ret = addr;
...@@ -598,7 +602,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, ...@@ -598,7 +602,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
} }
ret = move_vma(vma, addr, old_len, new_len, new_addr, ret = move_vma(vma, addr, old_len, new_len, new_addr,
&locked, &uf); &locked, &uf, &uf_unmap);
} }
out: out:
if (offset_in_page(ret)) { if (offset_in_page(ret)) {
...@@ -609,5 +613,6 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len, ...@@ -609,5 +613,6 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
if (locked && new_len > old_len) if (locked && new_len > old_len)
mm_populate(new_addr + old_len, new_len - old_len); mm_populate(new_addr + old_len, new_len - old_len);
mremap_userfaultfd_complete(&uf, addr, new_addr, old_len); mremap_userfaultfd_complete(&uf, addr, new_addr, old_len);
userfaultfd_unmap_complete(mm, &uf_unmap);
return ret; return ret;
} }
...@@ -1205,7 +1205,8 @@ unsigned long do_mmap(struct file *file, ...@@ -1205,7 +1205,8 @@ unsigned long do_mmap(struct file *file,
unsigned long flags, unsigned long flags,
vm_flags_t vm_flags, vm_flags_t vm_flags,
unsigned long pgoff, unsigned long pgoff,
unsigned long *populate) unsigned long *populate,
struct list_head *uf)
{ {
struct vm_area_struct *vma; struct vm_area_struct *vma;
struct vm_region *region; struct vm_region *region;
...@@ -1577,7 +1578,7 @@ static int shrink_vma(struct mm_struct *mm, ...@@ -1577,7 +1578,7 @@ static int shrink_vma(struct mm_struct *mm,
* - under NOMMU conditions the chunk to be unmapped must be backed by a single * - under NOMMU conditions the chunk to be unmapped must be backed by a single
* VMA, though it need not cover the whole VMA * VMA, though it need not cover the whole VMA
*/ */
int do_munmap(struct mm_struct *mm, unsigned long start, size_t len) int do_munmap(struct mm_struct *mm, unsigned long start, size_t len, struct list_head *uf)
{ {
struct vm_area_struct *vma; struct vm_area_struct *vma;
unsigned long end; unsigned long end;
...@@ -1643,7 +1644,7 @@ int vm_munmap(unsigned long addr, size_t len) ...@@ -1643,7 +1644,7 @@ int vm_munmap(unsigned long addr, size_t len)
int ret; int ret;
down_write(&mm->mmap_sem); down_write(&mm->mmap_sem);
ret = do_munmap(mm, addr, len); ret = do_munmap(mm, addr, len, NULL);
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
return ret; return ret;
} }
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <linux/mman.h> #include <linux/mman.h>
#include <linux/hugetlb.h> #include <linux/hugetlb.h>
#include <linux/vmalloc.h> #include <linux/vmalloc.h>
#include <linux/userfaultfd_k.h>
#include <asm/sections.h> #include <asm/sections.h>
#include <linux/uaccess.h> #include <linux/uaccess.h>
...@@ -297,14 +298,16 @@ unsigned long vm_mmap_pgoff(struct file *file, unsigned long addr, ...@@ -297,14 +298,16 @@ unsigned long vm_mmap_pgoff(struct file *file, unsigned long addr,
unsigned long ret; unsigned long ret;
struct mm_struct *mm = current->mm; struct mm_struct *mm = current->mm;
unsigned long populate; unsigned long populate;
LIST_HEAD(uf);
ret = security_mmap_file(file, prot, flag); ret = security_mmap_file(file, prot, flag);
if (!ret) { if (!ret) {
if (down_write_killable(&mm->mmap_sem)) if (down_write_killable(&mm->mmap_sem))
return -EINTR; return -EINTR;
ret = do_mmap_pgoff(file, addr, len, prot, flag, pgoff, ret = do_mmap_pgoff(file, addr, len, prot, flag, pgoff,
&populate); &populate, &uf);
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
userfaultfd_unmap_complete(mm, &uf);
if (populate) if (populate)
mm_populate(ret, populate); mm_populate(ret, populate);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment