Commit a2286a44 authored by Linus Torvalds's avatar Linus Torvalds

Merge tag 'io_uring-5.7-2020-04-17' of git://git.kernel.dk/linux-block

Pull io_uring fixes from Jens Axboe:

 - wrap up the init/setup cleanup (Pavel)

 - fix some issues around deferral sequences (Pavel)

 - fix splice punt check using the wrong struct file member

 - apply poll re-arm logic for pollable retry too

 - pollable retry should honor cancelation

 - fix setup time error handling syzbot reported crash

 - restore work state when poll is canceled

* tag 'io_uring-5.7-2020-04-17' of git://git.kernel.dk/linux-block:
  io_uring: don't count rqs failed after current one
  io_uring: kill already cached timeout.seq_offset
  io_uring: fix cached_sq_head in io_timeout()
  io_uring: only post events in io_poll_remove_all() if we completed some
  io_uring: io_async_task_func() should check and honor cancelation
  io_uring: check for need to re-wait in polled async handling
  io_uring: correct O_NONBLOCK check for splice punt
  io_uring: restore req->work when canceling poll request
  io_uring: move all request init code in one place
  io_uring: keep all sqe->flags in req->flags
  io_uring: early submission req fail code
  io_uring: track mm through current->mm
  io_uring: remove obsolete @mm_fault
parents bf9196d5 31af27c7
...@@ -357,7 +357,6 @@ struct io_timeout_data { ...@@ -357,7 +357,6 @@ struct io_timeout_data {
struct hrtimer timer; struct hrtimer timer;
struct timespec64 ts; struct timespec64 ts;
enum hrtimer_mode mode; enum hrtimer_mode mode;
u32 seq_offset;
}; };
struct io_accept { struct io_accept {
...@@ -385,7 +384,7 @@ struct io_timeout { ...@@ -385,7 +384,7 @@ struct io_timeout {
struct file *file; struct file *file;
u64 addr; u64 addr;
int flags; int flags;
unsigned count; u32 count;
}; };
struct io_rw { struct io_rw {
...@@ -508,6 +507,7 @@ enum { ...@@ -508,6 +507,7 @@ enum {
REQ_F_FORCE_ASYNC_BIT = IOSQE_ASYNC_BIT, REQ_F_FORCE_ASYNC_BIT = IOSQE_ASYNC_BIT,
REQ_F_BUFFER_SELECT_BIT = IOSQE_BUFFER_SELECT_BIT, REQ_F_BUFFER_SELECT_BIT = IOSQE_BUFFER_SELECT_BIT,
REQ_F_LINK_HEAD_BIT,
REQ_F_LINK_NEXT_BIT, REQ_F_LINK_NEXT_BIT,
REQ_F_FAIL_LINK_BIT, REQ_F_FAIL_LINK_BIT,
REQ_F_INFLIGHT_BIT, REQ_F_INFLIGHT_BIT,
...@@ -543,6 +543,8 @@ enum { ...@@ -543,6 +543,8 @@ enum {
/* IOSQE_BUFFER_SELECT */ /* IOSQE_BUFFER_SELECT */
REQ_F_BUFFER_SELECT = BIT(REQ_F_BUFFER_SELECT_BIT), REQ_F_BUFFER_SELECT = BIT(REQ_F_BUFFER_SELECT_BIT),
/* head of a link */
REQ_F_LINK_HEAD = BIT(REQ_F_LINK_HEAD_BIT),
/* already grabbed next link */ /* already grabbed next link */
REQ_F_LINK_NEXT = BIT(REQ_F_LINK_NEXT_BIT), REQ_F_LINK_NEXT = BIT(REQ_F_LINK_NEXT_BIT),
/* fail rest of links */ /* fail rest of links */
...@@ -955,7 +957,7 @@ static inline bool __req_need_defer(struct io_kiocb *req) ...@@ -955,7 +957,7 @@ static inline bool __req_need_defer(struct io_kiocb *req)
{ {
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
return req->sequence != ctx->cached_cq_tail + ctx->cached_sq_dropped return req->sequence != ctx->cached_cq_tail
+ atomic_read(&ctx->cached_cq_overflow); + atomic_read(&ctx->cached_cq_overflow);
} }
...@@ -1437,7 +1439,7 @@ static bool io_link_cancel_timeout(struct io_kiocb *req) ...@@ -1437,7 +1439,7 @@ static bool io_link_cancel_timeout(struct io_kiocb *req)
if (ret != -1) { if (ret != -1) {
io_cqring_fill_event(req, -ECANCELED); io_cqring_fill_event(req, -ECANCELED);
io_commit_cqring(ctx); io_commit_cqring(ctx);
req->flags &= ~REQ_F_LINK; req->flags &= ~REQ_F_LINK_HEAD;
io_put_req(req); io_put_req(req);
return true; return true;
} }
...@@ -1473,7 +1475,7 @@ static void io_req_link_next(struct io_kiocb *req, struct io_kiocb **nxtptr) ...@@ -1473,7 +1475,7 @@ static void io_req_link_next(struct io_kiocb *req, struct io_kiocb **nxtptr)
list_del_init(&req->link_list); list_del_init(&req->link_list);
if (!list_empty(&nxt->link_list)) if (!list_empty(&nxt->link_list))
nxt->flags |= REQ_F_LINK; nxt->flags |= REQ_F_LINK_HEAD;
*nxtptr = nxt; *nxtptr = nxt;
break; break;
} }
...@@ -1484,7 +1486,7 @@ static void io_req_link_next(struct io_kiocb *req, struct io_kiocb **nxtptr) ...@@ -1484,7 +1486,7 @@ static void io_req_link_next(struct io_kiocb *req, struct io_kiocb **nxtptr)
} }
/* /*
* Called if REQ_F_LINK is set, and we fail the head request * Called if REQ_F_LINK_HEAD is set, and we fail the head request
*/ */
static void io_fail_links(struct io_kiocb *req) static void io_fail_links(struct io_kiocb *req)
{ {
...@@ -1517,7 +1519,7 @@ static void io_fail_links(struct io_kiocb *req) ...@@ -1517,7 +1519,7 @@ static void io_fail_links(struct io_kiocb *req)
static void io_req_find_next(struct io_kiocb *req, struct io_kiocb **nxt) static void io_req_find_next(struct io_kiocb *req, struct io_kiocb **nxt)
{ {
if (likely(!(req->flags & REQ_F_LINK))) if (likely(!(req->flags & REQ_F_LINK_HEAD)))
return; return;
/* /*
...@@ -1669,7 +1671,7 @@ static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx) ...@@ -1669,7 +1671,7 @@ static inline unsigned int io_sqring_entries(struct io_ring_ctx *ctx)
static inline bool io_req_multi_free(struct req_batch *rb, struct io_kiocb *req) static inline bool io_req_multi_free(struct req_batch *rb, struct io_kiocb *req)
{ {
if ((req->flags & REQ_F_LINK) || io_is_fallback_req(req)) if ((req->flags & REQ_F_LINK_HEAD) || io_is_fallback_req(req))
return false; return false;
if (!(req->flags & REQ_F_FIXED_FILE) || req->io) if (!(req->flags & REQ_F_FIXED_FILE) || req->io)
...@@ -2562,7 +2564,7 @@ static int io_read(struct io_kiocb *req, bool force_nonblock) ...@@ -2562,7 +2564,7 @@ static int io_read(struct io_kiocb *req, bool force_nonblock)
req->result = 0; req->result = 0;
io_size = ret; io_size = ret;
if (req->flags & REQ_F_LINK) if (req->flags & REQ_F_LINK_HEAD)
req->result = io_size; req->result = io_size;
/* /*
...@@ -2653,7 +2655,7 @@ static int io_write(struct io_kiocb *req, bool force_nonblock) ...@@ -2653,7 +2655,7 @@ static int io_write(struct io_kiocb *req, bool force_nonblock)
req->result = 0; req->result = 0;
io_size = ret; io_size = ret;
if (req->flags & REQ_F_LINK) if (req->flags & REQ_F_LINK_HEAD)
req->result = io_size; req->result = io_size;
/* /*
...@@ -2760,7 +2762,7 @@ static bool io_splice_punt(struct file *file) ...@@ -2760,7 +2762,7 @@ static bool io_splice_punt(struct file *file)
return false; return false;
if (!io_file_supports_async(file)) if (!io_file_supports_async(file))
return true; return true;
return !(file->f_mode & O_NONBLOCK); return !(file->f_flags & O_NONBLOCK);
} }
static int io_splice(struct io_kiocb *req, bool force_nonblock) static int io_splice(struct io_kiocb *req, bool force_nonblock)
...@@ -4153,20 +4155,57 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll, ...@@ -4153,20 +4155,57 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
return 1; return 1;
} }
static bool io_poll_rewait(struct io_kiocb *req, struct io_poll_iocb *poll)
__acquires(&req->ctx->completion_lock)
{
struct io_ring_ctx *ctx = req->ctx;
if (!req->result && !READ_ONCE(poll->canceled)) {
struct poll_table_struct pt = { ._key = poll->events };
req->result = vfs_poll(req->file, &pt) & poll->events;
}
spin_lock_irq(&ctx->completion_lock);
if (!req->result && !READ_ONCE(poll->canceled)) {
add_wait_queue(poll->head, &poll->wait);
return true;
}
return false;
}
static void io_async_task_func(struct callback_head *cb) static void io_async_task_func(struct callback_head *cb)
{ {
struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work); struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work);
struct async_poll *apoll = req->apoll; struct async_poll *apoll = req->apoll;
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
bool canceled;
trace_io_uring_task_run(req->ctx, req->opcode, req->user_data); trace_io_uring_task_run(req->ctx, req->opcode, req->user_data);
WARN_ON_ONCE(!list_empty(&req->apoll->poll.wait.entry)); if (io_poll_rewait(req, &apoll->poll)) {
spin_unlock_irq(&ctx->completion_lock);
return;
}
if (hash_hashed(&req->hash_node)) { if (hash_hashed(&req->hash_node))
spin_lock_irq(&ctx->completion_lock);
hash_del(&req->hash_node); hash_del(&req->hash_node);
canceled = READ_ONCE(apoll->poll.canceled);
if (canceled) {
io_cqring_fill_event(req, -ECANCELED);
io_commit_cqring(ctx);
}
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
if (canceled) {
kfree(apoll);
io_cqring_ev_posted(ctx);
req_set_fail_links(req);
io_put_req(req);
return;
} }
/* restore ->work in case we need to retry again */ /* restore ->work in case we need to retry again */
...@@ -4315,11 +4354,13 @@ static bool __io_poll_remove_one(struct io_kiocb *req, ...@@ -4315,11 +4354,13 @@ static bool __io_poll_remove_one(struct io_kiocb *req,
static bool io_poll_remove_one(struct io_kiocb *req) static bool io_poll_remove_one(struct io_kiocb *req)
{ {
struct async_poll *apoll = NULL;
bool do_complete; bool do_complete;
if (req->opcode == IORING_OP_POLL_ADD) { if (req->opcode == IORING_OP_POLL_ADD) {
do_complete = __io_poll_remove_one(req, &req->poll); do_complete = __io_poll_remove_one(req, &req->poll);
} else { } else {
apoll = req->apoll;
/* non-poll requests have submit ref still */ /* non-poll requests have submit ref still */
do_complete = __io_poll_remove_one(req, &req->apoll->poll); do_complete = __io_poll_remove_one(req, &req->apoll->poll);
if (do_complete) if (do_complete)
...@@ -4328,6 +4369,14 @@ static bool io_poll_remove_one(struct io_kiocb *req) ...@@ -4328,6 +4369,14 @@ static bool io_poll_remove_one(struct io_kiocb *req)
hash_del(&req->hash_node); hash_del(&req->hash_node);
if (apoll) {
/*
* restore ->work because we need to call io_req_work_drop_env.
*/
memcpy(&req->work, &apoll->work, sizeof(req->work));
kfree(apoll);
}
if (do_complete) { if (do_complete) {
io_cqring_fill_event(req, -ECANCELED); io_cqring_fill_event(req, -ECANCELED);
io_commit_cqring(req->ctx); io_commit_cqring(req->ctx);
...@@ -4342,7 +4391,7 @@ static void io_poll_remove_all(struct io_ring_ctx *ctx) ...@@ -4342,7 +4391,7 @@ static void io_poll_remove_all(struct io_ring_ctx *ctx)
{ {
struct hlist_node *tmp; struct hlist_node *tmp;
struct io_kiocb *req; struct io_kiocb *req;
int i; int posted = 0, i;
spin_lock_irq(&ctx->completion_lock); spin_lock_irq(&ctx->completion_lock);
for (i = 0; i < (1U << ctx->cancel_hash_bits); i++) { for (i = 0; i < (1U << ctx->cancel_hash_bits); i++) {
...@@ -4350,10 +4399,11 @@ static void io_poll_remove_all(struct io_ring_ctx *ctx) ...@@ -4350,10 +4399,11 @@ static void io_poll_remove_all(struct io_ring_ctx *ctx)
list = &ctx->cancel_hash[i]; list = &ctx->cancel_hash[i];
hlist_for_each_entry_safe(req, tmp, list, hash_node) hlist_for_each_entry_safe(req, tmp, list, hash_node)
io_poll_remove_one(req); posted += io_poll_remove_one(req);
} }
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
if (posted)
io_cqring_ev_posted(ctx); io_cqring_ev_posted(ctx);
} }
...@@ -4423,18 +4473,11 @@ static void io_poll_task_handler(struct io_kiocb *req, struct io_kiocb **nxt) ...@@ -4423,18 +4473,11 @@ static void io_poll_task_handler(struct io_kiocb *req, struct io_kiocb **nxt)
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
struct io_poll_iocb *poll = &req->poll; struct io_poll_iocb *poll = &req->poll;
if (!req->result && !READ_ONCE(poll->canceled)) { if (io_poll_rewait(req, poll)) {
struct poll_table_struct pt = { ._key = poll->events };
req->result = vfs_poll(req->file, &pt) & poll->events;
}
spin_lock_irq(&ctx->completion_lock);
if (!req->result && !READ_ONCE(poll->canceled)) {
add_wait_queue(poll->head, &poll->wait);
spin_unlock_irq(&ctx->completion_lock); spin_unlock_irq(&ctx->completion_lock);
return; return;
} }
hash_del(&req->hash_node); hash_del(&req->hash_node);
io_poll_complete(req, req->result, 0); io_poll_complete(req, req->result, 0);
req->flags |= REQ_F_COMP_LOCKED; req->flags |= REQ_F_COMP_LOCKED;
...@@ -4665,11 +4708,12 @@ static int io_timeout_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -4665,11 +4708,12 @@ static int io_timeout_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe,
static int io_timeout(struct io_kiocb *req) static int io_timeout(struct io_kiocb *req)
{ {
unsigned count;
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
struct io_timeout_data *data; struct io_timeout_data *data;
struct list_head *entry; struct list_head *entry;
unsigned span = 0; unsigned span = 0;
u32 count = req->timeout.count;
u32 seq = req->sequence;
data = &req->io->timeout; data = &req->io->timeout;
...@@ -4678,7 +4722,6 @@ static int io_timeout(struct io_kiocb *req) ...@@ -4678,7 +4722,6 @@ static int io_timeout(struct io_kiocb *req)
* timeout event to be satisfied. If it isn't set, then this is * timeout event to be satisfied. If it isn't set, then this is
* a pure timeout request, sequence isn't used. * a pure timeout request, sequence isn't used.
*/ */
count = req->timeout.count;
if (!count) { if (!count) {
req->flags |= REQ_F_TIMEOUT_NOSEQ; req->flags |= REQ_F_TIMEOUT_NOSEQ;
spin_lock_irq(&ctx->completion_lock); spin_lock_irq(&ctx->completion_lock);
...@@ -4686,8 +4729,7 @@ static int io_timeout(struct io_kiocb *req) ...@@ -4686,8 +4729,7 @@ static int io_timeout(struct io_kiocb *req)
goto add; goto add;
} }
req->sequence = ctx->cached_sq_head + count - 1; req->sequence = seq + count;
data->seq_offset = count;
/* /*
* Insertion sort, ensuring the first entry in the list is always * Insertion sort, ensuring the first entry in the list is always
...@@ -4696,26 +4738,26 @@ static int io_timeout(struct io_kiocb *req) ...@@ -4696,26 +4738,26 @@ static int io_timeout(struct io_kiocb *req)
spin_lock_irq(&ctx->completion_lock); spin_lock_irq(&ctx->completion_lock);
list_for_each_prev(entry, &ctx->timeout_list) { list_for_each_prev(entry, &ctx->timeout_list) {
struct io_kiocb *nxt = list_entry(entry, struct io_kiocb, list); struct io_kiocb *nxt = list_entry(entry, struct io_kiocb, list);
unsigned nxt_sq_head; unsigned nxt_seq;
long long tmp, tmp_nxt; long long tmp, tmp_nxt;
u32 nxt_offset = nxt->io->timeout.seq_offset; u32 nxt_offset = nxt->timeout.count;
if (nxt->flags & REQ_F_TIMEOUT_NOSEQ) if (nxt->flags & REQ_F_TIMEOUT_NOSEQ)
continue; continue;
/* /*
* Since cached_sq_head + count - 1 can overflow, use type long * Since seq + count can overflow, use type long
* long to store it. * long to store it.
*/ */
tmp = (long long)ctx->cached_sq_head + count - 1; tmp = (long long)seq + count;
nxt_sq_head = nxt->sequence - nxt_offset + 1; nxt_seq = nxt->sequence - nxt_offset;
tmp_nxt = (long long)nxt_sq_head + nxt_offset - 1; tmp_nxt = (long long)nxt_seq + nxt_offset;
/* /*
* cached_sq_head may overflow, and it will never overflow twice * cached_sq_head may overflow, and it will never overflow twice
* once there is some timeout req still be valid. * once there is some timeout req still be valid.
*/ */
if (ctx->cached_sq_head < nxt_sq_head) if (seq < nxt_seq)
tmp += UINT_MAX; tmp += UINT_MAX;
if (tmp > tmp_nxt) if (tmp > tmp_nxt)
...@@ -5476,7 +5518,7 @@ static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req) ...@@ -5476,7 +5518,7 @@ static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req)
{ {
struct io_kiocb *nxt; struct io_kiocb *nxt;
if (!(req->flags & REQ_F_LINK)) if (!(req->flags & REQ_F_LINK_HEAD))
return NULL; return NULL;
/* for polled retry, if flag is set, we already went through here */ /* for polled retry, if flag is set, we already went through here */
if (req->flags & REQ_F_POLLED) if (req->flags & REQ_F_POLLED)
...@@ -5604,54 +5646,11 @@ static inline void io_queue_link_head(struct io_kiocb *req) ...@@ -5604,54 +5646,11 @@ static inline void io_queue_link_head(struct io_kiocb *req)
io_queue_sqe(req, NULL); io_queue_sqe(req, NULL);
} }
#define SQE_VALID_FLAGS (IOSQE_FIXED_FILE|IOSQE_IO_DRAIN|IOSQE_IO_LINK| \ static int io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
IOSQE_IO_HARDLINK | IOSQE_ASYNC | \
IOSQE_BUFFER_SELECT)
static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
struct io_submit_state *state, struct io_kiocb **link) struct io_submit_state *state, struct io_kiocb **link)
{ {
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
unsigned int sqe_flags; int ret;
int ret, id, fd;
sqe_flags = READ_ONCE(sqe->flags);
/* enforce forwards compatibility on users */
if (unlikely(sqe_flags & ~SQE_VALID_FLAGS)) {
ret = -EINVAL;
goto err_req;
}
if ((sqe_flags & IOSQE_BUFFER_SELECT) &&
!io_op_defs[req->opcode].buffer_select) {
ret = -EOPNOTSUPP;
goto err_req;
}
id = READ_ONCE(sqe->personality);
if (id) {
req->work.creds = idr_find(&ctx->personality_idr, id);
if (unlikely(!req->work.creds)) {
ret = -EINVAL;
goto err_req;
}
get_cred(req->work.creds);
}
/* same numerical values with corresponding REQ_F_*, safe to copy */
req->flags |= sqe_flags & (IOSQE_IO_DRAIN | IOSQE_IO_HARDLINK |
IOSQE_ASYNC | IOSQE_FIXED_FILE |
IOSQE_BUFFER_SELECT);
fd = READ_ONCE(sqe->fd);
ret = io_req_set_file(state, req, fd, sqe_flags);
if (unlikely(ret)) {
err_req:
io_cqring_add_event(req, ret);
io_double_put_req(req);
return false;
}
/* /*
* If we already have a head request, queue this one for async * If we already have a head request, queue this one for async
...@@ -5670,42 +5669,39 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -5670,42 +5669,39 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
* next after the link request. The last one is done via * next after the link request. The last one is done via
* drain_next flag to persist the effect across calls. * drain_next flag to persist the effect across calls.
*/ */
if (sqe_flags & IOSQE_IO_DRAIN) { if (req->flags & REQ_F_IO_DRAIN) {
head->flags |= REQ_F_IO_DRAIN; head->flags |= REQ_F_IO_DRAIN;
ctx->drain_next = 1; ctx->drain_next = 1;
} }
if (io_alloc_async_ctx(req)) { if (io_alloc_async_ctx(req))
ret = -EAGAIN; return -EAGAIN;
goto err_req;
}
ret = io_req_defer_prep(req, sqe); ret = io_req_defer_prep(req, sqe);
if (ret) { if (ret) {
/* fail even hard links since we don't submit */ /* fail even hard links since we don't submit */
head->flags |= REQ_F_FAIL_LINK; head->flags |= REQ_F_FAIL_LINK;
goto err_req; return ret;
} }
trace_io_uring_link(ctx, req, head); trace_io_uring_link(ctx, req, head);
list_add_tail(&req->link_list, &head->link_list); list_add_tail(&req->link_list, &head->link_list);
/* last request of a link, enqueue the link */ /* last request of a link, enqueue the link */
if (!(sqe_flags & (IOSQE_IO_LINK|IOSQE_IO_HARDLINK))) { if (!(req->flags & (REQ_F_LINK | REQ_F_HARDLINK))) {
io_queue_link_head(head); io_queue_link_head(head);
*link = NULL; *link = NULL;
} }
} else { } else {
if (unlikely(ctx->drain_next)) { if (unlikely(ctx->drain_next)) {
req->flags |= REQ_F_IO_DRAIN; req->flags |= REQ_F_IO_DRAIN;
req->ctx->drain_next = 0; ctx->drain_next = 0;
} }
if (sqe_flags & (IOSQE_IO_LINK|IOSQE_IO_HARDLINK)) { if (req->flags & (REQ_F_LINK | REQ_F_HARDLINK)) {
req->flags |= REQ_F_LINK; req->flags |= REQ_F_LINK_HEAD;
INIT_LIST_HEAD(&req->link_list); INIT_LIST_HEAD(&req->link_list);
if (io_alloc_async_ctx(req)) { if (io_alloc_async_ctx(req))
ret = -EAGAIN; return -EAGAIN;
goto err_req;
}
ret = io_req_defer_prep(req, sqe); ret = io_req_defer_prep(req, sqe);
if (ret) if (ret)
req->flags |= REQ_F_FAIL_LINK; req->flags |= REQ_F_FAIL_LINK;
...@@ -5715,7 +5711,7 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe, ...@@ -5715,7 +5711,7 @@ static bool io_submit_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe,
} }
} }
return true; return 0;
} }
/* /*
...@@ -5789,15 +5785,23 @@ static inline void io_consume_sqe(struct io_ring_ctx *ctx) ...@@ -5789,15 +5785,23 @@ static inline void io_consume_sqe(struct io_ring_ctx *ctx)
ctx->cached_sq_head++; ctx->cached_sq_head++;
} }
static void io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req, #define SQE_VALID_FLAGS (IOSQE_FIXED_FILE|IOSQE_IO_DRAIN|IOSQE_IO_LINK| \
const struct io_uring_sqe *sqe) IOSQE_IO_HARDLINK | IOSQE_ASYNC | \
IOSQE_BUFFER_SELECT)
static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
const struct io_uring_sqe *sqe,
struct io_submit_state *state, bool async)
{ {
unsigned int sqe_flags;
int id, fd;
/* /*
* All io need record the previous position, if LINK vs DARIN, * All io need record the previous position, if LINK vs DARIN,
* it can be used to mark the position of the first IO in the * it can be used to mark the position of the first IO in the
* link list. * link list.
*/ */
req->sequence = ctx->cached_sq_head; req->sequence = ctx->cached_sq_head - ctx->cached_sq_dropped;
req->opcode = READ_ONCE(sqe->opcode); req->opcode = READ_ONCE(sqe->opcode);
req->user_data = READ_ONCE(sqe->user_data); req->user_data = READ_ONCE(sqe->user_data);
req->io = NULL; req->io = NULL;
...@@ -5808,17 +5812,50 @@ static void io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req, ...@@ -5808,17 +5812,50 @@ static void io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
refcount_set(&req->refs, 2); refcount_set(&req->refs, 2);
req->task = NULL; req->task = NULL;
req->result = 0; req->result = 0;
req->needs_fixed_file = async;
INIT_IO_WORK(&req->work, io_wq_submit_work); INIT_IO_WORK(&req->work, io_wq_submit_work);
if (unlikely(req->opcode >= IORING_OP_LAST))
return -EINVAL;
if (io_op_defs[req->opcode].needs_mm && !current->mm) {
if (unlikely(!mmget_not_zero(ctx->sqo_mm)))
return -EFAULT;
use_mm(ctx->sqo_mm);
}
sqe_flags = READ_ONCE(sqe->flags);
/* enforce forwards compatibility on users */
if (unlikely(sqe_flags & ~SQE_VALID_FLAGS))
return -EINVAL;
if ((sqe_flags & IOSQE_BUFFER_SELECT) &&
!io_op_defs[req->opcode].buffer_select)
return -EOPNOTSUPP;
id = READ_ONCE(sqe->personality);
if (id) {
req->work.creds = idr_find(&ctx->personality_idr, id);
if (unlikely(!req->work.creds))
return -EINVAL;
get_cred(req->work.creds);
}
/* same numerical values with corresponding REQ_F_*, safe to copy */
req->flags |= sqe_flags & (IOSQE_IO_DRAIN | IOSQE_IO_HARDLINK |
IOSQE_ASYNC | IOSQE_FIXED_FILE |
IOSQE_BUFFER_SELECT | IOSQE_IO_LINK);
fd = READ_ONCE(sqe->fd);
return io_req_set_file(state, req, fd, sqe_flags);
} }
static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
struct file *ring_file, int ring_fd, struct file *ring_file, int ring_fd, bool async)
struct mm_struct **mm, bool async)
{ {
struct io_submit_state state, *statep = NULL; struct io_submit_state state, *statep = NULL;
struct io_kiocb *link = NULL; struct io_kiocb *link = NULL;
int i, submitted = 0; int i, submitted = 0;
bool mm_fault = false;
/* if we have a backlog and couldn't flush it all, return BUSY */ /* if we have a backlog and couldn't flush it all, return BUSY */
if (test_bit(0, &ctx->sq_check_overflow)) { if (test_bit(0, &ctx->sq_check_overflow)) {
...@@ -5858,34 +5895,23 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, ...@@ -5858,34 +5895,23 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
break; break;
} }
io_init_req(ctx, req, sqe); err = io_init_req(ctx, req, sqe, statep, async);
io_consume_sqe(ctx); io_consume_sqe(ctx);
/* will complete beyond this point, count as submitted */ /* will complete beyond this point, count as submitted */
submitted++; submitted++;
if (unlikely(req->opcode >= IORING_OP_LAST)) { if (unlikely(err)) {
err = -EINVAL;
fail_req: fail_req:
io_cqring_add_event(req, err); io_cqring_add_event(req, err);
io_double_put_req(req); io_double_put_req(req);
break; break;
} }
if (io_op_defs[req->opcode].needs_mm && !*mm) {
mm_fault = mm_fault || !mmget_not_zero(ctx->sqo_mm);
if (unlikely(mm_fault)) {
err = -EFAULT;
goto fail_req;
}
use_mm(ctx->sqo_mm);
*mm = ctx->sqo_mm;
}
req->needs_fixed_file = async;
trace_io_uring_submit_sqe(ctx, req->opcode, req->user_data, trace_io_uring_submit_sqe(ctx, req->opcode, req->user_data,
true, async); true, async);
if (!io_submit_sqe(req, sqe, statep, &link)) err = io_submit_sqe(req, sqe, statep, &link);
break; if (err)
goto fail_req;
} }
if (unlikely(submitted != nr)) { if (unlikely(submitted != nr)) {
...@@ -5904,10 +5930,19 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, ...@@ -5904,10 +5930,19 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr,
return submitted; return submitted;
} }
static inline void io_sq_thread_drop_mm(struct io_ring_ctx *ctx)
{
struct mm_struct *mm = current->mm;
if (mm) {
unuse_mm(mm);
mmput(mm);
}
}
static int io_sq_thread(void *data) static int io_sq_thread(void *data)
{ {
struct io_ring_ctx *ctx = data; struct io_ring_ctx *ctx = data;
struct mm_struct *cur_mm = NULL;
const struct cred *old_cred; const struct cred *old_cred;
mm_segment_t old_fs; mm_segment_t old_fs;
DEFINE_WAIT(wait); DEFINE_WAIT(wait);
...@@ -5948,11 +5983,7 @@ static int io_sq_thread(void *data) ...@@ -5948,11 +5983,7 @@ static int io_sq_thread(void *data)
* adding ourselves to the waitqueue, as the unuse/drop * adding ourselves to the waitqueue, as the unuse/drop
* may sleep. * may sleep.
*/ */
if (cur_mm) { io_sq_thread_drop_mm(ctx);
unuse_mm(cur_mm);
mmput(cur_mm);
cur_mm = NULL;
}
/* /*
* We're polling. If we're within the defined idle * We're polling. If we're within the defined idle
...@@ -6016,7 +6047,7 @@ static int io_sq_thread(void *data) ...@@ -6016,7 +6047,7 @@ static int io_sq_thread(void *data)
} }
mutex_lock(&ctx->uring_lock); mutex_lock(&ctx->uring_lock);
ret = io_submit_sqes(ctx, to_submit, NULL, -1, &cur_mm, true); ret = io_submit_sqes(ctx, to_submit, NULL, -1, true);
mutex_unlock(&ctx->uring_lock); mutex_unlock(&ctx->uring_lock);
timeout = jiffies + ctx->sq_thread_idle; timeout = jiffies + ctx->sq_thread_idle;
} }
...@@ -6025,10 +6056,7 @@ static int io_sq_thread(void *data) ...@@ -6025,10 +6056,7 @@ static int io_sq_thread(void *data)
task_work_run(); task_work_run();
set_fs(old_fs); set_fs(old_fs);
if (cur_mm) { io_sq_thread_drop_mm(ctx);
unuse_mm(cur_mm);
mmput(cur_mm);
}
revert_creds(old_cred); revert_creds(old_cred);
kthread_parkme(); kthread_parkme();
...@@ -7509,13 +7537,8 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit, ...@@ -7509,13 +7537,8 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
wake_up(&ctx->sqo_wait); wake_up(&ctx->sqo_wait);
submitted = to_submit; submitted = to_submit;
} else if (to_submit) { } else if (to_submit) {
struct mm_struct *cur_mm;
mutex_lock(&ctx->uring_lock); mutex_lock(&ctx->uring_lock);
/* already have mm, so io_submit_sqes() won't try to grab it */ submitted = io_submit_sqes(ctx, to_submit, f.file, fd, false);
cur_mm = ctx->sqo_mm;
submitted = io_submit_sqes(ctx, to_submit, f.file, fd,
&cur_mm, false);
mutex_unlock(&ctx->uring_lock); mutex_unlock(&ctx->uring_lock);
if (submitted != to_submit) if (submitted != to_submit)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment