Commit a282967c authored by Pavel Begunkov's avatar Pavel Begunkov Committed by Jens Axboe

io_uring: encapsulate task_work state

For task works we're passing around a bool pointer for whether the
current ring is locked or not, let's wrap it in a structure, that
will make it more opaque preventing abuse and will also help us
to pass more info in the future if needed.
Signed-off-by: default avatarPavel Begunkov <asml.silence@gmail.com>
Link: https://lore.kernel.org/r/1ecec9483d58696e248d1bfd52cf62b04442df1d.1679931367.git.asml.silence@gmail.comSigned-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 13bfa6f1
...@@ -367,6 +367,11 @@ struct io_ring_ctx { ...@@ -367,6 +367,11 @@ struct io_ring_ctx {
unsigned evfd_last_cq_tail; unsigned evfd_last_cq_tail;
}; };
struct io_tw_state {
/* ->uring_lock is taken, callbacks can use io_tw_lock to lock it */
bool locked;
};
enum { enum {
REQ_F_FIXED_FILE_BIT = IOSQE_FIXED_FILE_BIT, REQ_F_FIXED_FILE_BIT = IOSQE_FIXED_FILE_BIT,
REQ_F_IO_DRAIN_BIT = IOSQE_IO_DRAIN_BIT, REQ_F_IO_DRAIN_BIT = IOSQE_IO_DRAIN_BIT,
...@@ -473,7 +478,7 @@ enum { ...@@ -473,7 +478,7 @@ enum {
REQ_F_HASH_LOCKED = BIT(REQ_F_HASH_LOCKED_BIT), REQ_F_HASH_LOCKED = BIT(REQ_F_HASH_LOCKED_BIT),
}; };
typedef void (*io_req_tw_func_t)(struct io_kiocb *req, bool *locked); typedef void (*io_req_tw_func_t)(struct io_kiocb *req, struct io_tw_state *ts);
struct io_task_work { struct io_task_work {
struct llist_node node; struct llist_node node;
......
...@@ -247,12 +247,12 @@ static __cold void io_fallback_req_func(struct work_struct *work) ...@@ -247,12 +247,12 @@ static __cold void io_fallback_req_func(struct work_struct *work)
fallback_work.work); fallback_work.work);
struct llist_node *node = llist_del_all(&ctx->fallback_llist); struct llist_node *node = llist_del_all(&ctx->fallback_llist);
struct io_kiocb *req, *tmp; struct io_kiocb *req, *tmp;
bool locked = true; struct io_tw_state ts = { .locked = true, };
mutex_lock(&ctx->uring_lock); mutex_lock(&ctx->uring_lock);
llist_for_each_entry_safe(req, tmp, node, io_task_work.node) llist_for_each_entry_safe(req, tmp, node, io_task_work.node)
req->io_task_work.func(req, &locked); req->io_task_work.func(req, &ts);
if (WARN_ON_ONCE(!locked)) if (WARN_ON_ONCE(!ts.locked))
return; return;
io_submit_flush_completions(ctx); io_submit_flush_completions(ctx);
mutex_unlock(&ctx->uring_lock); mutex_unlock(&ctx->uring_lock);
...@@ -457,7 +457,7 @@ static void io_prep_async_link(struct io_kiocb *req) ...@@ -457,7 +457,7 @@ static void io_prep_async_link(struct io_kiocb *req)
} }
} }
void io_queue_iowq(struct io_kiocb *req, bool *dont_use) void io_queue_iowq(struct io_kiocb *req, struct io_tw_state *ts_dont_use)
{ {
struct io_kiocb *link = io_prep_linked_timeout(req); struct io_kiocb *link = io_prep_linked_timeout(req);
struct io_uring_task *tctx = req->task->io_uring; struct io_uring_task *tctx = req->task->io_uring;
...@@ -1153,22 +1153,23 @@ static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req) ...@@ -1153,22 +1153,23 @@ static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req)
return nxt; return nxt;
} }
static void ctx_flush_and_put(struct io_ring_ctx *ctx, bool *locked) static void ctx_flush_and_put(struct io_ring_ctx *ctx, struct io_tw_state *ts)
{ {
if (!ctx) if (!ctx)
return; return;
if (ctx->flags & IORING_SETUP_TASKRUN_FLAG) if (ctx->flags & IORING_SETUP_TASKRUN_FLAG)
atomic_andnot(IORING_SQ_TASKRUN, &ctx->rings->sq_flags); atomic_andnot(IORING_SQ_TASKRUN, &ctx->rings->sq_flags);
if (*locked) { if (ts->locked) {
io_submit_flush_completions(ctx); io_submit_flush_completions(ctx);
mutex_unlock(&ctx->uring_lock); mutex_unlock(&ctx->uring_lock);
*locked = false; ts->locked = false;
} }
percpu_ref_put(&ctx->refs); percpu_ref_put(&ctx->refs);
} }
static unsigned int handle_tw_list(struct llist_node *node, static unsigned int handle_tw_list(struct llist_node *node,
struct io_ring_ctx **ctx, bool *locked, struct io_ring_ctx **ctx,
struct io_tw_state *ts,
struct llist_node *last) struct llist_node *last)
{ {
unsigned int count = 0; unsigned int count = 0;
...@@ -1181,17 +1182,17 @@ static unsigned int handle_tw_list(struct llist_node *node, ...@@ -1181,17 +1182,17 @@ static unsigned int handle_tw_list(struct llist_node *node,
prefetch(container_of(next, struct io_kiocb, io_task_work.node)); prefetch(container_of(next, struct io_kiocb, io_task_work.node));
if (req->ctx != *ctx) { if (req->ctx != *ctx) {
ctx_flush_and_put(*ctx, locked); ctx_flush_and_put(*ctx, ts);
*ctx = req->ctx; *ctx = req->ctx;
/* if not contended, grab and improve batching */ /* if not contended, grab and improve batching */
*locked = mutex_trylock(&(*ctx)->uring_lock); ts->locked = mutex_trylock(&(*ctx)->uring_lock);
percpu_ref_get(&(*ctx)->refs); percpu_ref_get(&(*ctx)->refs);
} }
req->io_task_work.func(req, locked); req->io_task_work.func(req, ts);
node = next; node = next;
count++; count++;
if (unlikely(need_resched())) { if (unlikely(need_resched())) {
ctx_flush_and_put(*ctx, locked); ctx_flush_and_put(*ctx, ts);
*ctx = NULL; *ctx = NULL;
cond_resched(); cond_resched();
} }
...@@ -1232,7 +1233,7 @@ static inline struct llist_node *io_llist_cmpxchg(struct llist_head *head, ...@@ -1232,7 +1233,7 @@ static inline struct llist_node *io_llist_cmpxchg(struct llist_head *head,
void tctx_task_work(struct callback_head *cb) void tctx_task_work(struct callback_head *cb)
{ {
bool uring_locked = false; struct io_tw_state ts = {};
struct io_ring_ctx *ctx = NULL; struct io_ring_ctx *ctx = NULL;
struct io_uring_task *tctx = container_of(cb, struct io_uring_task, struct io_uring_task *tctx = container_of(cb, struct io_uring_task,
task_work); task_work);
...@@ -1249,12 +1250,12 @@ void tctx_task_work(struct callback_head *cb) ...@@ -1249,12 +1250,12 @@ void tctx_task_work(struct callback_head *cb)
do { do {
loops++; loops++;
node = io_llist_xchg(&tctx->task_list, &fake); node = io_llist_xchg(&tctx->task_list, &fake);
count += handle_tw_list(node, &ctx, &uring_locked, &fake); count += handle_tw_list(node, &ctx, &ts, &fake);
/* skip expensive cmpxchg if there are items in the list */ /* skip expensive cmpxchg if there are items in the list */
if (READ_ONCE(tctx->task_list.first) != &fake) if (READ_ONCE(tctx->task_list.first) != &fake)
continue; continue;
if (uring_locked && !wq_list_empty(&ctx->submit_state.compl_reqs)) { if (ts.locked && !wq_list_empty(&ctx->submit_state.compl_reqs)) {
io_submit_flush_completions(ctx); io_submit_flush_completions(ctx);
if (READ_ONCE(tctx->task_list.first) != &fake) if (READ_ONCE(tctx->task_list.first) != &fake)
continue; continue;
...@@ -1262,7 +1263,7 @@ void tctx_task_work(struct callback_head *cb) ...@@ -1262,7 +1263,7 @@ void tctx_task_work(struct callback_head *cb)
node = io_llist_cmpxchg(&tctx->task_list, &fake, NULL); node = io_llist_cmpxchg(&tctx->task_list, &fake, NULL);
} while (node != &fake); } while (node != &fake);
ctx_flush_and_put(ctx, &uring_locked); ctx_flush_and_put(ctx, &ts);
/* relaxed read is enough as only the task itself sets ->in_cancel */ /* relaxed read is enough as only the task itself sets ->in_cancel */
if (unlikely(atomic_read(&tctx->in_cancel))) if (unlikely(atomic_read(&tctx->in_cancel)))
...@@ -1351,7 +1352,7 @@ static void __cold io_move_task_work_from_local(struct io_ring_ctx *ctx) ...@@ -1351,7 +1352,7 @@ static void __cold io_move_task_work_from_local(struct io_ring_ctx *ctx)
} }
} }
static int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked) static int __io_run_local_work(struct io_ring_ctx *ctx, struct io_tw_state *ts)
{ {
struct llist_node *node; struct llist_node *node;
unsigned int loops = 0; unsigned int loops = 0;
...@@ -1368,7 +1369,7 @@ static int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked) ...@@ -1368,7 +1369,7 @@ static int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked)
struct io_kiocb *req = container_of(node, struct io_kiocb, struct io_kiocb *req = container_of(node, struct io_kiocb,
io_task_work.node); io_task_work.node);
prefetch(container_of(next, struct io_kiocb, io_task_work.node)); prefetch(container_of(next, struct io_kiocb, io_task_work.node));
req->io_task_work.func(req, locked); req->io_task_work.func(req, ts);
ret++; ret++;
node = next; node = next;
} }
...@@ -1376,7 +1377,7 @@ static int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked) ...@@ -1376,7 +1377,7 @@ static int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked)
if (!llist_empty(&ctx->work_llist)) if (!llist_empty(&ctx->work_llist))
goto again; goto again;
if (*locked) { if (ts->locked) {
io_submit_flush_completions(ctx); io_submit_flush_completions(ctx);
if (!llist_empty(&ctx->work_llist)) if (!llist_empty(&ctx->work_llist))
goto again; goto again;
...@@ -1387,46 +1388,46 @@ static int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked) ...@@ -1387,46 +1388,46 @@ static int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked)
static inline int io_run_local_work_locked(struct io_ring_ctx *ctx) static inline int io_run_local_work_locked(struct io_ring_ctx *ctx)
{ {
bool locked; struct io_tw_state ts = { .locked = true, };
int ret; int ret;
if (llist_empty(&ctx->work_llist)) if (llist_empty(&ctx->work_llist))
return 0; return 0;
locked = true; ret = __io_run_local_work(ctx, &ts);
ret = __io_run_local_work(ctx, &locked);
/* shouldn't happen! */ /* shouldn't happen! */
if (WARN_ON_ONCE(!locked)) if (WARN_ON_ONCE(!ts.locked))
mutex_lock(&ctx->uring_lock); mutex_lock(&ctx->uring_lock);
return ret; return ret;
} }
static int io_run_local_work(struct io_ring_ctx *ctx) static int io_run_local_work(struct io_ring_ctx *ctx)
{ {
bool locked = mutex_trylock(&ctx->uring_lock); struct io_tw_state ts = {};
int ret; int ret;
ret = __io_run_local_work(ctx, &locked); ts.locked = mutex_trylock(&ctx->uring_lock);
if (locked) ret = __io_run_local_work(ctx, &ts);
if (ts.locked)
mutex_unlock(&ctx->uring_lock); mutex_unlock(&ctx->uring_lock);
return ret; return ret;
} }
static void io_req_task_cancel(struct io_kiocb *req, bool *locked) static void io_req_task_cancel(struct io_kiocb *req, struct io_tw_state *ts)
{ {
io_tw_lock(req->ctx, locked); io_tw_lock(req->ctx, ts);
io_req_defer_failed(req, req->cqe.res); io_req_defer_failed(req, req->cqe.res);
} }
void io_req_task_submit(struct io_kiocb *req, bool *locked) void io_req_task_submit(struct io_kiocb *req, struct io_tw_state *ts)
{ {
io_tw_lock(req->ctx, locked); io_tw_lock(req->ctx, ts);
/* req->task == current here, checking PF_EXITING is safe */ /* req->task == current here, checking PF_EXITING is safe */
if (unlikely(req->task->flags & PF_EXITING)) if (unlikely(req->task->flags & PF_EXITING))
io_req_defer_failed(req, -EFAULT); io_req_defer_failed(req, -EFAULT);
else if (req->flags & REQ_F_FORCE_ASYNC) else if (req->flags & REQ_F_FORCE_ASYNC)
io_queue_iowq(req, locked); io_queue_iowq(req, ts);
else else
io_queue_sqe(req); io_queue_sqe(req);
} }
...@@ -1652,9 +1653,9 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, long min) ...@@ -1652,9 +1653,9 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, long min)
return ret; return ret;
} }
void io_req_task_complete(struct io_kiocb *req, bool *locked) void io_req_task_complete(struct io_kiocb *req, struct io_tw_state *ts)
{ {
if (*locked) if (ts->locked)
io_req_complete_defer(req); io_req_complete_defer(req);
else else
io_req_complete_post(req, IO_URING_F_UNLOCKED); io_req_complete_post(req, IO_URING_F_UNLOCKED);
...@@ -1933,9 +1934,9 @@ static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags) ...@@ -1933,9 +1934,9 @@ static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags)
return 0; return 0;
} }
int io_poll_issue(struct io_kiocb *req, bool *locked) int io_poll_issue(struct io_kiocb *req, struct io_tw_state *ts)
{ {
io_tw_lock(req->ctx, locked); io_tw_lock(req->ctx, ts);
return io_issue_sqe(req, IO_URING_F_NONBLOCK|IO_URING_F_MULTISHOT| return io_issue_sqe(req, IO_URING_F_NONBLOCK|IO_URING_F_MULTISHOT|
IO_URING_F_COMPLETE_DEFER); IO_URING_F_COMPLETE_DEFER);
} }
......
...@@ -52,16 +52,16 @@ void __io_req_task_work_add(struct io_kiocb *req, bool allow_local); ...@@ -52,16 +52,16 @@ void __io_req_task_work_add(struct io_kiocb *req, bool allow_local);
bool io_is_uring_fops(struct file *file); bool io_is_uring_fops(struct file *file);
bool io_alloc_async_data(struct io_kiocb *req); bool io_alloc_async_data(struct io_kiocb *req);
void io_req_task_queue(struct io_kiocb *req); void io_req_task_queue(struct io_kiocb *req);
void io_queue_iowq(struct io_kiocb *req, bool *dont_use); void io_queue_iowq(struct io_kiocb *req, struct io_tw_state *ts_dont_use);
void io_req_task_complete(struct io_kiocb *req, bool *locked); void io_req_task_complete(struct io_kiocb *req, struct io_tw_state *ts);
void io_req_task_queue_fail(struct io_kiocb *req, int ret); void io_req_task_queue_fail(struct io_kiocb *req, int ret);
void io_req_task_submit(struct io_kiocb *req, bool *locked); void io_req_task_submit(struct io_kiocb *req, struct io_tw_state *ts);
void tctx_task_work(struct callback_head *cb); void tctx_task_work(struct callback_head *cb);
__cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd); __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd);
int io_uring_alloc_task_context(struct task_struct *task, int io_uring_alloc_task_context(struct task_struct *task,
struct io_ring_ctx *ctx); struct io_ring_ctx *ctx);
int io_poll_issue(struct io_kiocb *req, bool *locked); int io_poll_issue(struct io_kiocb *req, struct io_tw_state *ts);
int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr); int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr);
int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin); int io_do_iopoll(struct io_ring_ctx *ctx, bool force_nonspin);
void io_free_batch_list(struct io_ring_ctx *ctx, struct io_wq_work_node *node); void io_free_batch_list(struct io_ring_ctx *ctx, struct io_wq_work_node *node);
...@@ -299,11 +299,11 @@ static inline bool io_task_work_pending(struct io_ring_ctx *ctx) ...@@ -299,11 +299,11 @@ static inline bool io_task_work_pending(struct io_ring_ctx *ctx)
return task_work_pending(current) || !wq_list_empty(&ctx->work_llist); return task_work_pending(current) || !wq_list_empty(&ctx->work_llist);
} }
static inline void io_tw_lock(struct io_ring_ctx *ctx, bool *locked) static inline void io_tw_lock(struct io_ring_ctx *ctx, struct io_tw_state *ts)
{ {
if (!*locked) { if (!ts->locked) {
mutex_lock(&ctx->uring_lock); mutex_lock(&ctx->uring_lock);
*locked = true; ts->locked = true;
} }
} }
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "notif.h" #include "notif.h"
#include "rsrc.h" #include "rsrc.h"
static void io_notif_complete_tw_ext(struct io_kiocb *notif, bool *locked) static void io_notif_complete_tw_ext(struct io_kiocb *notif, struct io_tw_state *ts)
{ {
struct io_notif_data *nd = io_notif_to_data(notif); struct io_notif_data *nd = io_notif_to_data(notif);
struct io_ring_ctx *ctx = notif->ctx; struct io_ring_ctx *ctx = notif->ctx;
...@@ -21,7 +21,7 @@ static void io_notif_complete_tw_ext(struct io_kiocb *notif, bool *locked) ...@@ -21,7 +21,7 @@ static void io_notif_complete_tw_ext(struct io_kiocb *notif, bool *locked)
__io_unaccount_mem(ctx->user, nd->account_pages); __io_unaccount_mem(ctx->user, nd->account_pages);
nd->account_pages = 0; nd->account_pages = 0;
} }
io_req_task_complete(notif, locked); io_req_task_complete(notif, ts);
} }
static void io_tx_ubuf_callback(struct sk_buff *skb, struct ubuf_info *uarg, static void io_tx_ubuf_callback(struct sk_buff *skb, struct ubuf_info *uarg,
......
...@@ -148,7 +148,7 @@ static void io_poll_req_insert_locked(struct io_kiocb *req) ...@@ -148,7 +148,7 @@ static void io_poll_req_insert_locked(struct io_kiocb *req)
hlist_add_head(&req->hash_node, &table->hbs[index].list); hlist_add_head(&req->hash_node, &table->hbs[index].list);
} }
static void io_poll_tw_hash_eject(struct io_kiocb *req, bool *locked) static void io_poll_tw_hash_eject(struct io_kiocb *req, struct io_tw_state *ts)
{ {
struct io_ring_ctx *ctx = req->ctx; struct io_ring_ctx *ctx = req->ctx;
...@@ -159,7 +159,7 @@ static void io_poll_tw_hash_eject(struct io_kiocb *req, bool *locked) ...@@ -159,7 +159,7 @@ static void io_poll_tw_hash_eject(struct io_kiocb *req, bool *locked)
* already grabbed the mutex for us, but there is a chance it * already grabbed the mutex for us, but there is a chance it
* failed. * failed.
*/ */
io_tw_lock(ctx, locked); io_tw_lock(ctx, ts);
hash_del(&req->hash_node); hash_del(&req->hash_node);
req->flags &= ~REQ_F_HASH_LOCKED; req->flags &= ~REQ_F_HASH_LOCKED;
} else { } else {
...@@ -238,7 +238,7 @@ enum { ...@@ -238,7 +238,7 @@ enum {
* req->cqe.res. IOU_POLL_REMOVE_POLL_USE_RES indicates to remove multishot * req->cqe.res. IOU_POLL_REMOVE_POLL_USE_RES indicates to remove multishot
* poll and that the result is stored in req->cqe. * poll and that the result is stored in req->cqe.
*/ */
static int io_poll_check_events(struct io_kiocb *req, bool *locked) static int io_poll_check_events(struct io_kiocb *req, struct io_tw_state *ts)
{ {
int v; int v;
...@@ -300,13 +300,13 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked) ...@@ -300,13 +300,13 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked)
__poll_t mask = mangle_poll(req->cqe.res & __poll_t mask = mangle_poll(req->cqe.res &
req->apoll_events); req->apoll_events);
if (!io_aux_cqe(req->ctx, *locked, req->cqe.user_data, if (!io_aux_cqe(req->ctx, ts->locked, req->cqe.user_data,
mask, IORING_CQE_F_MORE, false)) { mask, IORING_CQE_F_MORE, false)) {
io_req_set_res(req, mask, 0); io_req_set_res(req, mask, 0);
return IOU_POLL_REMOVE_POLL_USE_RES; return IOU_POLL_REMOVE_POLL_USE_RES;
} }
} else { } else {
int ret = io_poll_issue(req, locked); int ret = io_poll_issue(req, ts);
if (ret == IOU_STOP_MULTISHOT) if (ret == IOU_STOP_MULTISHOT)
return IOU_POLL_REMOVE_POLL_USE_RES; return IOU_POLL_REMOVE_POLL_USE_RES;
if (ret < 0) if (ret < 0)
...@@ -326,15 +326,15 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked) ...@@ -326,15 +326,15 @@ static int io_poll_check_events(struct io_kiocb *req, bool *locked)
return IOU_POLL_NO_ACTION; return IOU_POLL_NO_ACTION;
} }
static void io_poll_task_func(struct io_kiocb *req, bool *locked) static void io_poll_task_func(struct io_kiocb *req, struct io_tw_state *ts)
{ {
int ret; int ret;
ret = io_poll_check_events(req, locked); ret = io_poll_check_events(req, ts);
if (ret == IOU_POLL_NO_ACTION) if (ret == IOU_POLL_NO_ACTION)
return; return;
io_poll_remove_entries(req); io_poll_remove_entries(req);
io_poll_tw_hash_eject(req, locked); io_poll_tw_hash_eject(req, ts);
if (req->opcode == IORING_OP_POLL_ADD) { if (req->opcode == IORING_OP_POLL_ADD) {
if (ret == IOU_POLL_DONE) { if (ret == IOU_POLL_DONE) {
...@@ -343,7 +343,7 @@ static void io_poll_task_func(struct io_kiocb *req, bool *locked) ...@@ -343,7 +343,7 @@ static void io_poll_task_func(struct io_kiocb *req, bool *locked)
poll = io_kiocb_to_cmd(req, struct io_poll); poll = io_kiocb_to_cmd(req, struct io_poll);
req->cqe.res = mangle_poll(req->cqe.res & poll->events); req->cqe.res = mangle_poll(req->cqe.res & poll->events);
} else if (ret == IOU_POLL_REISSUE) { } else if (ret == IOU_POLL_REISSUE) {
io_req_task_submit(req, locked); io_req_task_submit(req, ts);
return; return;
} else if (ret != IOU_POLL_REMOVE_POLL_USE_RES) { } else if (ret != IOU_POLL_REMOVE_POLL_USE_RES) {
req->cqe.res = ret; req->cqe.res = ret;
...@@ -351,14 +351,14 @@ static void io_poll_task_func(struct io_kiocb *req, bool *locked) ...@@ -351,14 +351,14 @@ static void io_poll_task_func(struct io_kiocb *req, bool *locked)
} }
io_req_set_res(req, req->cqe.res, 0); io_req_set_res(req, req->cqe.res, 0);
io_req_task_complete(req, locked); io_req_task_complete(req, ts);
} else { } else {
io_tw_lock(req->ctx, locked); io_tw_lock(req->ctx, ts);
if (ret == IOU_POLL_REMOVE_POLL_USE_RES) if (ret == IOU_POLL_REMOVE_POLL_USE_RES)
io_req_task_complete(req, locked); io_req_task_complete(req, ts);
else if (ret == IOU_POLL_DONE || ret == IOU_POLL_REISSUE) else if (ret == IOU_POLL_DONE || ret == IOU_POLL_REISSUE)
io_req_task_submit(req, locked); io_req_task_submit(req, ts);
else else
io_req_defer_failed(req, ret); io_req_defer_failed(req, ret);
} }
...@@ -977,7 +977,7 @@ int io_poll_remove(struct io_kiocb *req, unsigned int issue_flags) ...@@ -977,7 +977,7 @@ int io_poll_remove(struct io_kiocb *req, unsigned int issue_flags)
struct io_hash_bucket *bucket; struct io_hash_bucket *bucket;
struct io_kiocb *preq; struct io_kiocb *preq;
int ret2, ret = 0; int ret2, ret = 0;
bool locked; struct io_tw_state ts = {};
preq = io_poll_find(ctx, true, &cd, &ctx->cancel_table, &bucket); preq = io_poll_find(ctx, true, &cd, &ctx->cancel_table, &bucket);
ret2 = io_poll_disarm(preq); ret2 = io_poll_disarm(preq);
...@@ -1027,8 +1027,8 @@ int io_poll_remove(struct io_kiocb *req, unsigned int issue_flags) ...@@ -1027,8 +1027,8 @@ int io_poll_remove(struct io_kiocb *req, unsigned int issue_flags)
req_set_fail(preq); req_set_fail(preq);
io_req_set_res(preq, -ECANCELED, 0); io_req_set_res(preq, -ECANCELED, 0);
locked = !(issue_flags & IO_URING_F_UNLOCKED); ts.locked = !(issue_flags & IO_URING_F_UNLOCKED);
io_req_task_complete(preq, &locked); io_req_task_complete(preq, &ts);
out: out:
if (ret < 0) { if (ret < 0) {
req_set_fail(req); req_set_fail(req);
......
...@@ -283,16 +283,16 @@ static inline int io_fixup_rw_res(struct io_kiocb *req, long res) ...@@ -283,16 +283,16 @@ static inline int io_fixup_rw_res(struct io_kiocb *req, long res)
return res; return res;
} }
static void io_req_rw_complete(struct io_kiocb *req, bool *locked) static void io_req_rw_complete(struct io_kiocb *req, struct io_tw_state *ts)
{ {
io_req_io_end(req); io_req_io_end(req);
if (req->flags & (REQ_F_BUFFER_SELECTED|REQ_F_BUFFER_RING)) { if (req->flags & (REQ_F_BUFFER_SELECTED|REQ_F_BUFFER_RING)) {
unsigned issue_flags = *locked ? 0 : IO_URING_F_UNLOCKED; unsigned issue_flags = ts->locked ? 0 : IO_URING_F_UNLOCKED;
req->cqe.flags |= io_put_kbuf(req, issue_flags); req->cqe.flags |= io_put_kbuf(req, issue_flags);
} }
io_req_task_complete(req, locked); io_req_task_complete(req, ts);
} }
static void io_complete_rw(struct kiocb *kiocb, long res) static void io_complete_rw(struct kiocb *kiocb, long res)
......
...@@ -101,9 +101,9 @@ __cold void io_flush_timeouts(struct io_ring_ctx *ctx) ...@@ -101,9 +101,9 @@ __cold void io_flush_timeouts(struct io_ring_ctx *ctx)
spin_unlock_irq(&ctx->timeout_lock); spin_unlock_irq(&ctx->timeout_lock);
} }
static void io_req_tw_fail_links(struct io_kiocb *link, bool *locked) static void io_req_tw_fail_links(struct io_kiocb *link, struct io_tw_state *ts)
{ {
io_tw_lock(link->ctx, locked); io_tw_lock(link->ctx, ts);
while (link) { while (link) {
struct io_kiocb *nxt = link->link; struct io_kiocb *nxt = link->link;
long res = -ECANCELED; long res = -ECANCELED;
...@@ -112,7 +112,7 @@ static void io_req_tw_fail_links(struct io_kiocb *link, bool *locked) ...@@ -112,7 +112,7 @@ static void io_req_tw_fail_links(struct io_kiocb *link, bool *locked)
res = link->cqe.res; res = link->cqe.res;
link->link = NULL; link->link = NULL;
io_req_set_res(link, res, 0); io_req_set_res(link, res, 0);
io_req_task_complete(link, locked); io_req_task_complete(link, ts);
link = nxt; link = nxt;
} }
} }
...@@ -265,9 +265,9 @@ int io_timeout_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd) ...@@ -265,9 +265,9 @@ int io_timeout_cancel(struct io_ring_ctx *ctx, struct io_cancel_data *cd)
return 0; return 0;
} }
static void io_req_task_link_timeout(struct io_kiocb *req, bool *locked) static void io_req_task_link_timeout(struct io_kiocb *req, struct io_tw_state *ts)
{ {
unsigned issue_flags = *locked ? 0 : IO_URING_F_UNLOCKED; unsigned issue_flags = ts->locked ? 0 : IO_URING_F_UNLOCKED;
struct io_timeout *timeout = io_kiocb_to_cmd(req, struct io_timeout); struct io_timeout *timeout = io_kiocb_to_cmd(req, struct io_timeout);
struct io_kiocb *prev = timeout->prev; struct io_kiocb *prev = timeout->prev;
int ret = -ENOENT; int ret = -ENOENT;
...@@ -282,11 +282,11 @@ static void io_req_task_link_timeout(struct io_kiocb *req, bool *locked) ...@@ -282,11 +282,11 @@ static void io_req_task_link_timeout(struct io_kiocb *req, bool *locked)
ret = io_try_cancel(req->task->io_uring, &cd, issue_flags); ret = io_try_cancel(req->task->io_uring, &cd, issue_flags);
} }
io_req_set_res(req, ret ?: -ETIME, 0); io_req_set_res(req, ret ?: -ETIME, 0);
io_req_task_complete(req, locked); io_req_task_complete(req, ts);
io_put_req(prev); io_put_req(prev);
} else { } else {
io_req_set_res(req, -ETIME, 0); io_req_set_res(req, -ETIME, 0);
io_req_task_complete(req, locked); io_req_task_complete(req, ts);
} }
} }
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
#include "rsrc.h" #include "rsrc.h"
#include "uring_cmd.h" #include "uring_cmd.h"
static void io_uring_cmd_work(struct io_kiocb *req, bool *locked) static void io_uring_cmd_work(struct io_kiocb *req, struct io_tw_state *ts)
{ {
struct io_uring_cmd *ioucmd = io_kiocb_to_cmd(req, struct io_uring_cmd); struct io_uring_cmd *ioucmd = io_kiocb_to_cmd(req, struct io_uring_cmd);
unsigned issue_flags = *locked ? 0 : IO_URING_F_UNLOCKED; unsigned issue_flags = ts->locked ? 0 : IO_URING_F_UNLOCKED;
ioucmd->task_work_cb(ioucmd, issue_flags); ioucmd->task_work_cb(ioucmd, issue_flags);
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment