diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c
index 3214a4c17c6b3b0883742be3f30e081e6c7f7d32..bcbcd6d94062401c01c27b63163cd65ba1e40f48 100644
--- a/drivers/iommu/iommufd/device.c
+++ b/drivers/iommu/iommufd/device.c
@@ -327,8 +327,9 @@ static int iommufd_group_setup_msi(struct iommufd_group *igroup,
 	return 0;
 }
 
-static int iommufd_hwpt_paging_attach(struct iommufd_hwpt_paging *hwpt_paging,
-				      struct iommufd_device *idev)
+static int
+iommufd_device_attach_reserved_iova(struct iommufd_device *idev,
+				    struct iommufd_hwpt_paging *hwpt_paging)
 {
 	int rc;
 
@@ -354,6 +355,7 @@ static int iommufd_hwpt_paging_attach(struct iommufd_hwpt_paging *hwpt_paging,
 int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
 				struct iommufd_device *idev)
 {
+	struct iommufd_hwpt_paging *hwpt_paging = find_hwpt_paging(hwpt);
 	int rc;
 
 	mutex_lock(&idev->igroup->lock);
@@ -363,8 +365,8 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
 		goto err_unlock;
 	}
 
-	if (hwpt_is_paging(hwpt)) {
-		rc = iommufd_hwpt_paging_attach(to_hwpt_paging(hwpt), idev);
+	if (hwpt_paging) {
+		rc = iommufd_device_attach_reserved_iova(idev, hwpt_paging);
 		if (rc)
 			goto err_unlock;
 	}
@@ -387,9 +389,8 @@ int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
 	mutex_unlock(&idev->igroup->lock);
 	return 0;
 err_unresv:
-	if (hwpt_is_paging(hwpt))
-		iopt_remove_reserved_iova(&to_hwpt_paging(hwpt)->ioas->iopt,
-					  idev->dev);
+	if (hwpt_paging)
+		iopt_remove_reserved_iova(&hwpt_paging->ioas->iopt, idev->dev);
 err_unlock:
 	mutex_unlock(&idev->igroup->lock);
 	return rc;
@@ -399,6 +400,7 @@ struct iommufd_hw_pagetable *
 iommufd_hw_pagetable_detach(struct iommufd_device *idev)
 {
 	struct iommufd_hw_pagetable *hwpt = idev->igroup->hwpt;
+	struct iommufd_hwpt_paging *hwpt_paging = find_hwpt_paging(hwpt);
 
 	mutex_lock(&idev->igroup->lock);
 	list_del(&idev->group_item);
@@ -406,9 +408,8 @@ iommufd_hw_pagetable_detach(struct iommufd_device *idev)
 		iommufd_hwpt_detach_device(hwpt, idev);
 		idev->igroup->hwpt = NULL;
 	}
-	if (hwpt_is_paging(hwpt))
-		iopt_remove_reserved_iova(&to_hwpt_paging(hwpt)->ioas->iopt,
-					  idev->dev);
+	if (hwpt_paging)
+		iopt_remove_reserved_iova(&hwpt_paging->ioas->iopt, idev->dev);
 	mutex_unlock(&idev->igroup->lock);
 
 	/* Caller must destroy hwpt */
@@ -440,17 +441,17 @@ iommufd_group_remove_reserved_iova(struct iommufd_group *igroup,
 }
 
 static int
-iommufd_group_do_replace_paging(struct iommufd_group *igroup,
-				struct iommufd_hwpt_paging *hwpt_paging)
+iommufd_group_do_replace_reserved_iova(struct iommufd_group *igroup,
+				       struct iommufd_hwpt_paging *hwpt_paging)
 {
-	struct iommufd_hw_pagetable *old_hwpt = igroup->hwpt;
+	struct iommufd_hwpt_paging *old_hwpt_paging;
 	struct iommufd_device *cur;
 	int rc;
 
 	lockdep_assert_held(&igroup->lock);
 
-	if (!hwpt_is_paging(old_hwpt) ||
-	    hwpt_paging->ioas != to_hwpt_paging(old_hwpt)->ioas) {
+	old_hwpt_paging = find_hwpt_paging(igroup->hwpt);
+	if (!old_hwpt_paging || hwpt_paging->ioas != old_hwpt_paging->ioas) {
 		list_for_each_entry(cur, &igroup->device_list, group_item) {
 			rc = iopt_table_enforce_dev_resv_regions(
 				&hwpt_paging->ioas->iopt, cur->dev, NULL);
@@ -473,6 +474,8 @@ static struct iommufd_hw_pagetable *
 iommufd_device_do_replace(struct iommufd_device *idev,
 			  struct iommufd_hw_pagetable *hwpt)
 {
+	struct iommufd_hwpt_paging *hwpt_paging = find_hwpt_paging(hwpt);
+	struct iommufd_hwpt_paging *old_hwpt_paging;
 	struct iommufd_group *igroup = idev->igroup;
 	struct iommufd_hw_pagetable *old_hwpt;
 	unsigned int num_devices;
@@ -491,9 +494,8 @@ iommufd_device_do_replace(struct iommufd_device *idev,
 	}
 
 	old_hwpt = igroup->hwpt;
-	if (hwpt_is_paging(hwpt)) {
-		rc = iommufd_group_do_replace_paging(igroup,
-						     to_hwpt_paging(hwpt));
+	if (hwpt_paging) {
+		rc = iommufd_group_do_replace_reserved_iova(igroup, hwpt_paging);
 		if (rc)
 			goto err_unlock;
 	}
@@ -502,11 +504,10 @@ iommufd_device_do_replace(struct iommufd_device *idev,
 	if (rc)
 		goto err_unresv;
 
-	if (hwpt_is_paging(old_hwpt) &&
-	    (!hwpt_is_paging(hwpt) ||
-	     to_hwpt_paging(hwpt)->ioas != to_hwpt_paging(old_hwpt)->ioas))
-		iommufd_group_remove_reserved_iova(igroup,
-						   to_hwpt_paging(old_hwpt));
+	old_hwpt_paging = find_hwpt_paging(old_hwpt);
+	if (old_hwpt_paging &&
+	    (!hwpt_paging || hwpt_paging->ioas != old_hwpt_paging->ioas))
+		iommufd_group_remove_reserved_iova(igroup, old_hwpt_paging);
 
 	igroup->hwpt = hwpt;
 
@@ -524,9 +525,8 @@ iommufd_device_do_replace(struct iommufd_device *idev,
 	/* Caller must destroy old_hwpt */
 	return old_hwpt;
 err_unresv:
-	if (hwpt_is_paging(hwpt))
-		iommufd_group_remove_reserved_iova(igroup,
-						   to_hwpt_paging(hwpt));
+	if (hwpt_paging)
+		iommufd_group_remove_reserved_iova(igroup, hwpt_paging);
 err_unlock:
 	mutex_unlock(&idev->igroup->lock);
 	return ERR_PTR(rc);
diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h
index 92efe30a8f0d00fa92197261b26c69c09fd4fd9a..0bfacaf40c05576a41257ba0194d615f83fcaaf6 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -324,6 +324,25 @@ to_hwpt_paging(struct iommufd_hw_pagetable *hwpt)
 	return container_of(hwpt, struct iommufd_hwpt_paging, common);
 }
 
+static inline struct iommufd_hwpt_nested *
+to_hwpt_nested(struct iommufd_hw_pagetable *hwpt)
+{
+	return container_of(hwpt, struct iommufd_hwpt_nested, common);
+}
+
+static inline struct iommufd_hwpt_paging *
+find_hwpt_paging(struct iommufd_hw_pagetable *hwpt)
+{
+	switch (hwpt->obj.type) {
+	case IOMMUFD_OBJ_HWPT_PAGING:
+		return to_hwpt_paging(hwpt);
+	case IOMMUFD_OBJ_HWPT_NESTED:
+		return to_hwpt_nested(hwpt)->parent;
+	default:
+		return NULL;
+	}
+}
+
 static inline struct iommufd_hwpt_paging *
 iommufd_get_hwpt_paging(struct iommufd_ucmd *ucmd, u32 id)
 {