Commit 4333a9b0 authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'io_uring-5.8-2020-06-19' of git://git.kernel.dk/linux-block

Pull io_uring fixes from Jens Axboe:

 - Catch a case where io_sq_thread() didn't do proper mm acquire

 - Ensure poll completions are reaped on shutdown

 - Async cancelation and run fixes (Pavel)

 - io-poll race fixes (Xiaoguang)

 - Request cleanup race fix (Xiaoguang)

* tag 'io_uring-5.8-2020-06-19' of git://git.kernel.dk/linux-block:
  io_uring: fix possible race condition against REQ_F_NEED_CLEANUP
  io_uring: reap poll completions while waiting for refs to drop on exit
  io_uring: acquire 'mm' for task_work for SQPOLL
  io_uring: add memory barrier to synchronize io_kiocb's result and iopoll_completed
  io_uring: don't fail links for EAGAIN error in IOPOLL mode
  io_uring: cancel by ->task not pid
  io_uring: lazy get task
  io_uring: batch cancel in io_uring_cancel_files()
  io_uring: cancel all task's requests on exit
  io-wq: add an option to cancel all matched reqs
  io-wq: reorder cancellation pending -> running
  io_uring: fix lazy work init
parents d2b1c81f 6f2cc166
...@@ -903,13 +903,15 @@ void io_wq_cancel_all(struct io_wq *wq) ...@@ -903,13 +903,15 @@ void io_wq_cancel_all(struct io_wq *wq)
struct io_cb_cancel_data { struct io_cb_cancel_data {
work_cancel_fn *fn; work_cancel_fn *fn;
void *data; void *data;
int nr_running;
int nr_pending;
bool cancel_all;
}; };
static bool io_wq_worker_cancel(struct io_worker *worker, void *data) static bool io_wq_worker_cancel(struct io_worker *worker, void *data)
{ {
struct io_cb_cancel_data *match = data; struct io_cb_cancel_data *match = data;
unsigned long flags; unsigned long flags;
bool ret = false;
/* /*
* Hold the lock to avoid ->cur_work going out of scope, caller * Hold the lock to avoid ->cur_work going out of scope, caller
...@@ -920,74 +922,90 @@ static bool io_wq_worker_cancel(struct io_worker *worker, void *data) ...@@ -920,74 +922,90 @@ static bool io_wq_worker_cancel(struct io_worker *worker, void *data)
!(worker->cur_work->flags & IO_WQ_WORK_NO_CANCEL) && !(worker->cur_work->flags & IO_WQ_WORK_NO_CANCEL) &&
match->fn(worker->cur_work, match->data)) { match->fn(worker->cur_work, match->data)) {
send_sig(SIGINT, worker->task, 1); send_sig(SIGINT, worker->task, 1);
ret = true; match->nr_running++;
} }
spin_unlock_irqrestore(&worker->lock, flags); spin_unlock_irqrestore(&worker->lock, flags);
return ret; return match->nr_running && !match->cancel_all;
} }
static enum io_wq_cancel io_wqe_cancel_work(struct io_wqe *wqe, static void io_wqe_cancel_pending_work(struct io_wqe *wqe,
struct io_cb_cancel_data *match) struct io_cb_cancel_data *match)
{ {
struct io_wq_work_node *node, *prev; struct io_wq_work_node *node, *prev;
struct io_wq_work *work; struct io_wq_work *work;
unsigned long flags; unsigned long flags;
bool found = false;
/* retry:
* First check pending list, if we're lucky we can just remove it
* from there. CANCEL_OK means that the work is returned as-new,
* no completion will be posted for it.
*/
spin_lock_irqsave(&wqe->lock, flags); spin_lock_irqsave(&wqe->lock, flags);
wq_list_for_each(node, prev, &wqe->work_list) { wq_list_for_each(node, prev, &wqe->work_list) {
work = container_of(node, struct io_wq_work, list); work = container_of(node, struct io_wq_work, list);
if (!match->fn(work, match->data))
continue;
if (match->fn(work, match->data)) {
wq_list_del(&wqe->work_list, node, prev); wq_list_del(&wqe->work_list, node, prev);
found = true;
break;
}
}
spin_unlock_irqrestore(&wqe->lock, flags); spin_unlock_irqrestore(&wqe->lock, flags);
if (found) {
io_run_cancel(work, wqe); io_run_cancel(work, wqe);
return IO_WQ_CANCEL_OK; match->nr_pending++;
if (!match->cancel_all)
return;
/* not safe to continue after unlock */
goto retry;
} }
spin_unlock_irqrestore(&wqe->lock, flags);
}
/* static void io_wqe_cancel_running_work(struct io_wqe *wqe,
* Now check if a free (going busy) or busy worker has the work struct io_cb_cancel_data *match)
* currently running. If we find it there, we'll return CANCEL_RUNNING {
* as an indication that we attempt to signal cancellation. The
* completion will run normally in this case.
*/
rcu_read_lock(); rcu_read_lock();
found = io_wq_for_each_worker(wqe, io_wq_worker_cancel, match); io_wq_for_each_worker(wqe, io_wq_worker_cancel, match);
rcu_read_unlock(); rcu_read_unlock();
return found ? IO_WQ_CANCEL_RUNNING : IO_WQ_CANCEL_NOTFOUND;
} }
enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel, enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel,
void *data) void *data, bool cancel_all)
{ {
struct io_cb_cancel_data match = { struct io_cb_cancel_data match = {
.fn = cancel, .fn = cancel,
.data = data, .data = data,
.cancel_all = cancel_all,
}; };
enum io_wq_cancel ret = IO_WQ_CANCEL_NOTFOUND;
int node; int node;
/*
* First check pending list, if we're lucky we can just remove it
* from there. CANCEL_OK means that the work is returned as-new,
* no completion will be posted for it.
*/
for_each_node(node) {
struct io_wqe *wqe = wq->wqes[node];
io_wqe_cancel_pending_work(wqe, &match);
if (match.nr_pending && !match.cancel_all)
return IO_WQ_CANCEL_OK;
}
/*
* Now check if a free (going busy) or busy worker has the work
* currently running. If we find it there, we'll return CANCEL_RUNNING
* as an indication that we attempt to signal cancellation. The
* completion will run normally in this case.
*/
for_each_node(node) { for_each_node(node) {
struct io_wqe *wqe = wq->wqes[node]; struct io_wqe *wqe = wq->wqes[node];
ret = io_wqe_cancel_work(wqe, &match); io_wqe_cancel_running_work(wqe, &match);
if (ret != IO_WQ_CANCEL_NOTFOUND) if (match.nr_running && !match.cancel_all)
break; return IO_WQ_CANCEL_RUNNING;
} }
return ret; if (match.nr_running)
return IO_WQ_CANCEL_RUNNING;
if (match.nr_pending)
return IO_WQ_CANCEL_OK;
return IO_WQ_CANCEL_NOTFOUND;
} }
static bool io_wq_io_cb_cancel_data(struct io_wq_work *work, void *data) static bool io_wq_io_cb_cancel_data(struct io_wq_work *work, void *data)
...@@ -997,21 +1015,7 @@ static bool io_wq_io_cb_cancel_data(struct io_wq_work *work, void *data) ...@@ -997,21 +1015,7 @@ static bool io_wq_io_cb_cancel_data(struct io_wq_work *work, void *data)
enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork) enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork)
{ {
return io_wq_cancel_cb(wq, io_wq_io_cb_cancel_data, (void *)cwork); return io_wq_cancel_cb(wq, io_wq_io_cb_cancel_data, (void *)cwork, false);
}
static bool io_wq_pid_match(struct io_wq_work *work, void *data)
{
pid_t pid = (pid_t) (unsigned long) data;
return work->task_pid == pid;
}
enum io_wq_cancel io_wq_cancel_pid(struct io_wq *wq, pid_t pid)
{
void *data = (void *) (unsigned long) pid;
return io_wq_cancel_cb(wq, io_wq_pid_match, data);
} }
struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data) struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
......
...@@ -90,7 +90,6 @@ struct io_wq_work { ...@@ -90,7 +90,6 @@ struct io_wq_work {
const struct cred *creds; const struct cred *creds;
struct fs_struct *fs; struct fs_struct *fs;
unsigned flags; unsigned flags;
pid_t task_pid;
}; };
static inline struct io_wq_work *wq_next_work(struct io_wq_work *work) static inline struct io_wq_work *wq_next_work(struct io_wq_work *work)
...@@ -125,12 +124,11 @@ static inline bool io_wq_is_hashed(struct io_wq_work *work) ...@@ -125,12 +124,11 @@ static inline bool io_wq_is_hashed(struct io_wq_work *work)
void io_wq_cancel_all(struct io_wq *wq); void io_wq_cancel_all(struct io_wq *wq);
enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork); enum io_wq_cancel io_wq_cancel_work(struct io_wq *wq, struct io_wq_work *cwork);
enum io_wq_cancel io_wq_cancel_pid(struct io_wq *wq, pid_t pid);
typedef bool (work_cancel_fn)(struct io_wq_work *, void *); typedef bool (work_cancel_fn)(struct io_wq_work *, void *);
enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel, enum io_wq_cancel io_wq_cancel_cb(struct io_wq *wq, work_cancel_fn *cancel,
void *data); void *data, bool cancel_all);
struct task_struct *io_wq_get_task(struct io_wq *wq); struct task_struct *io_wq_get_task(struct io_wq *wq);
......
...@@ -541,6 +541,7 @@ enum { ...@@ -541,6 +541,7 @@ enum {
REQ_F_NO_FILE_TABLE_BIT, REQ_F_NO_FILE_TABLE_BIT,
REQ_F_QUEUE_TIMEOUT_BIT, REQ_F_QUEUE_TIMEOUT_BIT,
REQ_F_WORK_INITIALIZED_BIT, REQ_F_WORK_INITIALIZED_BIT,
REQ_F_TASK_PINNED_BIT,
/* not a real bit, just to check we're not overflowing the space */ /* not a real bit, just to check we're not overflowing the space */
__REQ_F_LAST_BIT, __REQ_F_LAST_BIT,
...@@ -598,6 +599,8 @@ enum { ...@@ -598,6 +599,8 @@ enum {
REQ_F_QUEUE_TIMEOUT = BIT(REQ_F_QUEUE_TIMEOUT_BIT), REQ_F_QUEUE_TIMEOUT = BIT(REQ_F_QUEUE_TIMEOUT_BIT),
/* io_wq_work is initialized */ /* io_wq_work is initialized */
REQ_F_WORK_INITIALIZED = BIT(REQ_F_WORK_INITIALIZED_BIT), REQ_F_WORK_INITIALIZED = BIT(REQ_F_WORK_INITIALIZED_BIT),
/* req->task is refcounted */
REQ_F_TASK_PINNED = BIT(REQ_F_TASK_PINNED_BIT),
}; };
struct async_poll { struct async_poll {
...@@ -910,6 +913,21 @@ struct sock *io_uring_get_socket(struct file *file) ...@@ -910,6 +913,21 @@ struct sock *io_uring_get_socket(struct file *file)
} }
EXPORT_SYMBOL(io_uring_get_socket); EXPORT_SYMBOL(io_uring_get_socket);
static void io_get_req_task(struct io_kiocb *req)
{
if (req->flags & REQ_F_TASK_PINNED)
return;
get_task_struct(req->task);
req->flags |= REQ_F_TASK_PINNED;
}
/* not idempotent -- it doesn't clear REQ_F_TASK_PINNED */
static void __io_put_req_task(struct io_kiocb *req)
{
if (req->flags & REQ_F_TASK_PINNED)
put_task_struct(req->task);
}
static void io_file_put_work(struct work_struct *work); static void io_file_put_work(struct work_struct *work);
/* /*
...@@ -1045,8 +1063,6 @@ static inline void io_req_work_grab_env(struct io_kiocb *req, ...@@ -1045,8 +1063,6 @@ static inline void io_req_work_grab_env(struct io_kiocb *req,
} }
spin_unlock(&current->fs->lock); spin_unlock(&current->fs->lock);
} }
if (!req->work.task_pid)
req->work.task_pid = task_pid_vnr(current);
} }
static inline void io_req_work_drop_env(struct io_kiocb *req) static inline void io_req_work_drop_env(struct io_kiocb *req)
...@@ -1087,6 +1103,7 @@ static inline void io_prep_async_work(struct io_kiocb *req, ...@@ -1087,6 +1103,7 @@ static inline void io_prep_async_work(struct io_kiocb *req,
req->work.flags |= IO_WQ_WORK_UNBOUND; req->work.flags |= IO_WQ_WORK_UNBOUND;
} }
io_req_init_async(req);
io_req_work_grab_env(req, def); io_req_work_grab_env(req, def);
*link = io_prep_linked_timeout(req); *link = io_prep_linked_timeout(req);
...@@ -1398,9 +1415,7 @@ static void __io_req_aux_free(struct io_kiocb *req) ...@@ -1398,9 +1415,7 @@ static void __io_req_aux_free(struct io_kiocb *req)
kfree(req->io); kfree(req->io);
if (req->file) if (req->file)
io_put_file(req, req->file, (req->flags & REQ_F_FIXED_FILE)); io_put_file(req, req->file, (req->flags & REQ_F_FIXED_FILE));
if (req->task) __io_put_req_task(req);
put_task_struct(req->task);
io_req_work_drop_env(req); io_req_work_drop_env(req);
} }
...@@ -1727,6 +1742,18 @@ static int io_put_kbuf(struct io_kiocb *req) ...@@ -1727,6 +1742,18 @@ static int io_put_kbuf(struct io_kiocb *req)
return cflags; return cflags;
} }
static void io_iopoll_queue(struct list_head *again)
{
struct io_kiocb *req;
do {
req = list_first_entry(again, struct io_kiocb, list);
list_del(&req->list);
refcount_inc(&req->refs);
io_queue_async_work(req);
} while (!list_empty(again));
}
/* /*
* Find and free completed poll iocbs * Find and free completed poll iocbs
*/ */
...@@ -1735,12 +1762,21 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events, ...@@ -1735,12 +1762,21 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
{ {
struct req_batch rb; struct req_batch rb;
struct io_kiocb *req; struct io_kiocb *req;
LIST_HEAD(again);
/* order with ->result store in io_complete_rw_iopoll() */
smp_rmb();
rb.to_free = rb.need_iter = 0; rb.to_free = rb.need_iter = 0;
while (!list_empty(done)) { while (!list_empty(done)) {
int cflags = 0; int cflags = 0;
req = list_first_entry(done, struct io_kiocb, list); req = list_first_entry(done, struct io_kiocb, list);
if (READ_ONCE(req->result) == -EAGAIN) {
req->iopoll_completed = 0;
list_move_tail(&req->list, &again);
continue;
}
list_del(&req->list); list_del(&req->list);
if (req->flags & REQ_F_BUFFER_SELECTED) if (req->flags & REQ_F_BUFFER_SELECTED)
...@@ -1758,18 +1794,9 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events, ...@@ -1758,18 +1794,9 @@ static void io_iopoll_complete(struct io_ring_ctx *ctx, unsigned int *nr_events,
if (ctx->flags & IORING_SETUP_SQPOLL) if (ctx->flags & IORING_SETUP_SQPOLL)
io_cqring_ev_posted(ctx); io_cqring_ev_posted(ctx);
io_free_req_many(ctx, &rb); io_free_req_many(ctx, &rb);
}
static void io_iopoll_queue(struct list_head *again) if (!list_empty(&again))
{ io_iopoll_queue(&again);
struct io_kiocb *req;
do {
req = list_first_entry(again, struct io_kiocb, list);
list_del(&req->list);
refcount_inc(&req->refs);
io_queue_async_work(req);
} while (!list_empty(again));
} }
static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events, static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
...@@ -1777,7 +1804,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events, ...@@ -1777,7 +1804,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
{ {
struct io_kiocb *req, *tmp; struct io_kiocb *req, *tmp;
LIST_HEAD(done); LIST_HEAD(done);
LIST_HEAD(again);
bool spin; bool spin;
int ret; int ret;
...@@ -1803,13 +1829,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events, ...@@ -1803,13 +1829,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
if (!list_empty(&done)) if (!list_empty(&done))
break; break;
if (req->result == -EAGAIN) {
list_move_tail(&req->list, &again);
continue;
}
if (!list_empty(&again))
break;
ret = kiocb->ki_filp->f_op->iopoll(kiocb, spin); ret = kiocb->ki_filp->f_op->iopoll(kiocb, spin);
if (ret < 0) if (ret < 0)
break; break;
...@@ -1822,9 +1841,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events, ...@@ -1822,9 +1841,6 @@ static int io_do_iopoll(struct io_ring_ctx *ctx, unsigned int *nr_events,
if (!list_empty(&done)) if (!list_empty(&done))
io_iopoll_complete(ctx, nr_events, &done); io_iopoll_complete(ctx, nr_events, &done);
if (!list_empty(&again))
io_iopoll_queue(&again);
return ret; return ret;
} }
...@@ -1973,11 +1989,15 @@ static void io_complete_rw_iopoll(struct kiocb *kiocb, long res, long res2) ...@@ -1973,11 +1989,15 @@ static void io_complete_rw_iopoll(struct kiocb *kiocb, long res, long res2)
if (kiocb->ki_flags & IOCB_WRITE) if (kiocb->ki_flags & IOCB_WRITE)
kiocb_end_write(req); kiocb_end_write(req);
if (res != req->result) if (res != -EAGAIN && res != req->result)
req_set_fail_links(req); req_set_fail_links(req);
req->result = res;
if (res != -EAGAIN) WRITE_ONCE(req->result, res);
/* order with io_poll_complete() checking ->result */
if (res != -EAGAIN) {
smp_wmb();
WRITE_ONCE(req->iopoll_completed, 1); WRITE_ONCE(req->iopoll_completed, 1);
}
} }
/* /*
...@@ -2650,8 +2670,8 @@ static int io_read(struct io_kiocb *req, bool force_nonblock) ...@@ -2650,8 +2670,8 @@ static int io_read(struct io_kiocb *req, bool force_nonblock)
} }
} }
out_free: out_free:
if (!(req->flags & REQ_F_NEED_CLEANUP))
kfree(iovec); kfree(iovec);
req->flags &= ~REQ_F_NEED_CLEANUP;
return ret; return ret;
} }
...@@ -2773,7 +2793,7 @@ static int io_write(struct io_kiocb *req, bool force_nonblock) ...@@ -2773,7 +2793,7 @@ static int io_write(struct io_kiocb *req, bool force_nonblock)
} }
} }
out_free: out_free:
req->flags &= ~REQ_F_NEED_CLEANUP; if (!(req->flags & REQ_F_NEED_CLEANUP))
kfree(iovec); kfree(iovec);
return ret; return ret;
} }
...@@ -4236,6 +4256,28 @@ static void io_async_queue_proc(struct file *file, struct wait_queue_head *head, ...@@ -4236,6 +4256,28 @@ static void io_async_queue_proc(struct file *file, struct wait_queue_head *head,
__io_queue_proc(&pt->req->apoll->poll, pt, head); __io_queue_proc(&pt->req->apoll->poll, pt, head);
} }
static void io_sq_thread_drop_mm(struct io_ring_ctx *ctx)
{
struct mm_struct *mm = current->mm;
if (mm) {
kthread_unuse_mm(mm);
mmput(mm);
}
}
static int io_sq_thread_acquire_mm(struct io_ring_ctx *ctx,
struct io_kiocb *req)
{
if (io_op_defs[req->opcode].needs_mm && !current->mm) {
if (unlikely(!mmget_not_zero(ctx->sqo_mm)))
return -EFAULT;
kthread_use_mm(ctx->sqo_mm);
}
return 0;
}
static void io_async_task_func(struct callback_head *cb) static void io_async_task_func(struct callback_head *cb)
{ {
struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work); struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
...@@ -4270,11 +4312,16 @@ static void io_async_task_func(struct callback_head *cb) ...@@ -4270,11 +4312,16 @@ static void io_async_task_func(struct callback_head *cb)
if (!canceled) { if (!canceled) {
__set_current_state(TASK_RUNNING); __set_current_state(TASK_RUNNING);
if (io_sq_thread_acquire_mm(ctx, req)) {
io_cqring_add_event(req, -EFAULT);
goto end_req;
}
mutex_lock(&ctx->uring_lock); mutex_lock(&ctx->uring_lock);
__io_queue_sqe(req, NULL); __io_queue_sqe(req, NULL);
mutex_unlock(&ctx->uring_lock); mutex_unlock(&ctx->uring_lock);
} else { } else {
io_cqring_ev_posted(ctx); io_cqring_ev_posted(ctx);
end_req:
req_set_fail_links(req); req_set_fail_links(req);
io_double_put_req(req); io_double_put_req(req);
} }
...@@ -4366,8 +4413,7 @@ static bool io_arm_poll_handler(struct io_kiocb *req) ...@@ -4366,8 +4413,7 @@ static bool io_arm_poll_handler(struct io_kiocb *req)
memcpy(&apoll->work, &req->work, sizeof(req->work)); memcpy(&apoll->work, &req->work, sizeof(req->work));
had_io = req->io != NULL; had_io = req->io != NULL;
get_task_struct(current); io_get_req_task(req);
req->task = current;
req->apoll = apoll; req->apoll = apoll;
INIT_HLIST_NODE(&req->hash_node); INIT_HLIST_NODE(&req->hash_node);
...@@ -4555,8 +4601,7 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe ...@@ -4555,8 +4601,7 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe
events = READ_ONCE(sqe->poll_events); events = READ_ONCE(sqe->poll_events);
poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP; poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP;
get_task_struct(current); io_get_req_task(req);
req->task = current;
return 0; return 0;
} }
...@@ -4772,7 +4817,7 @@ static int io_async_cancel_one(struct io_ring_ctx *ctx, void *sqe_addr) ...@@ -4772,7 +4817,7 @@ static int io_async_cancel_one(struct io_ring_ctx *ctx, void *sqe_addr)
enum io_wq_cancel cancel_ret; enum io_wq_cancel cancel_ret;
int ret = 0; int ret = 0;
cancel_ret = io_wq_cancel_cb(ctx->io_wq, io_cancel_cb, sqe_addr); cancel_ret = io_wq_cancel_cb(ctx->io_wq, io_cancel_cb, sqe_addr, false);
switch (cancel_ret) { switch (cancel_ret) {
case IO_WQ_CANCEL_OK: case IO_WQ_CANCEL_OK:
ret = 0; ret = 0;
...@@ -5817,17 +5862,14 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req, ...@@ -5817,17 +5862,14 @@ static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
req->flags = 0; req->flags = 0;
/* one is dropped after submission, the other at completion */ /* one is dropped after submission, the other at completion */
refcount_set(&req->refs, 2); refcount_set(&req->refs, 2);
req->task = NULL; req->task = current;
req->result = 0; req->result = 0;
if (unlikely(req->opcode >= IORING_OP_LAST)) if (unlikely(req->opcode >= IORING_OP_LAST))
return -EINVAL; return -EINVAL;
if (io_op_defs[req->opcode].needs_mm && !current->mm) { if (unlikely(io_sq_thread_acquire_mm(ctx, req)))
if (unlikely(!mmget_not_zero(ctx->sqo_mm)))
return -EFAULT; return -EFAULT;
kthread_use_mm(ctx->sqo_mm);
}
sqe_flags = READ_ONCE(sqe->flags); sqe_flags = READ_ONCE(sqe->flags);
/* enforce forwards compatibility on users */ /* enforce forwards compatibility on users */
...@@ -5936,16 +5978,6 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, ...@@ -5936,16 +5978,6 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
return submitted; return submitted;
} }
static inline void io_sq_thread_drop_mm(struct io_ring_ctx *ctx)
{
struct mm_struct *mm = current->mm;
if (mm) {
kthread_unuse_mm(mm);
mmput(mm);
}
}
static int io_sq_thread(void *data) static int io_sq_thread(void *data)
{ {
struct io_ring_ctx *ctx = data; struct io_ring_ctx *ctx = data;
...@@ -7331,7 +7363,17 @@ static void io_ring_exit_work(struct work_struct *work) ...@@ -7331,7 +7363,17 @@ static void io_ring_exit_work(struct work_struct *work)
if (ctx->rings) if (ctx->rings)
io_cqring_overflow_flush(ctx, true); io_cqring_overflow_flush(ctx, true);
wait_for_completion(&ctx->ref_comp); /*
* If we're doing polled IO and end up having requests being
* submitted async (out-of-line), then completions can come in while
* we're waiting for refs to drop. We need to reap these manually,
* as nobody else will be looking for them.
*/
while (!wait_for_completion_timeout(&ctx->ref_comp, HZ/20)) {
io_iopoll_reap_events(ctx);
if (ctx->rings)
io_cqring_overflow_flush(ctx, true);
}
io_ring_ctx_free(ctx); io_ring_ctx_free(ctx);
} }
...@@ -7365,9 +7407,22 @@ static int io_uring_release(struct inode *inode, struct file *file) ...@@ -7365,9 +7407,22 @@ static int io_uring_release(struct inode *inode, struct file *file)
return 0; return 0;
} }
static bool io_wq_files_match(struct io_wq_work *work, void *data)
{
struct files_struct *files = data;
return work->files == files;
}
static void io_uring_cancel_files(struct io_ring_ctx *ctx, static void io_uring_cancel_files(struct io_ring_ctx *ctx,
struct files_struct *files) struct files_struct *files)
{ {
if (list_empty_careful(&ctx->inflight_list))
return;
/* cancel all at once, should be faster than doing it one by one*/
io_wq_cancel_cb(ctx->io_wq, io_wq_files_match, files, true);
while (!list_empty_careful(&ctx->inflight_list)) { while (!list_empty_careful(&ctx->inflight_list)) {
struct io_kiocb *cancel_req = NULL, *req; struct io_kiocb *cancel_req = NULL, *req;
DEFINE_WAIT(wait); DEFINE_WAIT(wait);
...@@ -7423,6 +7478,14 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx, ...@@ -7423,6 +7478,14 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
} }
} }
static bool io_cancel_task_cb(struct io_wq_work *work, void *data)
{
struct io_kiocb *req = container_of(work, struct io_kiocb, work);
struct task_struct *task = data;
return req->task == task;
}
static int io_uring_flush(struct file *file, void *data) static int io_uring_flush(struct file *file, void *data)
{ {
struct io_ring_ctx *ctx = file->private_data; struct io_ring_ctx *ctx = file->private_data;
...@@ -7433,7 +7496,7 @@ static int io_uring_flush(struct file *file, void *data) ...@@ -7433,7 +7496,7 @@ static int io_uring_flush(struct file *file, void *data)
* If the task is going away, cancel work it may have pending * If the task is going away, cancel work it may have pending
*/ */
if (fatal_signal_pending(current) || (current->flags & PF_EXITING)) if (fatal_signal_pending(current) || (current->flags & PF_EXITING))
io_wq_cancel_pid(ctx->io_wq, task_pid_vnr(current)); io_wq_cancel_cb(ctx->io_wq, io_cancel_task_cb, current, true);
return 0; return 0;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment