Commit 8bc54824 authored by Xiyu Yang via iommu's avatar Xiyu Yang via iommu Committed by Joerg Roedel

iommu/amd: Convert from atomic_t to refcount_t on pasid_state->count

refcount_t type and corresponding API can protect refcounters from
accidental underflow and overflow and further use-after-free situations.
Signed-off-by: default avatarXiyu Yang <xiyuyang19@fudan.edu.cn>
Signed-off-by: default avatarXin Tan <tanxin.ctf@gmail.com>
Reviewed-by: default avatarSuravee Suthikulpanit <suravee.suthikulpanit@amd.com>
Link: https://lore.kernel.org/r/1626683578-64214-1-git-send-email-xiyuyang19@fudan.edu.cnSigned-off-by: default avatarJoerg Roedel <jroedel@suse.de>
parent ff117646
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#define pr_fmt(fmt) "AMD-Vi: " fmt #define pr_fmt(fmt) "AMD-Vi: " fmt
#include <linux/refcount.h>
#include <linux/mmu_notifier.h> #include <linux/mmu_notifier.h>
#include <linux/amd-iommu.h> #include <linux/amd-iommu.h>
#include <linux/mm_types.h> #include <linux/mm_types.h>
...@@ -33,7 +34,7 @@ struct pri_queue { ...@@ -33,7 +34,7 @@ struct pri_queue {
struct pasid_state { struct pasid_state {
struct list_head list; /* For global state-list */ struct list_head list; /* For global state-list */
atomic_t count; /* Reference count */ refcount_t count; /* Reference count */
unsigned mmu_notifier_count; /* Counting nested mmu_notifier unsigned mmu_notifier_count; /* Counting nested mmu_notifier
calls */ calls */
struct mm_struct *mm; /* mm_struct for the faults */ struct mm_struct *mm; /* mm_struct for the faults */
...@@ -242,7 +243,7 @@ static struct pasid_state *get_pasid_state(struct device_state *dev_state, ...@@ -242,7 +243,7 @@ static struct pasid_state *get_pasid_state(struct device_state *dev_state,
ret = *ptr; ret = *ptr;
if (ret) if (ret)
atomic_inc(&ret->count); refcount_inc(&ret->count);
out_unlock: out_unlock:
spin_unlock_irqrestore(&dev_state->lock, flags); spin_unlock_irqrestore(&dev_state->lock, flags);
...@@ -257,14 +258,14 @@ static void free_pasid_state(struct pasid_state *pasid_state) ...@@ -257,14 +258,14 @@ static void free_pasid_state(struct pasid_state *pasid_state)
static void put_pasid_state(struct pasid_state *pasid_state) static void put_pasid_state(struct pasid_state *pasid_state)
{ {
if (atomic_dec_and_test(&pasid_state->count)) if (refcount_dec_and_test(&pasid_state->count))
wake_up(&pasid_state->wq); wake_up(&pasid_state->wq);
} }
static void put_pasid_state_wait(struct pasid_state *pasid_state) static void put_pasid_state_wait(struct pasid_state *pasid_state)
{ {
atomic_dec(&pasid_state->count); refcount_dec(&pasid_state->count);
wait_event(pasid_state->wq, !atomic_read(&pasid_state->count)); wait_event(pasid_state->wq, !refcount_read(&pasid_state->count));
free_pasid_state(pasid_state); free_pasid_state(pasid_state);
} }
...@@ -624,7 +625,7 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid, ...@@ -624,7 +625,7 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, u32 pasid,
goto out; goto out;
atomic_set(&pasid_state->count, 1); refcount_set(&pasid_state->count, 1);
init_waitqueue_head(&pasid_state->wq); init_waitqueue_head(&pasid_state->wq);
spin_lock_init(&pasid_state->lock); spin_lock_init(&pasid_state->lock);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment