Commit 7d723065 authored by Jens Axboe's avatar Jens Axboe

io_wq: add get/put_work handlers to io_wq_create()

For cancellation, we need to ensure that the work item stays valid for
as long as ->cur_work is valid. Right now we can't safely dereference
the work item even under the wqe->lock, because while the ->cur_work
pointer will remain valid, the work could be completing and be freed
in parallel.

Only invoke ->get/put_work() on items we know that the caller queued
themselves. Add IO_WQ_WORK_INTERNAL for io-wq to use, which is needed
when we're queueing a flush item, for instance.
Signed-off-by: default avatarJens Axboe <axboe@kernel.dk>
parent 15dff286
...@@ -106,6 +106,9 @@ struct io_wq { ...@@ -106,6 +106,9 @@ struct io_wq {
unsigned long state; unsigned long state;
unsigned nr_wqes; unsigned nr_wqes;
get_work_fn *get_work;
put_work_fn *put_work;
struct task_struct *manager; struct task_struct *manager;
struct user_struct *user; struct user_struct *user;
struct mm_struct *mm; struct mm_struct *mm;
...@@ -392,7 +395,7 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash) ...@@ -392,7 +395,7 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
static void io_worker_handle_work(struct io_worker *worker) static void io_worker_handle_work(struct io_worker *worker)
__releases(wqe->lock) __releases(wqe->lock)
{ {
struct io_wq_work *work, *old_work; struct io_wq_work *work, *old_work = NULL, *put_work = NULL;
struct io_wqe *wqe = worker->wqe; struct io_wqe *wqe = worker->wqe;
struct io_wq *wq = wqe->wq; struct io_wq *wq = wqe->wq;
...@@ -424,6 +427,8 @@ static void io_worker_handle_work(struct io_worker *worker) ...@@ -424,6 +427,8 @@ static void io_worker_handle_work(struct io_worker *worker)
wqe->flags |= IO_WQE_FLAG_STALLED; wqe->flags |= IO_WQE_FLAG_STALLED;
spin_unlock_irq(&wqe->lock); spin_unlock_irq(&wqe->lock);
if (put_work && wq->put_work)
wq->put_work(old_work);
if (!work) if (!work)
break; break;
next: next:
...@@ -444,6 +449,11 @@ static void io_worker_handle_work(struct io_worker *worker) ...@@ -444,6 +449,11 @@ static void io_worker_handle_work(struct io_worker *worker)
if (worker->mm) if (worker->mm)
work->flags |= IO_WQ_WORK_HAS_MM; work->flags |= IO_WQ_WORK_HAS_MM;
if (wq->get_work && !(work->flags & IO_WQ_WORK_INTERNAL)) {
put_work = work;
wq->get_work(work);
}
old_work = work; old_work = work;
work->func(&work); work->func(&work);
...@@ -455,6 +465,12 @@ static void io_worker_handle_work(struct io_worker *worker) ...@@ -455,6 +465,12 @@ static void io_worker_handle_work(struct io_worker *worker)
} }
if (work && work != old_work) { if (work && work != old_work) {
spin_unlock_irq(&wqe->lock); spin_unlock_irq(&wqe->lock);
if (put_work && wq->put_work) {
wq->put_work(put_work);
put_work = NULL;
}
/* dependent work not hashed */ /* dependent work not hashed */
hash = -1U; hash = -1U;
goto next; goto next;
...@@ -950,13 +966,15 @@ void io_wq_flush(struct io_wq *wq) ...@@ -950,13 +966,15 @@ void io_wq_flush(struct io_wq *wq)
init_completion(&data.done); init_completion(&data.done);
INIT_IO_WORK(&data.work, io_wq_flush_func); INIT_IO_WORK(&data.work, io_wq_flush_func);
data.work.flags |= IO_WQ_WORK_INTERNAL;
io_wqe_enqueue(wqe, &data.work); io_wqe_enqueue(wqe, &data.work);
wait_for_completion(&data.done); wait_for_completion(&data.done);
} }
} }
struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm, struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
struct user_struct *user) struct user_struct *user, get_work_fn *get_work,
put_work_fn *put_work)
{ {
int ret = -ENOMEM, i, node; int ret = -ENOMEM, i, node;
struct io_wq *wq; struct io_wq *wq;
...@@ -972,6 +990,9 @@ struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm, ...@@ -972,6 +990,9 @@ struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
return ERR_PTR(-ENOMEM); return ERR_PTR(-ENOMEM);
} }
wq->get_work = get_work;
wq->put_work = put_work;
/* caller must already hold a reference to this */ /* caller must already hold a reference to this */
wq->user = user; wq->user = user;
......
...@@ -10,6 +10,7 @@ enum { ...@@ -10,6 +10,7 @@ enum {
IO_WQ_WORK_NEEDS_USER = 8, IO_WQ_WORK_NEEDS_USER = 8,
IO_WQ_WORK_NEEDS_FILES = 16, IO_WQ_WORK_NEEDS_FILES = 16,
IO_WQ_WORK_UNBOUND = 32, IO_WQ_WORK_UNBOUND = 32,
IO_WQ_WORK_INTERNAL = 64,
IO_WQ_HASH_SHIFT = 24, /* upper 8 bits are used for hash key */ IO_WQ_HASH_SHIFT = 24, /* upper 8 bits are used for hash key */
}; };
...@@ -34,8 +35,12 @@ struct io_wq_work { ...@@ -34,8 +35,12 @@ struct io_wq_work {
(work)->files = NULL; \ (work)->files = NULL; \
} while (0) \ } while (0) \
typedef void (get_work_fn)(struct io_wq_work *);
typedef void (put_work_fn)(struct io_wq_work *);
struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm, struct io_wq *io_wq_create(unsigned bounded, struct mm_struct *mm,
struct user_struct *user); struct user_struct *user,
get_work_fn *get_work, put_work_fn *put_work);
void io_wq_destroy(struct io_wq *wq); void io_wq_destroy(struct io_wq *wq);
void io_wq_enqueue(struct io_wq *wq, struct io_wq_work *work); void io_wq_enqueue(struct io_wq *wq, struct io_wq_work *work);
......
...@@ -3822,6 +3822,20 @@ static int io_sqe_files_update(struct io_ring_ctx *ctx, void __user *arg, ...@@ -3822,6 +3822,20 @@ static int io_sqe_files_update(struct io_ring_ctx *ctx, void __user *arg,
return done ? done : err; return done ? done : err;
} }
static void io_put_work(struct io_wq_work *work)
{
struct io_kiocb *req = container_of(work, struct io_kiocb, work);
io_put_req(req);
}
static void io_get_work(struct io_wq_work *work)
{
struct io_kiocb *req = container_of(work, struct io_kiocb, work);
refcount_inc(&req->refs);
}
static int io_sq_offload_start(struct io_ring_ctx *ctx, static int io_sq_offload_start(struct io_ring_ctx *ctx,
struct io_uring_params *p) struct io_uring_params *p)
{ {
...@@ -3871,7 +3885,8 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx, ...@@ -3871,7 +3885,8 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,
/* Do QD, or 4 * CPUS, whatever is smallest */ /* Do QD, or 4 * CPUS, whatever is smallest */
concurrency = min(ctx->sq_entries, 4 * num_online_cpus()); concurrency = min(ctx->sq_entries, 4 * num_online_cpus());
ctx->io_wq = io_wq_create(concurrency, ctx->sqo_mm, ctx->user); ctx->io_wq = io_wq_create(concurrency, ctx->sqo_mm, ctx->user,
io_get_work, io_put_work);
if (IS_ERR(ctx->io_wq)) { if (IS_ERR(ctx->io_wq)) {
ret = PTR_ERR(ctx->io_wq); ret = PTR_ERR(ctx->io_wq);
ctx->io_wq = NULL; ctx->io_wq = NULL;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment