Commit d2005e3f authored by Oleg Nesterov's avatar Oleg Nesterov Committed by Linus Torvalds

userfaultfd: don't pin the user memory in userfaultfd_file_create()

userfaultfd_file_create() increments mm->mm_users; this means that the
memory won't be unmapped/freed if mm owner exits/execs, and UFFDIO_COPY
after that can populate the orphaned mm more.

Change userfaultfd_file_create() and userfaultfd_ctx_put() to use
mm->mm_count to pin mm_struct.  This means that
atomic_inc_not_zero(mm->mm_users) is needed when we are going to
actually play with this memory.  Except handle_userfault() path doesn't
need this, the caller must already have a reference.

The patch adds the new trivial helper, mmget_not_zero(), it can have
more users.

Link: http://lkml.kernel.org/r/20160516172254.GA8595@redhat.comSigned-off-by: default avatarOleg Nesterov <oleg@redhat.com>
Cc: Andrea Arcangeli <aarcange@redhat.com>
Cc: Michal Hocko <mhocko@kernel.org>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
Signed-off-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
parent cd33a76b
...@@ -137,7 +137,7 @@ static void userfaultfd_ctx_put(struct userfaultfd_ctx *ctx) ...@@ -137,7 +137,7 @@ static void userfaultfd_ctx_put(struct userfaultfd_ctx *ctx)
VM_BUG_ON(waitqueue_active(&ctx->fault_wqh)); VM_BUG_ON(waitqueue_active(&ctx->fault_wqh));
VM_BUG_ON(spin_is_locked(&ctx->fd_wqh.lock)); VM_BUG_ON(spin_is_locked(&ctx->fd_wqh.lock));
VM_BUG_ON(waitqueue_active(&ctx->fd_wqh)); VM_BUG_ON(waitqueue_active(&ctx->fd_wqh));
mmput(ctx->mm); mmdrop(ctx->mm);
kmem_cache_free(userfaultfd_ctx_cachep, ctx); kmem_cache_free(userfaultfd_ctx_cachep, ctx);
} }
} }
...@@ -434,6 +434,9 @@ static int userfaultfd_release(struct inode *inode, struct file *file) ...@@ -434,6 +434,9 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
ACCESS_ONCE(ctx->released) = true; ACCESS_ONCE(ctx->released) = true;
if (!mmget_not_zero(mm))
goto wakeup;
/* /*
* Flush page faults out of all CPUs. NOTE: all page faults * Flush page faults out of all CPUs. NOTE: all page faults
* must be retried without returning VM_FAULT_SIGBUS if * must be retried without returning VM_FAULT_SIGBUS if
...@@ -466,7 +469,8 @@ static int userfaultfd_release(struct inode *inode, struct file *file) ...@@ -466,7 +469,8 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
} }
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
mmput(mm);
wakeup:
/* /*
* After no new page faults can wait on this fault_*wqh, flush * After no new page faults can wait on this fault_*wqh, flush
* the last page faults that may have been already waiting on * the last page faults that may have been already waiting on
...@@ -760,10 +764,12 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, ...@@ -760,10 +764,12 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
start = uffdio_register.range.start; start = uffdio_register.range.start;
end = start + uffdio_register.range.len; end = start + uffdio_register.range.len;
ret = -ENOMEM;
if (!mmget_not_zero(mm))
goto out;
down_write(&mm->mmap_sem); down_write(&mm->mmap_sem);
vma = find_vma_prev(mm, start, &prev); vma = find_vma_prev(mm, start, &prev);
ret = -ENOMEM;
if (!vma) if (!vma)
goto out_unlock; goto out_unlock;
...@@ -864,6 +870,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, ...@@ -864,6 +870,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
} while (vma && vma->vm_start < end); } while (vma && vma->vm_start < end);
out_unlock: out_unlock:
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
mmput(mm);
if (!ret) { if (!ret) {
/* /*
* Now that we scanned all vmas we can already tell * Now that we scanned all vmas we can already tell
...@@ -902,10 +909,12 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, ...@@ -902,10 +909,12 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
start = uffdio_unregister.start; start = uffdio_unregister.start;
end = start + uffdio_unregister.len; end = start + uffdio_unregister.len;
ret = -ENOMEM;
if (!mmget_not_zero(mm))
goto out;
down_write(&mm->mmap_sem); down_write(&mm->mmap_sem);
vma = find_vma_prev(mm, start, &prev); vma = find_vma_prev(mm, start, &prev);
ret = -ENOMEM;
if (!vma) if (!vma)
goto out_unlock; goto out_unlock;
...@@ -998,6 +1007,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, ...@@ -998,6 +1007,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
} while (vma && vma->vm_start < end); } while (vma && vma->vm_start < end);
out_unlock: out_unlock:
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
mmput(mm);
out: out:
return ret; return ret;
} }
...@@ -1067,9 +1077,11 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx, ...@@ -1067,9 +1077,11 @@ static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
goto out; goto out;
if (uffdio_copy.mode & ~UFFDIO_COPY_MODE_DONTWAKE) if (uffdio_copy.mode & ~UFFDIO_COPY_MODE_DONTWAKE)
goto out; goto out;
if (mmget_not_zero(ctx->mm)) {
ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src, ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src,
uffdio_copy.len); uffdio_copy.len);
mmput(ctx->mm);
}
if (unlikely(put_user(ret, &user_uffdio_copy->copy))) if (unlikely(put_user(ret, &user_uffdio_copy->copy)))
return -EFAULT; return -EFAULT;
if (ret < 0) if (ret < 0)
...@@ -1110,8 +1122,11 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx, ...@@ -1110,8 +1122,11 @@ static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx,
if (uffdio_zeropage.mode & ~UFFDIO_ZEROPAGE_MODE_DONTWAKE) if (uffdio_zeropage.mode & ~UFFDIO_ZEROPAGE_MODE_DONTWAKE)
goto out; goto out;
if (mmget_not_zero(ctx->mm)) {
ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start, ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start,
uffdio_zeropage.range.len); uffdio_zeropage.range.len);
mmput(ctx->mm);
}
if (unlikely(put_user(ret, &user_uffdio_zeropage->zeropage))) if (unlikely(put_user(ret, &user_uffdio_zeropage->zeropage)))
return -EFAULT; return -EFAULT;
if (ret < 0) if (ret < 0)
...@@ -1289,12 +1304,12 @@ static struct file *userfaultfd_file_create(int flags) ...@@ -1289,12 +1304,12 @@ static struct file *userfaultfd_file_create(int flags)
ctx->released = false; ctx->released = false;
ctx->mm = current->mm; ctx->mm = current->mm;
/* prevent the mm struct to be freed */ /* prevent the mm struct to be freed */
atomic_inc(&ctx->mm->mm_users); atomic_inc(&ctx->mm->mm_count);
file = anon_inode_getfile("[userfaultfd]", &userfaultfd_fops, ctx, file = anon_inode_getfile("[userfaultfd]", &userfaultfd_fops, ctx,
O_RDWR | (flags & UFFD_SHARED_FCNTL_FLAGS)); O_RDWR | (flags & UFFD_SHARED_FCNTL_FLAGS));
if (IS_ERR(file)) { if (IS_ERR(file)) {
mmput(ctx->mm); mmdrop(ctx->mm);
kmem_cache_free(userfaultfd_ctx_cachep, ctx); kmem_cache_free(userfaultfd_ctx_cachep, ctx);
} }
out: out:
......
...@@ -2723,12 +2723,17 @@ extern struct mm_struct * mm_alloc(void); ...@@ -2723,12 +2723,17 @@ extern struct mm_struct * mm_alloc(void);
/* mmdrop drops the mm and the page tables */ /* mmdrop drops the mm and the page tables */
extern void __mmdrop(struct mm_struct *); extern void __mmdrop(struct mm_struct *);
static inline void mmdrop(struct mm_struct * mm) static inline void mmdrop(struct mm_struct *mm)
{ {
if (unlikely(atomic_dec_and_test(&mm->mm_count))) if (unlikely(atomic_dec_and_test(&mm->mm_count)))
__mmdrop(mm); __mmdrop(mm);
} }
static inline bool mmget_not_zero(struct mm_struct *mm)
{
return atomic_inc_not_zero(&mm->mm_users);
}
/* mmput gets rid of the mappings and all user-space */ /* mmput gets rid of the mappings and all user-space */
extern void mmput(struct mm_struct *); extern void mmput(struct mm_struct *);
/* same as above but performs the slow path from the async kontext. Can /* same as above but performs the slow path from the async kontext. Can
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment