Commit d8b2993c authored by Patrick Mochel's avatar Patrick Mochel

Update device model locking

This updates the device model locking to use device_lock when accessing all
lists (the global list, the bus' lists and the drivers' lists). Before the latter
two would use their own rwlocks. 

This also updates get_device() to return a pointer to the struct device if it 
can successfully increment the reference count. 

Between these two changes, this should prevent anything gaining an invalid 
reference to a device that is in the process of being removed:

If a device is being removed, it's reference count is 0, but it hasn't 
necessarily hasn't been removed from its bus's list. If the bus list iterator
attempts to access the device, it will take the lock, but will continue on to 
the next device because the refcount is 0 (and drop the lock).

Well, theoretically; the bus iterators still need to be changed, but that's 
coming next..
parent 44013500
...@@ -9,6 +9,8 @@ ...@@ -9,6 +9,8 @@
extern struct device device_root; extern struct device device_root;
extern spinlock_t device_lock; extern spinlock_t device_lock;
extern struct device * get_device_locked(struct device *);
extern int bus_add_device(struct device * dev); extern int bus_add_device(struct device * dev);
extern void bus_remove_device(struct device * dev); extern void bus_remove_device(struct device * dev);
......
...@@ -42,12 +42,12 @@ int bus_for_each_dev(struct bus_type * bus, void * data, ...@@ -42,12 +42,12 @@ int bus_for_each_dev(struct bus_type * bus, void * data,
int error = 0; int error = 0;
get_bus(bus); get_bus(bus);
read_lock(&bus->lock); spin_lock(&device_lock);
node = bus->devices.next; node = bus->devices.next;
while (node != &bus->devices) { while (node != &bus->devices) {
next = list_entry(node,struct device,bus_list); next = list_entry(node,struct device,bus_list);
get_device(next); get_device_locked(next);
read_unlock(&bus->lock); spin_unlock(&device_lock);
if (dev) if (dev)
put_device(dev); put_device(dev);
...@@ -56,10 +56,10 @@ int bus_for_each_dev(struct bus_type * bus, void * data, ...@@ -56,10 +56,10 @@ int bus_for_each_dev(struct bus_type * bus, void * data,
put_device(dev); put_device(dev);
break; break;
} }
read_lock(&bus->lock); spin_lock(&device_lock);
node = dev->bus_list.next; node = dev->bus_list.next;
} }
read_unlock(&bus->lock); spin_unlock(&device_lock);
if (dev) if (dev)
put_device(dev); put_device(dev);
put_bus(bus); put_bus(bus);
...@@ -77,12 +77,12 @@ int bus_for_each_drv(struct bus_type * bus, void * data, ...@@ -77,12 +77,12 @@ int bus_for_each_drv(struct bus_type * bus, void * data,
/* pin bus in memory */ /* pin bus in memory */
get_bus(bus); get_bus(bus);
read_lock(&bus->lock); spin_lock(&device_lock);
node = bus->drivers.next; node = bus->drivers.next;
while (node != &bus->drivers) { while (node != &bus->drivers) {
next = list_entry(node,struct device_driver,bus_list); next = list_entry(node,struct device_driver,bus_list);
get_driver(next); get_driver(next);
read_unlock(&bus->lock); spin_unlock(&device_lock);
if (drv) if (drv)
put_driver(drv); put_driver(drv);
...@@ -91,10 +91,10 @@ int bus_for_each_drv(struct bus_type * bus, void * data, ...@@ -91,10 +91,10 @@ int bus_for_each_drv(struct bus_type * bus, void * data,
put_driver(drv); put_driver(drv);
break; break;
} }
read_lock(&bus->lock); spin_lock(&device_lock);
node = drv->bus_list.next; node = drv->bus_list.next;
} }
read_unlock(&bus->lock); spin_unlock(&device_lock);
if (drv) if (drv)
put_driver(drv); put_driver(drv);
put_bus(bus); put_bus(bus);
...@@ -115,9 +115,9 @@ int bus_add_device(struct device * dev) ...@@ -115,9 +115,9 @@ int bus_add_device(struct device * dev)
if (dev->bus) { if (dev->bus) {
pr_debug("registering %s with bus '%s'\n",dev->bus_id,dev->bus->name); pr_debug("registering %s with bus '%s'\n",dev->bus_id,dev->bus->name);
get_bus(dev->bus); get_bus(dev->bus);
write_lock(&dev->bus->lock); spin_lock(&device_lock);
list_add_tail(&dev->bus_list,&dev->bus->devices); list_add_tail(&dev->bus_list,&dev->bus->devices);
write_unlock(&dev->bus->lock); spin_unlock(&device_lock);
device_bus_link(dev); device_bus_link(dev);
} }
return 0; return 0;
...@@ -134,9 +134,9 @@ void bus_remove_device(struct device * dev) ...@@ -134,9 +134,9 @@ void bus_remove_device(struct device * dev)
{ {
if (dev->bus) { if (dev->bus) {
device_remove_symlink(&dev->bus->device_dir,dev->bus_id); device_remove_symlink(&dev->bus->device_dir,dev->bus_id);
write_lock(&dev->bus->lock); spin_lock(&device_lock);
list_del_init(&dev->bus_list); list_del_init(&dev->bus_list);
write_unlock(&dev->bus->lock); spin_unlock(&device_lock);
put_bus(dev->bus); put_bus(dev->bus);
} }
} }
......
...@@ -53,9 +53,9 @@ static int found_match(struct device * dev, struct device_driver * drv) ...@@ -53,9 +53,9 @@ static int found_match(struct device * dev, struct device_driver * drv)
pr_debug("bound device '%s' to driver '%s'\n", pr_debug("bound device '%s' to driver '%s'\n",
dev->bus_id,drv->name); dev->bus_id,drv->name);
write_lock(&drv->lock); spin_lock(&device_lock);
list_add_tail(&dev->driver_list,&drv->devices); list_add_tail(&dev->driver_list,&drv->devices);
write_unlock(&drv->lock); spin_unlock(&device_lock);
goto Done; goto Done;
...@@ -154,13 +154,13 @@ void driver_detach(struct device_driver * drv) ...@@ -154,13 +154,13 @@ void driver_detach(struct device_driver * drv)
struct list_head * node; struct list_head * node;
int error = 0; int error = 0;
write_lock(&drv->lock); spin_lock(&device_lock);
node = drv->devices.next; node = drv->devices.next;
while (node != &drv->devices) { while (node != &drv->devices) {
next = list_entry(node,struct device,driver_list); next = list_entry(node,struct device,driver_list);
get_device(next); get_device_locked(next);
list_del_init(&next->driver_list); list_del_init(&next->driver_list);
write_unlock(&drv->lock); spin_unlock(&device_lock);
if (dev) if (dev)
put_device(dev); put_device(dev);
...@@ -169,10 +169,10 @@ void driver_detach(struct device_driver * drv) ...@@ -169,10 +169,10 @@ void driver_detach(struct device_driver * drv)
put_device(dev); put_device(dev);
break; break;
} }
write_lock(&drv->lock); spin_lock(&device_lock);
node = drv->devices.next; node = drv->devices.next;
} }
write_unlock(&drv->lock); spin_unlock(&device_lock);
if (dev) if (dev)
put_device(dev); put_device(dev);
} }
...@@ -202,12 +202,12 @@ int device_register(struct device *dev) ...@@ -202,12 +202,12 @@ int device_register(struct device *dev)
spin_lock_init(&dev->lock); spin_lock_init(&dev->lock);
atomic_set(&dev->refcount,2); atomic_set(&dev->refcount,2);
spin_lock(&device_lock);
if (dev != &device_root) { if (dev != &device_root) {
if (!dev->parent) if (!dev->parent)
dev->parent = &device_root; dev->parent = &device_root;
get_device(dev->parent); get_device(dev->parent);
spin_lock(&device_lock);
if (list_empty(&dev->parent->children)) if (list_empty(&dev->parent->children))
prev_dev = dev->parent; prev_dev = dev->parent;
else else
...@@ -215,8 +215,8 @@ int device_register(struct device *dev) ...@@ -215,8 +215,8 @@ int device_register(struct device *dev)
list_add(&dev->g_list, &prev_dev->g_list); list_add(&dev->g_list, &prev_dev->g_list);
list_add_tail(&dev->node,&dev->parent->children); list_add_tail(&dev->node,&dev->parent->children);
spin_unlock(&device_lock);
} }
spin_unlock(&device_lock);
pr_debug("DEV: registering device: ID = '%s', name = %s\n", pr_debug("DEV: registering device: ID = '%s', name = %s\n",
dev->bus_id, dev->name); dev->bus_id, dev->name);
...@@ -240,6 +240,25 @@ int device_register(struct device *dev) ...@@ -240,6 +240,25 @@ int device_register(struct device *dev)
return error; return error;
} }
struct device * get_device_locked(struct device * dev)
{
struct device * ret = dev;
if (dev && atomic_read(&dev->refcount))
atomic_inc(&dev->refcount);
else
ret = NULL;
return ret;
}
struct device * get_device(struct device * dev)
{
struct device * ret;
spin_lock(&device_lock);
ret = get_device_locked(dev);
spin_unlock(&device_lock);
return ret;
}
/** /**
* put_device - decrement reference count, and clean up when it hits 0 * put_device - decrement reference count, and clean up when it hits 0
* @dev: device in question * @dev: device in question
...@@ -296,4 +315,5 @@ static int __init device_init(void) ...@@ -296,4 +315,5 @@ static int __init device_init(void)
core_initcall(device_init); core_initcall(device_init);
EXPORT_SYMBOL(device_register); EXPORT_SYMBOL(device_register);
EXPORT_SYMBOL(get_device);
EXPORT_SYMBOL(put_device); EXPORT_SYMBOL(put_device);
...@@ -19,12 +19,12 @@ int driver_for_each_dev(struct device_driver * drv, void * data, int (*callback) ...@@ -19,12 +19,12 @@ int driver_for_each_dev(struct device_driver * drv, void * data, int (*callback)
int error = 0; int error = 0;
get_driver(drv); get_driver(drv);
read_lock(&drv->lock); spin_lock(&device_lock);
node = drv->devices.next; node = drv->devices.next;
while (node != &drv->devices) { while (node != &drv->devices) {
next = list_entry(node,struct device,driver_list); next = list_entry(node,struct device,driver_list);
get_device(next); get_device_locked(next);
read_unlock(&drv->lock); spin_unlock(&device_lock);
if (dev) if (dev)
put_device(dev); put_device(dev);
...@@ -33,10 +33,10 @@ int driver_for_each_dev(struct device_driver * drv, void * data, int (*callback) ...@@ -33,10 +33,10 @@ int driver_for_each_dev(struct device_driver * drv, void * data, int (*callback)
put_device(dev); put_device(dev);
break; break;
} }
read_lock(&drv->lock); spin_lock(&device_lock);
node = dev->driver_list.next; node = dev->driver_list.next;
} }
read_unlock(&drv->lock); spin_unlock(&device_lock);
if (dev) if (dev)
put_device(dev); put_device(dev);
put_driver(drv); put_driver(drv);
...@@ -60,9 +60,9 @@ int driver_register(struct device_driver * drv) ...@@ -60,9 +60,9 @@ int driver_register(struct device_driver * drv)
atomic_set(&drv->refcount,2); atomic_set(&drv->refcount,2);
rwlock_init(&drv->lock); rwlock_init(&drv->lock);
INIT_LIST_HEAD(&drv->devices); INIT_LIST_HEAD(&drv->devices);
write_lock(&drv->bus->lock); spin_lock(&device_lock);
list_add(&drv->bus_list,&drv->bus->drivers); list_add(&drv->bus_list,&drv->bus->drivers);
write_unlock(&drv->bus->lock); spin_unlock(&device_lock);
driver_make_dir(drv); driver_make_dir(drv);
driver_attach(drv); driver_attach(drv);
put_driver(drv); put_driver(drv);
...@@ -81,10 +81,10 @@ static void __remove_driver(struct device_driver * drv) ...@@ -81,10 +81,10 @@ static void __remove_driver(struct device_driver * drv)
void remove_driver(struct device_driver * drv) void remove_driver(struct device_driver * drv)
{ {
write_lock(&drv->bus->lock); spin_lock(&device_lock);
atomic_set(&drv->refcount,0); atomic_set(&drv->refcount,0);
list_del_init(&drv->bus_list); list_del_init(&drv->bus_list);
write_unlock(&drv->bus->lock); spin_unlock(&device_lock);
__remove_driver(drv); __remove_driver(drv);
} }
...@@ -94,13 +94,10 @@ void remove_driver(struct device_driver * drv) ...@@ -94,13 +94,10 @@ void remove_driver(struct device_driver * drv)
*/ */
void put_driver(struct device_driver * drv) void put_driver(struct device_driver * drv)
{ {
write_lock(&drv->bus->lock); if (!atomic_dec_and_lock(&drv->refcount,&device_lock))
if (!atomic_dec_and_test(&drv->refcount)) {
write_unlock(&drv->bus->lock);
return; return;
}
list_del_init(&drv->bus_list); list_del_init(&drv->bus_list);
write_unlock(&drv->bus->lock); spin_unlock(&device_lock);
__remove_driver(drv); __remove_driver(drv);
} }
......
...@@ -36,7 +36,7 @@ int device_suspend(u32 state, u32 level) ...@@ -36,7 +36,7 @@ int device_suspend(u32 state, u32 level)
spin_lock(&device_lock); spin_lock(&device_lock);
dev = g_list_to_dev(prev->g_list.next); dev = g_list_to_dev(prev->g_list.next);
while(dev != &device_root && !error) { while(dev != &device_root && !error) {
get_device(dev); get_device_locked(dev);
spin_unlock(&device_lock); spin_unlock(&device_lock);
put_device(prev); put_device(prev);
...@@ -71,7 +71,7 @@ void device_resume(u32 level) ...@@ -71,7 +71,7 @@ void device_resume(u32 level)
spin_lock(&device_lock); spin_lock(&device_lock);
dev = g_list_to_dev(prev->g_list.prev); dev = g_list_to_dev(prev->g_list.prev);
while(dev != &device_root) { while(dev != &device_root) {
get_device(dev); get_device_locked(dev);
spin_unlock(&device_lock); spin_unlock(&device_lock);
put_device(prev); put_device(prev);
...@@ -108,7 +108,7 @@ void device_shutdown(void) ...@@ -108,7 +108,7 @@ void device_shutdown(void)
spin_lock(&device_lock); spin_lock(&device_lock);
dev = g_list_to_dev(prev->g_list.next); dev = g_list_to_dev(prev->g_list.next);
while(dev != &device_root) { while(dev != &device_root) {
get_device(dev); dev = get_device_locked(dev);
spin_unlock(&device_lock); spin_unlock(&device_lock);
put_device(prev); put_device(prev);
......
...@@ -261,12 +261,7 @@ static inline void unlock_device(struct device * dev) ...@@ -261,12 +261,7 @@ static inline void unlock_device(struct device * dev)
* get_device - atomically increment the reference count for the device. * get_device - atomically increment the reference count for the device.
* *
*/ */
static inline void get_device(struct device * dev) extern struct device * get_device(struct device * dev);
{
BUG_ON(!atomic_read(&dev->refcount));
atomic_inc(&dev->refcount);
}
extern void put_device(struct device * dev); extern void put_device(struct device * dev);
/* drivers/base/sys.c */ /* drivers/base/sys.c */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment