Commit ca22354b authored by Jason Gunthorpe's avatar Jason Gunthorpe

RDMA/rxe: Close a race after ib_register_device

Since rxe allows unregistration from other threads the rxe pointer can
become invalid any moment after ib_register_driver returns. This could
cause a user triggered use after free.

Add another driver callback to be called right after the device becomes
registered to complete any device setup required post-registration.  This
callback has enough core locking to prevent the device from becoming
unregistered.
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 6cc2c8e5
...@@ -803,6 +803,12 @@ static int enable_device_and_get(struct ib_device *device) ...@@ -803,6 +803,12 @@ static int enable_device_and_get(struct ib_device *device)
*/ */
downgrade_write(&devices_rwsem); downgrade_write(&devices_rwsem);
if (device->ops.enable_driver) {
ret = device->ops.enable_driver(device);
if (ret)
goto out;
}
down_read(&clients_rwsem); down_read(&clients_rwsem);
xa_for_each_marked (&clients, index, client, CLIENT_REGISTERED) { xa_for_each_marked (&clients, index, client, CLIENT_REGISTERED) {
ret = add_client_context(device, client); ret = add_client_context(device, client);
...@@ -810,6 +816,8 @@ static int enable_device_and_get(struct ib_device *device) ...@@ -810,6 +816,8 @@ static int enable_device_and_get(struct ib_device *device)
break; break;
} }
up_read(&clients_rwsem); up_read(&clients_rwsem);
out:
up_read(&devices_rwsem); up_read(&devices_rwsem);
return ret; return ret;
} }
...@@ -1775,6 +1783,7 @@ void ib_set_device_ops(struct ib_device *dev, const struct ib_device_ops *ops) ...@@ -1775,6 +1783,7 @@ void ib_set_device_ops(struct ib_device *dev, const struct ib_device_ops *ops)
SET_DEVICE_OP(dev_ops, disassociate_ucontext); SET_DEVICE_OP(dev_ops, disassociate_ucontext);
SET_DEVICE_OP(dev_ops, drain_rq); SET_DEVICE_OP(dev_ops, drain_rq);
SET_DEVICE_OP(dev_ops, drain_sq); SET_DEVICE_OP(dev_ops, drain_sq);
SET_DEVICE_OP(dev_ops, enable_driver);
SET_DEVICE_OP(dev_ops, fill_res_entry); SET_DEVICE_OP(dev_ops, fill_res_entry);
SET_DEVICE_OP(dev_ops, get_dev_fw_str); SET_DEVICE_OP(dev_ops, get_dev_fw_str);
SET_DEVICE_OP(dev_ops, get_dma_mr); SET_DEVICE_OP(dev_ops, get_dma_mr);
......
...@@ -517,24 +517,24 @@ enum rdma_link_layer rxe_link_layer(struct rxe_dev *rxe, unsigned int port_num) ...@@ -517,24 +517,24 @@ enum rdma_link_layer rxe_link_layer(struct rxe_dev *rxe, unsigned int port_num)
return IB_LINK_LAYER_ETHERNET; return IB_LINK_LAYER_ETHERNET;
} }
struct rxe_dev *rxe_net_add(struct net_device *ndev) int rxe_net_add(struct net_device *ndev)
{ {
int err; int err;
struct rxe_dev *rxe = NULL; struct rxe_dev *rxe = NULL;
rxe = ib_alloc_device(rxe_dev, ib_dev); rxe = ib_alloc_device(rxe_dev, ib_dev);
if (!rxe) if (!rxe)
return NULL; return -ENOMEM;
rxe->ndev = ndev; rxe->ndev = ndev;
err = rxe_add(rxe, ndev->mtu); err = rxe_add(rxe, ndev->mtu);
if (err) { if (err) {
ib_dealloc_device(&rxe->ib_dev); ib_dealloc_device(&rxe->ib_dev);
return NULL; return err;
} }
return rxe; return 0;
} }
static void rxe_port_event(struct rxe_dev *rxe, static void rxe_port_event(struct rxe_dev *rxe,
......
...@@ -43,7 +43,7 @@ struct rxe_recv_sockets { ...@@ -43,7 +43,7 @@ struct rxe_recv_sockets {
struct socket *sk6; struct socket *sk6;
}; };
struct rxe_dev *rxe_net_add(struct net_device *ndev); int rxe_net_add(struct net_device *ndev);
int rxe_net_init(void); int rxe_net_init(void);
void rxe_net_exit(void); void rxe_net_exit(void);
......
...@@ -60,7 +60,6 @@ static int rxe_param_set_add(const char *val, const struct kernel_param *kp) ...@@ -60,7 +60,6 @@ static int rxe_param_set_add(const char *val, const struct kernel_param *kp)
char intf[32]; char intf[32];
struct net_device *ndev; struct net_device *ndev;
struct rxe_dev *exists; struct rxe_dev *exists;
struct rxe_dev *rxe;
len = sanitize_arg(val, intf, sizeof(intf)); len = sanitize_arg(val, intf, sizeof(intf));
if (!len) { if (!len) {
...@@ -82,16 +81,12 @@ static int rxe_param_set_add(const char *val, const struct kernel_param *kp) ...@@ -82,16 +81,12 @@ static int rxe_param_set_add(const char *val, const struct kernel_param *kp)
goto err; goto err;
} }
rxe = rxe_net_add(ndev); err = rxe_net_add(ndev);
if (!rxe) { if (err) {
pr_err("failed to add %s\n", intf); pr_err("failed to add %s\n", intf);
err = -EINVAL;
goto err; goto err;
} }
rxe_set_port_state(rxe);
dev_info(&rxe->ib_dev.dev, "added %s\n", intf);
err: err:
dev_put(ndev); dev_put(ndev);
return err; return err;
......
...@@ -1125,6 +1125,15 @@ static const struct attribute_group rxe_attr_group = { ...@@ -1125,6 +1125,15 @@ static const struct attribute_group rxe_attr_group = {
.attrs = rxe_dev_attributes, .attrs = rxe_dev_attributes,
}; };
static int rxe_enable_driver(struct ib_device *ib_dev)
{
struct rxe_dev *rxe = container_of(ib_dev, struct rxe_dev, ib_dev);
rxe_set_port_state(rxe);
dev_info(&rxe->ib_dev.dev, "added %s\n", netdev_name(rxe->ndev));
return 0;
}
static const struct ib_device_ops rxe_dev_ops = { static const struct ib_device_ops rxe_dev_ops = {
.alloc_hw_stats = rxe_ib_alloc_hw_stats, .alloc_hw_stats = rxe_ib_alloc_hw_stats,
.alloc_mr = rxe_alloc_mr, .alloc_mr = rxe_alloc_mr,
...@@ -1144,6 +1153,7 @@ static const struct ib_device_ops rxe_dev_ops = { ...@@ -1144,6 +1153,7 @@ static const struct ib_device_ops rxe_dev_ops = {
.destroy_qp = rxe_destroy_qp, .destroy_qp = rxe_destroy_qp,
.destroy_srq = rxe_destroy_srq, .destroy_srq = rxe_destroy_srq,
.detach_mcast = rxe_detach_mcast, .detach_mcast = rxe_detach_mcast,
.enable_driver = rxe_enable_driver,
.get_dma_mr = rxe_get_dma_mr, .get_dma_mr = rxe_get_dma_mr,
.get_hw_stats = rxe_ib_get_hw_stats, .get_hw_stats = rxe_ib_get_hw_stats,
.get_link_layer = rxe_get_link_layer, .get_link_layer = rxe_get_link_layer,
...@@ -1245,5 +1255,9 @@ int rxe_register_device(struct rxe_dev *rxe) ...@@ -1245,5 +1255,9 @@ int rxe_register_device(struct rxe_dev *rxe)
if (err) if (err)
pr_warn("%s failed with error %d\n", __func__, err); pr_warn("%s failed with error %d\n", __func__, err);
/*
* Note that rxe may be invalid at this point if another thread
* unregistered it.
*/
return err; return err;
} }
...@@ -2539,6 +2539,11 @@ struct ib_device_ops { ...@@ -2539,6 +2539,11 @@ struct ib_device_ops {
struct rdma_restrack_entry *entry); struct rdma_restrack_entry *entry);
/* Device lifecycle callbacks */ /* Device lifecycle callbacks */
/*
* Called after the device becomes registered, before clients are
* attached
*/
int (*enable_driver)(struct ib_device *dev);
/* /*
* This is called as part of ib_dealloc_device(). * This is called as part of ib_dealloc_device().
*/ */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment