Commit b3311b06 authored by Joerg Roedel's avatar Joerg Roedel

iommu/amd: Use container_of to get dma_ops_domain

This is better than storing an extra pointer in struct
protection_domain, because this pointer can now be removed
from the struct.
Signed-off-by: default avatarJoerg Roedel <jroedel@suse.de>
parent 281e8ccb
...@@ -231,6 +231,12 @@ static struct protection_domain *to_pdomain(struct iommu_domain *dom) ...@@ -231,6 +231,12 @@ static struct protection_domain *to_pdomain(struct iommu_domain *dom)
return container_of(dom, struct protection_domain, domain); return container_of(dom, struct protection_domain, domain);
} }
static struct dma_ops_domain* to_dma_ops_domain(struct protection_domain *domain)
{
BUG_ON(domain->flags != PD_DMA_OPS_MASK);
return container_of(domain, struct dma_ops_domain, domain);
}
static struct iommu_dev_data *alloc_dev_data(u16 devid) static struct iommu_dev_data *alloc_dev_data(u16 devid)
{ {
struct iommu_dev_data *dev_data; struct iommu_dev_data *dev_data;
...@@ -1670,7 +1676,6 @@ static struct dma_ops_domain *dma_ops_domain_alloc(void) ...@@ -1670,7 +1676,6 @@ static struct dma_ops_domain *dma_ops_domain_alloc(void)
dma_dom->domain.mode = PAGE_MODE_2_LEVEL; dma_dom->domain.mode = PAGE_MODE_2_LEVEL;
dma_dom->domain.pt_root = (void *)get_zeroed_page(GFP_KERNEL); dma_dom->domain.pt_root = (void *)get_zeroed_page(GFP_KERNEL);
dma_dom->domain.flags = PD_DMA_OPS_MASK; dma_dom->domain.flags = PD_DMA_OPS_MASK;
dma_dom->domain.priv = dma_dom;
if (!dma_dom->domain.pt_root) if (!dma_dom->domain.pt_root)
goto free_dma_dom; goto free_dma_dom;
...@@ -2367,6 +2372,7 @@ static dma_addr_t map_page(struct device *dev, struct page *page, ...@@ -2367,6 +2372,7 @@ static dma_addr_t map_page(struct device *dev, struct page *page,
{ {
phys_addr_t paddr = page_to_phys(page) + offset; phys_addr_t paddr = page_to_phys(page) + offset;
struct protection_domain *domain; struct protection_domain *domain;
struct dma_ops_domain *dma_dom;
u64 dma_mask; u64 dma_mask;
domain = get_domain(dev); domain = get_domain(dev);
...@@ -2376,8 +2382,9 @@ static dma_addr_t map_page(struct device *dev, struct page *page, ...@@ -2376,8 +2382,9 @@ static dma_addr_t map_page(struct device *dev, struct page *page,
return DMA_ERROR_CODE; return DMA_ERROR_CODE;
dma_mask = *dev->dma_mask; dma_mask = *dev->dma_mask;
dma_dom = to_dma_ops_domain(domain);
return __map_single(dev, domain->priv, paddr, size, dir, dma_mask); return __map_single(dev, dma_dom, paddr, size, dir, dma_mask);
} }
/* /*
...@@ -2387,12 +2394,15 @@ static void unmap_page(struct device *dev, dma_addr_t dma_addr, size_t size, ...@@ -2387,12 +2394,15 @@ static void unmap_page(struct device *dev, dma_addr_t dma_addr, size_t size,
enum dma_data_direction dir, struct dma_attrs *attrs) enum dma_data_direction dir, struct dma_attrs *attrs)
{ {
struct protection_domain *domain; struct protection_domain *domain;
struct dma_ops_domain *dma_dom;
domain = get_domain(dev); domain = get_domain(dev);
if (IS_ERR(domain)) if (IS_ERR(domain))
return; return;
__unmap_single(domain->priv, dma_addr, size, dir); dma_dom = to_dma_ops_domain(domain);
__unmap_single(dma_dom, dma_addr, size, dir);
} }
static int sg_num_pages(struct device *dev, static int sg_num_pages(struct device *dev,
...@@ -2440,7 +2450,7 @@ static int map_sg(struct device *dev, struct scatterlist *sglist, ...@@ -2440,7 +2450,7 @@ static int map_sg(struct device *dev, struct scatterlist *sglist,
if (IS_ERR(domain)) if (IS_ERR(domain))
return 0; return 0;
dma_dom = domain->priv; dma_dom = to_dma_ops_domain(domain);
dma_mask = *dev->dma_mask; dma_mask = *dev->dma_mask;
npages = sg_num_pages(dev, sglist, nelems); npages = sg_num_pages(dev, sglist, nelems);
...@@ -2511,6 +2521,7 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist, ...@@ -2511,6 +2521,7 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist,
struct dma_attrs *attrs) struct dma_attrs *attrs)
{ {
struct protection_domain *domain; struct protection_domain *domain;
struct dma_ops_domain *dma_dom;
unsigned long startaddr; unsigned long startaddr;
int npages = 2; int npages = 2;
...@@ -2519,9 +2530,10 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist, ...@@ -2519,9 +2530,10 @@ static void unmap_sg(struct device *dev, struct scatterlist *sglist,
return; return;
startaddr = sg_dma_address(sglist) & PAGE_MASK; startaddr = sg_dma_address(sglist) & PAGE_MASK;
dma_dom = to_dma_ops_domain(domain);
npages = sg_num_pages(dev, sglist, nelems); npages = sg_num_pages(dev, sglist, nelems);
__unmap_single(domain->priv, startaddr, npages << PAGE_SHIFT, dir); __unmap_single(dma_dom, startaddr, npages << PAGE_SHIFT, dir);
} }
/* /*
...@@ -2533,6 +2545,7 @@ static void *alloc_coherent(struct device *dev, size_t size, ...@@ -2533,6 +2545,7 @@ static void *alloc_coherent(struct device *dev, size_t size,
{ {
u64 dma_mask = dev->coherent_dma_mask; u64 dma_mask = dev->coherent_dma_mask;
struct protection_domain *domain; struct protection_domain *domain;
struct dma_ops_domain *dma_dom;
struct page *page; struct page *page;
domain = get_domain(dev); domain = get_domain(dev);
...@@ -2543,6 +2556,7 @@ static void *alloc_coherent(struct device *dev, size_t size, ...@@ -2543,6 +2556,7 @@ static void *alloc_coherent(struct device *dev, size_t size,
} else if (IS_ERR(domain)) } else if (IS_ERR(domain))
return NULL; return NULL;
dma_dom = to_dma_ops_domain(domain);
size = PAGE_ALIGN(size); size = PAGE_ALIGN(size);
dma_mask = dev->coherent_dma_mask; dma_mask = dev->coherent_dma_mask;
flag &= ~(__GFP_DMA | __GFP_HIGHMEM | __GFP_DMA32); flag &= ~(__GFP_DMA | __GFP_HIGHMEM | __GFP_DMA32);
...@@ -2562,7 +2576,7 @@ static void *alloc_coherent(struct device *dev, size_t size, ...@@ -2562,7 +2576,7 @@ static void *alloc_coherent(struct device *dev, size_t size,
if (!dma_mask) if (!dma_mask)
dma_mask = *dev->dma_mask; dma_mask = *dev->dma_mask;
*dma_addr = __map_single(dev, domain->priv, page_to_phys(page), *dma_addr = __map_single(dev, dma_dom, page_to_phys(page),
size, DMA_BIDIRECTIONAL, dma_mask); size, DMA_BIDIRECTIONAL, dma_mask);
if (*dma_addr == DMA_ERROR_CODE) if (*dma_addr == DMA_ERROR_CODE)
...@@ -2586,6 +2600,7 @@ static void free_coherent(struct device *dev, size_t size, ...@@ -2586,6 +2600,7 @@ static void free_coherent(struct device *dev, size_t size,
struct dma_attrs *attrs) struct dma_attrs *attrs)
{ {
struct protection_domain *domain; struct protection_domain *domain;
struct dma_ops_domain *dma_dom;
struct page *page; struct page *page;
page = virt_to_page(virt_addr); page = virt_to_page(virt_addr);
...@@ -2595,7 +2610,9 @@ static void free_coherent(struct device *dev, size_t size, ...@@ -2595,7 +2610,9 @@ static void free_coherent(struct device *dev, size_t size,
if (IS_ERR(domain)) if (IS_ERR(domain))
goto free_mem; goto free_mem;
__unmap_single(domain->priv, dma_addr, size, DMA_BIDIRECTIONAL); dma_dom = to_dma_ops_domain(domain);
__unmap_single(dma_dom, dma_addr, size, DMA_BIDIRECTIONAL);
free_mem: free_mem:
if (!dma_release_from_contiguous(dev, page, size >> PAGE_SHIFT)) if (!dma_release_from_contiguous(dev, page, size >> PAGE_SHIFT))
...@@ -2888,7 +2905,7 @@ static void amd_iommu_domain_free(struct iommu_domain *dom) ...@@ -2888,7 +2905,7 @@ static void amd_iommu_domain_free(struct iommu_domain *dom)
queue_flush_all(); queue_flush_all();
/* Now release the domain */ /* Now release the domain */
dma_dom = domain->priv; dma_dom = to_dma_ops_domain(domain);
dma_ops_domain_free(dma_dom); dma_ops_domain_free(dma_dom);
break; break;
default: default:
...@@ -3076,8 +3093,7 @@ static void amd_iommu_apply_dm_region(struct device *dev, ...@@ -3076,8 +3093,7 @@ static void amd_iommu_apply_dm_region(struct device *dev,
struct iommu_domain *domain, struct iommu_domain *domain,
struct iommu_dm_region *region) struct iommu_dm_region *region)
{ {
struct protection_domain *pdomain = to_pdomain(domain); struct dma_ops_domain *dma_dom = to_dma_ops_domain(to_pdomain(domain));
struct dma_ops_domain *dma_dom = pdomain->priv;
unsigned long start, end; unsigned long start, end;
start = IOVA_PFN(region->start); start = IOVA_PFN(region->start);
......
...@@ -421,7 +421,6 @@ struct protection_domain { ...@@ -421,7 +421,6 @@ struct protection_domain {
bool updated; /* complete domain flush required */ bool updated; /* complete domain flush required */
unsigned dev_cnt; /* devices assigned to this domain */ unsigned dev_cnt; /* devices assigned to this domain */
unsigned dev_iommu[MAX_IOMMUS]; /* per-IOMMU reference count */ unsigned dev_iommu[MAX_IOMMUS]; /* per-IOMMU reference count */
void *priv; /* private data */
}; };
/* /*
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment