Commit d53c6659 authored by Ofir Bitton's avatar Ofir Bitton Committed by Oded Gabbay

habanalabs: fix potential race in interrupt wait ioctl

We have a potential race where a user interrupt can be received
in between user thread value comparison and before request was
added to wait list. This means that if no consecutive interrupt
will be received, user thread will timeout and fail.

The solution is to add the request to wait list before we
perform the comparison.
Signed-off-by: default avatarOfir Bitton <obitton@habana.ai>
Reviewed-by: default avatarOded Gabbay <ogabbay@kernel.org>
Signed-off-by: default avatarOded Gabbay <ogabbay@kernel.org>
parent 25a14332
...@@ -2740,10 +2740,20 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx, ...@@ -2740,10 +2740,20 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
else else
interrupt = &hdev->user_interrupt[interrupt_offset]; interrupt = &hdev->user_interrupt[interrupt_offset];
/* Add pending user interrupt to relevant list for the interrupt
* handler to monitor
*/
spin_lock_irqsave(&interrupt->wait_list_lock, flags);
list_add_tail(&pend->wait_list_node, &interrupt->wait_list_head);
spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
/* We check for completion value as interrupt could have been received
* before we added the node to the wait list
*/
if (copy_from_user(&completion_value, u64_to_user_ptr(user_address), 4)) { if (copy_from_user(&completion_value, u64_to_user_ptr(user_address), 4)) {
dev_err(hdev->dev, "Failed to copy completion value from user\n"); dev_err(hdev->dev, "Failed to copy completion value from user\n");
rc = -EFAULT; rc = -EFAULT;
goto free_fence; goto remove_pending_user_interrupt;
} }
if (completion_value >= target_value) if (completion_value >= target_value)
...@@ -2752,14 +2762,7 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx, ...@@ -2752,14 +2762,7 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
*status = CS_WAIT_STATUS_BUSY; *status = CS_WAIT_STATUS_BUSY;
if (!timeout_us || (*status == CS_WAIT_STATUS_COMPLETED)) if (!timeout_us || (*status == CS_WAIT_STATUS_COMPLETED))
goto free_fence; goto remove_pending_user_interrupt;
/* Add pending user interrupt to relevant list for the interrupt
* handler to monitor
*/
spin_lock_irqsave(&interrupt->wait_list_lock, flags);
list_add_tail(&pend->wait_list_node, &interrupt->wait_list_head);
spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
wait_again: wait_again:
/* Wait for interrupt handler to signal completion */ /* Wait for interrupt handler to signal completion */
...@@ -2770,6 +2773,15 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx, ...@@ -2770,6 +2773,15 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
* If comparison fails, keep waiting until timeout expires * If comparison fails, keep waiting until timeout expires
*/ */
if (completion_rc > 0) { if (completion_rc > 0) {
spin_lock_irqsave(&interrupt->wait_list_lock, flags);
/* reinit_completion must be called before we check for user
* completion value, otherwise, if interrupt is received after
* the comparison and before the next wait_for_completion,
* we will reach timeout and fail
*/
reinit_completion(&pend->fence.completion);
spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
if (copy_from_user(&completion_value, u64_to_user_ptr(user_address), 4)) { if (copy_from_user(&completion_value, u64_to_user_ptr(user_address), 4)) {
dev_err(hdev->dev, "Failed to copy completion value from user\n"); dev_err(hdev->dev, "Failed to copy completion value from user\n");
rc = -EFAULT; rc = -EFAULT;
...@@ -2780,11 +2792,7 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx, ...@@ -2780,11 +2792,7 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
if (completion_value >= target_value) { if (completion_value >= target_value) {
*status = CS_WAIT_STATUS_COMPLETED; *status = CS_WAIT_STATUS_COMPLETED;
} else { } else {
spin_lock_irqsave(&interrupt->wait_list_lock, flags);
reinit_completion(&pend->fence.completion);
timeout = completion_rc; timeout = completion_rc;
spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
goto wait_again; goto wait_again;
} }
} else if (completion_rc == -ERESTARTSYS) { } else if (completion_rc == -ERESTARTSYS) {
...@@ -2802,7 +2810,6 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx, ...@@ -2802,7 +2810,6 @@ static int _hl_interrupt_wait_ioctl(struct hl_device *hdev, struct hl_ctx *ctx,
list_del(&pend->wait_list_node); list_del(&pend->wait_list_node);
spin_unlock_irqrestore(&interrupt->wait_list_lock, flags); spin_unlock_irqrestore(&interrupt->wait_list_lock, flags);
free_fence:
kfree(pend); kfree(pend);
hl_ctx_put(ctx); hl_ctx_put(ctx);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment