summaryrefslogtreecommitdiffstats
path: root/fs/io_uring.c
diff options
context:
space:
mode:
Diffstat (limited to 'fs/io_uring.c')
-rw-r--r--fs/io_uring.c101
1 files changed, 94 insertions, 7 deletions
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 87a4b727fe1c..bfc8fcd93504 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -721,6 +721,11 @@ struct async_poll {
struct io_poll_iocb *double_poll;
};
+struct io_task_work {
+ struct io_wq_work_node node;
+ task_work_func_t func;
+};
+
/*
* NOTE! Each of the iocb union members has the file pointer
* as the first entry in their struct definition. So you can
@@ -779,7 +784,10 @@ struct io_kiocb {
* 2. to track reqs with ->files (see io_op_def::file_table)
*/
struct list_head inflight_entry;
- struct callback_head task_work;
+ union {
+ struct io_task_work io_task_work;
+ struct callback_head task_work;
+ };
/* for polled requests, i.e. IORING_OP_POLL_ADD and async armed poll */
struct hlist_node hash_node;
struct async_poll *apoll;
@@ -2129,6 +2137,81 @@ static inline struct io_kiocb *io_req_find_next(struct io_kiocb *req)
return __io_req_find_next(req);
}
+static bool __tctx_task_work(struct io_uring_task *tctx)
+{
+ struct io_wq_work_list list;
+ struct io_wq_work_node *node;
+
+ if (wq_list_empty(&tctx->task_list))
+ return false;
+
+ spin_lock(&tctx->task_lock);
+ list = tctx->task_list;
+ INIT_WQ_LIST(&tctx->task_list);
+ spin_unlock(&tctx->task_lock);
+
+ node = list.first;
+ while (node) {
+ struct io_wq_work_node *next = node->next;
+ struct io_kiocb *req;
+
+ req = container_of(node, struct io_kiocb, io_task_work.node);
+ req->task_work.func(&req->task_work);
+ node = next;
+ }
+
+ return list.first != NULL;
+}
+
+static void tctx_task_work(struct callback_head *cb)
+{
+ struct io_uring_task *tctx = container_of(cb, struct io_uring_task, task_work);
+
+ while (__tctx_task_work(tctx))
+ cond_resched();
+
+ clear_bit(0, &tctx->task_state);
+}
+
+static int io_task_work_add(struct task_struct *tsk, struct io_kiocb *req,
+ enum task_work_notify_mode notify)
+{
+ struct io_uring_task *tctx = tsk->io_uring;
+ struct io_wq_work_node *node, *prev;
+ int ret;
+
+ WARN_ON_ONCE(!tctx);
+
+ spin_lock(&tctx->task_lock);
+ wq_list_add_tail(&req->io_task_work.node, &tctx->task_list);
+ spin_unlock(&tctx->task_lock);
+
+ /* task_work already pending, we're done */
+ if (test_bit(0, &tctx->task_state) ||
+ test_and_set_bit(0, &tctx->task_state))
+ return 0;
+
+ if (!task_work_add(tsk, &tctx->task_work, notify))
+ return 0;
+
+ /*
+ * Slow path - we failed, find and delete work. if the work is not
+ * in the list, it got run and we're fine.
+ */
+ ret = 0;
+ spin_lock(&tctx->task_lock);
+ wq_list_for_each(node, prev, &tctx->task_list) {
+ if (&req->io_task_work.node == node) {
+ wq_list_del(&tctx->task_list, node, prev);
+ ret = 1;
+ break;
+ }
+ }
+ spin_unlock(&tctx->task_lock);
+ clear_bit(0, &tctx->task_state);
+ return ret;
+}
+
static int io_req_task_work_add(struct io_kiocb *req)
{
struct task_struct *tsk = req->task;
@@ -2149,7 +2232,7 @@ static int io_req_task_work_add(struct io_kiocb *req)
if (!(ctx->flags & IORING_SETUP_SQPOLL))
notify = TWA_SIGNAL;
- ret = task_work_add(tsk, &req->task_work, notify);
+ ret = io_task_work_add(tsk, req, notify);
if (!ret)
wake_up_process(tsk);
@@ -2157,7 +2240,7 @@ static int io_req_task_work_add(struct io_kiocb *req)
}
static void io_req_task_work_add_fallback(struct io_kiocb *req,
- void (*cb)(struct callback_head *))
+ task_work_func_t cb)
{
struct task_struct *tsk = io_wq_get_task(req->ctx->io_wq);
@@ -2216,7 +2299,7 @@ static void io_req_task_queue(struct io_kiocb *req)
{
int ret;
- init_task_work(&req->task_work, io_req_task_submit);
+ req->task_work.func = io_req_task_submit;
percpu_ref_get(&req->ctx->refs);
ret = io_req_task_work_add(req);
@@ -2347,7 +2430,7 @@ static void io_free_req_deferred(struct io_kiocb *req)
{
int ret;
- init_task_work(&req->task_work, io_put_req_deferred_cb);
+ req->task_work.func = io_put_req_deferred_cb;
ret = io_req_task_work_add(req);
if (unlikely(ret))
io_req_task_work_add_fallback(req, io_put_req_deferred_cb);
@@ -3392,7 +3475,7 @@ static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode,
req->rw.kiocb.ki_flags &= ~IOCB_WAITQ;
list_del_init(&wait->entry);
- init_task_work(&req->task_work, io_req_task_submit);
+ req->task_work.func = io_req_task_submit;
percpu_ref_get(&req->ctx->refs);
/* submit ref gets dropped, acquire a new one */
@@ -5083,7 +5166,7 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
list_del_init(&poll->wait.entry);
req->result = mask;
- init_task_work(&req->task_work, func);
+ req->task_work.func = func;
percpu_ref_get(&req->ctx->refs);
/*
@@ -8086,6 +8169,10 @@ static int io_uring_alloc_task_context(struct task_struct *task)
io_init_identity(&tctx->__identity);
tctx->identity = &tctx->__identity;
task->io_uring = tctx;
+ spin_lock_init(&tctx->task_lock);
+ INIT_WQ_LIST(&tctx->task_list);
+ tctx->task_state = 0;
+ init_task_work(&tctx->task_work, tctx_task_work);
return 0;
}