Commit 15b726ef authored by Andrea Arcangeli's avatar Andrea Arcangeli Committed by Linus Torvalds

userfaultfd: optimize read() and poll() to be O(1)

This makes read O(1) and poll that was already O(1) becomes lockless.
Signed-off-by: default avatarAndrea Arcangeli <aarcange@redhat.com>
Acked-by: default avatarPavel Emelyanov <xemul@parallels.com>
Cc: Sanidhya Kashyap <sanidhya.gatech@gmail.com>
Cc: zhang.zhanghailiang@huawei.com
Cc: "Kirill A. Shutemov" <kirill@shutemov.name>
Cc: Andres Lagar-Cavilla <andreslc@google.com>
Cc: Dave Hansen <dave.hansen@intel.com>
Cc: Paolo Bonzini <pbonzini@redhat.com>
Cc: Rik van Riel <riel@redhat.com>
Cc: Mel Gorman <mgorman@suse.de>
Cc: Andy Lutomirski <luto@amacapital.net>
Cc: Hugh Dickins <hughd@google.com>
Cc: Peter Feiner <pfeiner@google.com>
Cc: "Dr. David Alan Gilbert" <dgilbert@redhat.com>
Cc: Johannes Weiner <hannes@cmpxchg.org>
Cc: "Huangpeng (Peter)" <peter.huangpeng@huawei.com>
Signed-off-by: default avatarAndrew Morton <akpm@linux-foundation.org>
Signed-off-by: default avatarLinus Torvalds <torvalds@linux-foundation.org>
parent ba85c702
...@@ -35,7 +35,9 @@ enum userfaultfd_state { ...@@ -35,7 +35,9 @@ enum userfaultfd_state {
struct userfaultfd_ctx { struct userfaultfd_ctx {
/* pseudo fd refcounting */ /* pseudo fd refcounting */
atomic_t refcount; atomic_t refcount;
/* waitqueue head for the userfaultfd page faults */ /* waitqueue head for the pending (i.e. not read) userfaults */
wait_queue_head_t fault_pending_wqh;
/* waitqueue head for the userfaults */
wait_queue_head_t fault_wqh; wait_queue_head_t fault_wqh;
/* waitqueue head for the pseudo fd to wakeup poll/read */ /* waitqueue head for the pseudo fd to wakeup poll/read */
wait_queue_head_t fd_wqh; wait_queue_head_t fd_wqh;
...@@ -52,11 +54,6 @@ struct userfaultfd_ctx { ...@@ -52,11 +54,6 @@ struct userfaultfd_ctx {
struct userfaultfd_wait_queue { struct userfaultfd_wait_queue {
struct uffd_msg msg; struct uffd_msg msg;
wait_queue_t wq; wait_queue_t wq;
/*
* Only relevant when queued in fault_wqh and only used by the
* read operation to avoid reading the same userfault twice.
*/
bool pending;
struct userfaultfd_ctx *ctx; struct userfaultfd_ctx *ctx;
}; };
...@@ -263,17 +260,21 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address, ...@@ -263,17 +260,21 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address,
init_waitqueue_func_entry(&uwq.wq, userfaultfd_wake_function); init_waitqueue_func_entry(&uwq.wq, userfaultfd_wake_function);
uwq.wq.private = current; uwq.wq.private = current;
uwq.msg = userfault_msg(address, flags, reason); uwq.msg = userfault_msg(address, flags, reason);
uwq.pending = true;
uwq.ctx = ctx; uwq.ctx = ctx;
spin_lock(&ctx->fault_wqh.lock); spin_lock(&ctx->fault_pending_wqh.lock);
/* /*
* After the __add_wait_queue the uwq is visible to userland * After the __add_wait_queue the uwq is visible to userland
* through poll/read(). * through poll/read().
*/ */
__add_wait_queue(&ctx->fault_wqh, &uwq.wq); __add_wait_queue(&ctx->fault_pending_wqh, &uwq.wq);
/*
* The smp_mb() after __set_current_state prevents the reads
* following the spin_unlock to happen before the list_add in
* __add_wait_queue.
*/
set_current_state(TASK_KILLABLE); set_current_state(TASK_KILLABLE);
spin_unlock(&ctx->fault_wqh.lock); spin_unlock(&ctx->fault_pending_wqh.lock);
if (likely(!ACCESS_ONCE(ctx->released) && if (likely(!ACCESS_ONCE(ctx->released) &&
!fatal_signal_pending(current))) { !fatal_signal_pending(current))) {
...@@ -283,11 +284,28 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address, ...@@ -283,11 +284,28 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address,
} }
__set_current_state(TASK_RUNNING); __set_current_state(TASK_RUNNING);
/* see finish_wait() comment for why list_empty_careful() */
/*
* Here we race with the list_del; list_add in
* userfaultfd_ctx_read(), however because we don't ever run
* list_del_init() to refile across the two lists, the prev
* and next pointers will never point to self. list_add also
* would never let any of the two pointers to point to
* self. So list_empty_careful won't risk to see both pointers
* pointing to self at any time during the list refile. The
* only case where list_del_init() is called is the full
* removal in the wake function and there we don't re-list_add
* and it's fine not to block on the spinlock. The uwq on this
* kernel stack can be released after the list_del_init.
*/
if (!list_empty_careful(&uwq.wq.task_list)) { if (!list_empty_careful(&uwq.wq.task_list)) {
spin_lock(&ctx->fault_wqh.lock); spin_lock(&ctx->fault_pending_wqh.lock);
list_del_init(&uwq.wq.task_list); /*
spin_unlock(&ctx->fault_wqh.lock); * No need of list_del_init(), the uwq on the stack
* will be freed shortly anyway.
*/
list_del(&uwq.wq.task_list);
spin_unlock(&ctx->fault_pending_wqh.lock);
} }
/* /*
...@@ -345,59 +363,38 @@ static int userfaultfd_release(struct inode *inode, struct file *file) ...@@ -345,59 +363,38 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
up_write(&mm->mmap_sem); up_write(&mm->mmap_sem);
/* /*
* After no new page faults can wait on this fault_wqh, flush * After no new page faults can wait on this fault_*wqh, flush
* the last page faults that may have been already waiting on * the last page faults that may have been already waiting on
* the fault_wqh. * the fault_*wqh.
*/ */
spin_lock(&ctx->fault_wqh.lock); spin_lock(&ctx->fault_pending_wqh.lock);
__wake_up_locked_key(&ctx->fault_pending_wqh, TASK_NORMAL, 0, &range);
__wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, 0, &range); __wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, 0, &range);
spin_unlock(&ctx->fault_wqh.lock); spin_unlock(&ctx->fault_pending_wqh.lock);
wake_up_poll(&ctx->fd_wqh, POLLHUP); wake_up_poll(&ctx->fd_wqh, POLLHUP);
userfaultfd_ctx_put(ctx); userfaultfd_ctx_put(ctx);
return 0; return 0;
} }
/* fault_wqh.lock must be hold by the caller */ /* fault_pending_wqh.lock must be hold by the caller */
static inline unsigned int find_userfault(struct userfaultfd_ctx *ctx, static inline struct userfaultfd_wait_queue *find_userfault(
struct userfaultfd_wait_queue **uwq) struct userfaultfd_ctx *ctx)
{ {
wait_queue_t *wq; wait_queue_t *wq;
struct userfaultfd_wait_queue *_uwq; struct userfaultfd_wait_queue *uwq;
unsigned int ret = 0;
VM_BUG_ON(!spin_is_locked(&ctx->fault_wqh.lock));
list_for_each_entry(wq, &ctx->fault_wqh.task_list, task_list) { VM_BUG_ON(!spin_is_locked(&ctx->fault_pending_wqh.lock));
_uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
if (_uwq->pending) {
ret = POLLIN;
if (!uwq)
/*
* If there's at least a pending and
* we don't care which one it is,
* break immediately and leverage the
* efficiency of the LIFO walk.
*/
break;
/*
* If we need to find which one was pending we
* keep walking until we find the first not
* pending one, so we read() them in FIFO order.
*/
*uwq = _uwq;
} else
/*
* break the loop at the first not pending
* one, there cannot be pending userfaults
* after the first not pending one, because
* all new pending ones are inserted at the
* head and we walk it in LIFO.
*/
break;
}
return ret; uwq = NULL;
if (!waitqueue_active(&ctx->fault_pending_wqh))
goto out;
/* walk in reverse to provide FIFO behavior to read userfaults */
wq = list_last_entry(&ctx->fault_pending_wqh.task_list,
typeof(*wq), task_list);
uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
out:
return uwq;
} }
static unsigned int userfaultfd_poll(struct file *file, poll_table *wait) static unsigned int userfaultfd_poll(struct file *file, poll_table *wait)
...@@ -417,9 +414,20 @@ static unsigned int userfaultfd_poll(struct file *file, poll_table *wait) ...@@ -417,9 +414,20 @@ static unsigned int userfaultfd_poll(struct file *file, poll_table *wait)
*/ */
if (unlikely(!(file->f_flags & O_NONBLOCK))) if (unlikely(!(file->f_flags & O_NONBLOCK)))
return POLLERR; return POLLERR;
spin_lock(&ctx->fault_wqh.lock); /*
ret = find_userfault(ctx, NULL); * lockless access to see if there are pending faults
spin_unlock(&ctx->fault_wqh.lock); * __pollwait last action is the add_wait_queue but
* the spin_unlock would allow the waitqueue_active to
* pass above the actual list_add inside
* add_wait_queue critical section. So use a full
* memory barrier to serialize the list_add write of
* add_wait_queue() with the waitqueue_active read
* below.
*/
ret = 0;
smp_mb();
if (waitqueue_active(&ctx->fault_pending_wqh))
ret = POLLIN;
return ret; return ret;
default: default:
BUG(); BUG();
...@@ -431,27 +439,47 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait, ...@@ -431,27 +439,47 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
{ {
ssize_t ret; ssize_t ret;
DECLARE_WAITQUEUE(wait, current); DECLARE_WAITQUEUE(wait, current);
struct userfaultfd_wait_queue *uwq = NULL; struct userfaultfd_wait_queue *uwq;
/* always take the fd_wqh lock before the fault_wqh lock */ /* always take the fd_wqh lock before the fault_pending_wqh lock */
spin_lock(&ctx->fd_wqh.lock); spin_lock(&ctx->fd_wqh.lock);
__add_wait_queue(&ctx->fd_wqh, &wait); __add_wait_queue(&ctx->fd_wqh, &wait);
for (;;) { for (;;) {
set_current_state(TASK_INTERRUPTIBLE); set_current_state(TASK_INTERRUPTIBLE);
spin_lock(&ctx->fault_wqh.lock); spin_lock(&ctx->fault_pending_wqh.lock);
if (find_userfault(ctx, &uwq)) { uwq = find_userfault(ctx);
if (uwq) {
/* /*
* The fault_wqh.lock prevents the uwq to * The fault_pending_wqh.lock prevents the uwq
* disappear from under us. * to disappear from under us.
*
* Refile this userfault from
* fault_pending_wqh to fault_wqh, it's not
* pending anymore after we read it.
*
* Use list_del() by hand (as
* userfaultfd_wake_function also uses
* list_del_init() by hand) to be sure nobody
* changes __remove_wait_queue() to use
* list_del_init() in turn breaking the
* !list_empty_careful() check in
* handle_userfault(). The uwq->wq.task_list
* must never be empty at any time during the
* refile, or the waitqueue could disappear
* from under us. The "wait_queue_head_t"
* parameter of __remove_wait_queue() is unused
* anyway.
*/ */
uwq->pending = false; list_del(&uwq->wq.task_list);
__add_wait_queue(&ctx->fault_wqh, &uwq->wq);
/* careful to always initialize msg if ret == 0 */ /* careful to always initialize msg if ret == 0 */
*msg = uwq->msg; *msg = uwq->msg;
spin_unlock(&ctx->fault_wqh.lock); spin_unlock(&ctx->fault_pending_wqh.lock);
ret = 0; ret = 0;
break; break;
} }
spin_unlock(&ctx->fault_wqh.lock); spin_unlock(&ctx->fault_pending_wqh.lock);
if (signal_pending(current)) { if (signal_pending(current)) {
ret = -ERESTARTSYS; ret = -ERESTARTSYS;
break; break;
...@@ -510,10 +538,14 @@ static void __wake_userfault(struct userfaultfd_ctx *ctx, ...@@ -510,10 +538,14 @@ static void __wake_userfault(struct userfaultfd_ctx *ctx,
start = range->start; start = range->start;
end = range->start + range->len; end = range->start + range->len;
spin_lock(&ctx->fault_wqh.lock); spin_lock(&ctx->fault_pending_wqh.lock);
/* wake all in the range and autoremove */ /* wake all in the range and autoremove */
__wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, 0, range); if (waitqueue_active(&ctx->fault_pending_wqh))
spin_unlock(&ctx->fault_wqh.lock); __wake_up_locked_key(&ctx->fault_pending_wqh, TASK_NORMAL, 0,
range);
if (waitqueue_active(&ctx->fault_wqh))
__wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, 0, range);
spin_unlock(&ctx->fault_pending_wqh.lock);
} }
static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx, static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx,
...@@ -534,7 +566,8 @@ static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx, ...@@ -534,7 +566,8 @@ static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx,
* userfaults yet. So we take the spinlock only when we're * userfaults yet. So we take the spinlock only when we're
* sure we've userfaults to wake. * sure we've userfaults to wake.
*/ */
if (waitqueue_active(&ctx->fault_wqh)) if (waitqueue_active(&ctx->fault_pending_wqh) ||
waitqueue_active(&ctx->fault_wqh))
__wake_userfault(ctx, range); __wake_userfault(ctx, range);
} }
...@@ -960,14 +993,17 @@ static void userfaultfd_show_fdinfo(struct seq_file *m, struct file *f) ...@@ -960,14 +993,17 @@ static void userfaultfd_show_fdinfo(struct seq_file *m, struct file *f)
struct userfaultfd_wait_queue *uwq; struct userfaultfd_wait_queue *uwq;
unsigned long pending = 0, total = 0; unsigned long pending = 0, total = 0;
spin_lock(&ctx->fault_wqh.lock); spin_lock(&ctx->fault_pending_wqh.lock);
list_for_each_entry(wq, &ctx->fault_pending_wqh.task_list, task_list) {
uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
pending++;
total++;
}
list_for_each_entry(wq, &ctx->fault_wqh.task_list, task_list) { list_for_each_entry(wq, &ctx->fault_wqh.task_list, task_list) {
uwq = container_of(wq, struct userfaultfd_wait_queue, wq); uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
if (uwq->pending)
pending++;
total++; total++;
} }
spin_unlock(&ctx->fault_wqh.lock); spin_unlock(&ctx->fault_pending_wqh.lock);
/* /*
* If more protocols will be added, there will be all shown * If more protocols will be added, there will be all shown
...@@ -1027,6 +1063,7 @@ static struct file *userfaultfd_file_create(int flags) ...@@ -1027,6 +1063,7 @@ static struct file *userfaultfd_file_create(int flags)
goto out; goto out;
atomic_set(&ctx->refcount, 1); atomic_set(&ctx->refcount, 1);
init_waitqueue_head(&ctx->fault_pending_wqh);
init_waitqueue_head(&ctx->fault_wqh); init_waitqueue_head(&ctx->fault_wqh);
init_waitqueue_head(&ctx->fd_wqh); init_waitqueue_head(&ctx->fd_wqh);
ctx->flags = flags; ctx->flags = flags;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment