Commit 95bf6df0 authored by Vishal Verma's avatar Vishal Verma

Merge branch 'for-6.5/dax-cleanups' into nvdimm-for-next

The reference counting of dax_region objects is needlessly complicated,
has lead to confusion [1], and has hidden a bug [2]. While testing the
cleanup for those issues, a CONFIG_DEBUG_KOBJECT_RELEASE test run
uncovered a use-after-free in dax_mapping_release(). Clean all of that
up.

Thanks to Yongqiang, Paul, and Ira for their analysis.

Additionally, clean up a redundant variable in fsdax, and fix memory
hotplug registration in the kmem driver.

[1]: http://lore.kernel.org/r/20221203095858.612027-1-liuyongqiang13@huawei.com
[2]: http://lore.kernel.org/r/3cf0890b-4eb0-e70e-cd9c-2ecc3d496263@hpe.com
parents 0e796e3e 46e66dab
......@@ -446,18 +446,33 @@ static void unregister_dev_dax(void *dev)
put_device(dev);
}
static void dax_region_free(struct kref *kref)
{
struct dax_region *dax_region;
dax_region = container_of(kref, struct dax_region, kref);
kfree(dax_region);
}
static void dax_region_put(struct dax_region *dax_region)
{
kref_put(&dax_region->kref, dax_region_free);
}
/* a return value >= 0 indicates this invocation invalidated the id */
static int __free_dev_dax_id(struct dev_dax *dev_dax)
{
struct dax_region *dax_region = dev_dax->region;
struct device *dev = &dev_dax->dev;
struct dax_region *dax_region;
int rc = dev_dax->id;
device_lock_assert(dev);
if (is_static(dax_region) || dev_dax->id < 0)
if (!dev_dax->dyn_id || dev_dax->id < 0)
return -1;
dax_region = dev_dax->region;
ida_free(&dax_region->ida, dev_dax->id);
dax_region_put(dax_region);
dev_dax->id = -1;
return rc;
}
......@@ -473,6 +488,20 @@ static int free_dev_dax_id(struct dev_dax *dev_dax)
return rc;
}
static int alloc_dev_dax_id(struct dev_dax *dev_dax)
{
struct dax_region *dax_region = dev_dax->region;
int id;
id = ida_alloc(&dax_region->ida, GFP_KERNEL);
if (id < 0)
return id;
kref_get(&dax_region->kref);
dev_dax->dyn_id = true;
dev_dax->id = id;
return id;
}
static ssize_t delete_store(struct device *dev, struct device_attribute *attr,
const char *buf, size_t len)
{
......@@ -560,20 +589,6 @@ static const struct attribute_group *dax_region_attribute_groups[] = {
NULL,
};
static void dax_region_free(struct kref *kref)
{
struct dax_region *dax_region;
dax_region = container_of(kref, struct dax_region, kref);
kfree(dax_region);
}
void dax_region_put(struct dax_region *dax_region)
{
kref_put(&dax_region->kref, dax_region_free);
}
EXPORT_SYMBOL_GPL(dax_region_put);
static void dax_region_unregister(void *region)
{
struct dax_region *dax_region = region;
......@@ -625,7 +640,6 @@ struct dax_region *alloc_dax_region(struct device *parent, int region_id,
return NULL;
}
kref_get(&dax_region->kref);
if (devm_add_action_or_reset(parent, dax_region_unregister, dax_region))
return NULL;
return dax_region;
......@@ -635,10 +649,12 @@ EXPORT_SYMBOL_GPL(alloc_dax_region);
static void dax_mapping_release(struct device *dev)
{
struct dax_mapping *mapping = to_dax_mapping(dev);
struct dev_dax *dev_dax = to_dev_dax(dev->parent);
struct device *parent = dev->parent;
struct dev_dax *dev_dax = to_dev_dax(parent);
ida_free(&dev_dax->ida, mapping->id);
kfree(mapping);
put_device(parent);
}
static void unregister_dax_mapping(void *data)
......@@ -655,8 +671,7 @@ static void unregister_dax_mapping(void *data)
dev_dax->ranges[mapping->range_id].mapping = NULL;
mapping->range_id = -1;
device_del(dev);
put_device(dev);
device_unregister(dev);
}
static struct dev_dax_range *get_dax_range(struct device *dev)
......@@ -778,6 +793,7 @@ static int devm_register_dax_mapping(struct dev_dax *dev_dax, int range_id)
dev = &mapping->dev;
device_initialize(dev);
dev->parent = &dev_dax->dev;
get_device(dev->parent);
dev->type = &dax_mapping_type;
dev_set_name(dev, "mapping%d", mapping->id);
rc = device_add(dev);
......@@ -1295,12 +1311,10 @@ static const struct attribute_group *dax_attribute_groups[] = {
static void dev_dax_release(struct device *dev)
{
struct dev_dax *dev_dax = to_dev_dax(dev);
struct dax_region *dax_region = dev_dax->region;
struct dax_device *dax_dev = dev_dax->dax_dev;
put_dax(dax_dev);
free_dev_dax_id(dev_dax);
dax_region_put(dax_region);
kfree(dev_dax->pgmap);
kfree(dev_dax);
}
......@@ -1324,6 +1338,7 @@ struct dev_dax *devm_create_dev_dax(struct dev_dax_data *data)
if (!dev_dax)
return ERR_PTR(-ENOMEM);
dev_dax->region = dax_region;
if (is_static(dax_region)) {
if (dev_WARN_ONCE(parent, data->id < 0,
"dynamic id specified to static region\n")) {
......@@ -1339,13 +1354,11 @@ struct dev_dax *devm_create_dev_dax(struct dev_dax_data *data)
goto err_id;
}
rc = ida_alloc(&dax_region->ida, GFP_KERNEL);
rc = alloc_dev_dax_id(dev_dax);
if (rc < 0)
goto err_id;
dev_dax->id = rc;
}
dev_dax->region = dax_region;
dev = &dev_dax->dev;
device_initialize(dev);
dev_set_name(dev, "dax%d.%d", dax_region->id, dev_dax->id);
......@@ -1386,7 +1399,6 @@ struct dev_dax *devm_create_dev_dax(struct dev_dax_data *data)
dev_dax->target_node = dax_region->target_node;
dev_dax->align = dax_region->align;
ida_init(&dev_dax->ida);
kref_get(&dax_region->kref);
inode = dax_inode(dax_dev);
dev->devt = inode->i_rdev;
......
......@@ -9,7 +9,6 @@ struct dev_dax;
struct resource;
struct dax_device;
struct dax_region;
void dax_region_put(struct dax_region *dax_region);
/* dax bus specific ioresource flags */
#define IORESOURCE_DAX_STATIC BIT(0)
......
......@@ -13,7 +13,6 @@ static int cxl_dax_region_probe(struct device *dev)
struct cxl_region *cxlr = cxlr_dax->cxlr;
struct dax_region *dax_region;
struct dev_dax_data data;
struct dev_dax *dev_dax;
if (nid == NUMA_NO_NODE)
nid = memory_add_physaddr_to_nid(cxlr_dax->hpa_range.start);
......@@ -28,13 +27,8 @@ static int cxl_dax_region_probe(struct device *dev)
.id = -1,
.size = range_len(&cxlr_dax->hpa_range),
};
dev_dax = devm_create_dev_dax(&data);
if (IS_ERR(dev_dax))
return PTR_ERR(dev_dax);
/* child dev_dax instances now own the lifetime of the dax_region */
dax_region_put(dax_region);
return 0;
return PTR_ERR_OR_ZERO(devm_create_dev_dax(&data));
}
static struct cxl_driver cxl_dax_region_driver = {
......
......@@ -52,7 +52,8 @@ struct dax_mapping {
* @region - parent region
* @dax_dev - core dax functionality
* @target_node: effective numa node if dev_dax memory range is onlined
* @id: ida allocated id
* @dyn_id: is this a dynamic or statically created instance
* @id: ida allocated id when the dax_region is not static
* @ida: mapping id allocator
* @dev - device core
* @pgmap - pgmap for memmap setup / lifetime (driver owned)
......@@ -64,6 +65,7 @@ struct dev_dax {
struct dax_device *dax_dev;
unsigned int align;
int target_node;
bool dyn_id;
int id;
struct ida ida;
struct device dev;
......
......@@ -16,7 +16,6 @@ static int dax_hmem_probe(struct platform_device *pdev)
struct dax_region *dax_region;
struct memregion_info *mri;
struct dev_dax_data data;
struct dev_dax *dev_dax;
/*
* @region_idle == true indicates that an administrative agent
......@@ -38,13 +37,8 @@ static int dax_hmem_probe(struct platform_device *pdev)
.id = -1,
.size = region_idle ? 0 : range_len(&mri->range),
};
dev_dax = devm_create_dev_dax(&data);
if (IS_ERR(dev_dax))
return PTR_ERR(dev_dax);
/* child dev_dax instances now own the lifetime of the dax_region */
dax_region_put(dax_region);
return 0;
return PTR_ERR_OR_ZERO(devm_create_dev_dax(&data));
}
static struct platform_driver dax_hmem_driver = {
......
......@@ -99,7 +99,7 @@ static int dev_dax_kmem_probe(struct dev_dax *dev_dax)
if (!data->res_name)
goto err_res_name;
rc = memory_group_register_static(numa_node, total_len);
rc = memory_group_register_static(numa_node, PFN_UP(total_len));
if (rc < 0)
goto err_reg_mgid;
data->mgid = rc;
......
......@@ -13,7 +13,6 @@ static struct dev_dax *__dax_pmem_probe(struct device *dev)
int rc, id, region_id;
resource_size_t offset;
struct nd_pfn_sb *pfn_sb;
struct dev_dax *dev_dax;
struct dev_dax_data data;
struct nd_namespace_io *nsio;
struct dax_region *dax_region;
......@@ -65,12 +64,8 @@ static struct dev_dax *__dax_pmem_probe(struct device *dev)
.pgmap = &pgmap,
.size = range_len(&range),
};
dev_dax = devm_create_dev_dax(&data);
/* child dev_dax instances now own the lifetime of the dax_region */
dax_region_put(dax_region);
return dev_dax;
return devm_create_dev_dax(&data);
}
static int dax_pmem_probe(struct device *dev)
......
......@@ -1830,7 +1830,6 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
vm_fault_t ret = VM_FAULT_FALLBACK;
pgoff_t max_pgoff;
void *entry;
int error;
if (vmf->flags & FAULT_FLAG_WRITE)
iter.flags |= IOMAP_WRITE;
......@@ -1877,7 +1876,7 @@ static vm_fault_t dax_iomap_pmd_fault(struct vm_fault *vmf, pfn_t *pfnp,
}
iter.pos = (loff_t)xas.xa_index << PAGE_SHIFT;
while ((error = iomap_iter(&iter, ops)) > 0) {
while (iomap_iter(&iter, ops) > 0) {
if (iomap_length(&iter) < PMD_SIZE)
continue; /* actually breaks out of the loop */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment