Commit 98f67156 authored by Jason Gunthorpe's avatar Jason Gunthorpe

RDMA/cm: Simplify establishing a listen cm_id

Any manipulation of cm_id->state must be done under the cm_id_priv->lock,
the two routines that added listens did not follow this rule, because they
never participate in any concurrent access around the state.

However, since this exception makes the code hard to understand, simplify
the flow so that it can be fully locked:
 - Move manipulation of listen_sharecount into cm_insert_listen() so it is
   trivially under the cm.lock without having to expose the cm.lock to the
   caller.
 - Push the cm.lock down into cm_insert_listen() and have the function
   increment the reference count before returning an existing pointer.
 - Split ib_cm_listen() into an cm_init_listen() and do not call
   ib_cm_listen() from ib_cm_insert_listen()
 - Make both ib_cm_listen() and ib_cm_insert_listen() directly call
   cm_insert_listen() under their cm_id_priv->lock which does both a
   collision detect and, if needed, the insert (atomically)
 - Enclose all state manipulation within the cm_id_priv->lock, notice this
   set can be done safely after cm_insert_listen() as no reader is allowed
   to read the state without holding the lock.
 - Do not set the listen cm_id in the xarray, as it is never correct to
   look it up. This makes the concurrency simpler to understand.

Many needless error unwinds are removed in the process.

Link: https://lore.kernel.org/r/20200310092545.251365-6-leon@kernel.orgSigned-off-by: default avatarLeon Romanovsky <leonro@mellanox.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 2305d686
...@@ -620,22 +620,44 @@ static int be64_gt(__be64 a, __be64 b) ...@@ -620,22 +620,44 @@ static int be64_gt(__be64 a, __be64 b)
return (__force u64) a > (__force u64) b; return (__force u64) a > (__force u64) b;
} }
static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv) /*
* Inserts a new cm_id_priv into the listen_service_table. Returns cm_id_priv
* if the new ID was inserted, NULL if it could not be inserted due to a
* collision, or the existing cm_id_priv ready for shared usage.
*/
static struct cm_id_private *cm_insert_listen(struct cm_id_private *cm_id_priv,
ib_cm_handler shared_handler)
{ {
struct rb_node **link = &cm.listen_service_table.rb_node; struct rb_node **link = &cm.listen_service_table.rb_node;
struct rb_node *parent = NULL; struct rb_node *parent = NULL;
struct cm_id_private *cur_cm_id_priv; struct cm_id_private *cur_cm_id_priv;
__be64 service_id = cm_id_priv->id.service_id; __be64 service_id = cm_id_priv->id.service_id;
__be64 service_mask = cm_id_priv->id.service_mask; __be64 service_mask = cm_id_priv->id.service_mask;
unsigned long flags;
spin_lock_irqsave(&cm.lock, flags);
while (*link) { while (*link) {
parent = *link; parent = *link;
cur_cm_id_priv = rb_entry(parent, struct cm_id_private, cur_cm_id_priv = rb_entry(parent, struct cm_id_private,
service_node); service_node);
if ((cur_cm_id_priv->id.service_mask & service_id) == if ((cur_cm_id_priv->id.service_mask & service_id) ==
(service_mask & cur_cm_id_priv->id.service_id) && (service_mask & cur_cm_id_priv->id.service_id) &&
(cm_id_priv->id.device == cur_cm_id_priv->id.device)) (cm_id_priv->id.device == cur_cm_id_priv->id.device)) {
/*
* Sharing an ib_cm_id with different handlers is not
* supported
*/
if (cur_cm_id_priv->id.cm_handler != shared_handler ||
cur_cm_id_priv->id.context ||
WARN_ON(!cur_cm_id_priv->id.cm_handler)) {
spin_unlock_irqrestore(&cm.lock, flags);
return NULL;
}
refcount_inc(&cur_cm_id_priv->refcount);
cur_cm_id_priv->listen_sharecount++;
spin_unlock_irqrestore(&cm.lock, flags);
return cur_cm_id_priv; return cur_cm_id_priv;
}
if (cm_id_priv->id.device < cur_cm_id_priv->id.device) if (cm_id_priv->id.device < cur_cm_id_priv->id.device)
link = &(*link)->rb_left; link = &(*link)->rb_left;
...@@ -648,9 +670,11 @@ static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv) ...@@ -648,9 +670,11 @@ static struct cm_id_private * cm_insert_listen(struct cm_id_private *cm_id_priv)
else else
link = &(*link)->rb_right; link = &(*link)->rb_right;
} }
cm_id_priv->listen_sharecount++;
rb_link_node(&cm_id_priv->service_node, parent, link); rb_link_node(&cm_id_priv->service_node, parent, link);
rb_insert_color(&cm_id_priv->service_node, &cm.listen_service_table); rb_insert_color(&cm_id_priv->service_node, &cm.listen_service_table);
return NULL; spin_unlock_irqrestore(&cm.lock, flags);
return cm_id_priv;
} }
static struct cm_id_private * cm_find_listen(struct ib_device *device, static struct cm_id_private * cm_find_listen(struct ib_device *device,
...@@ -807,7 +831,7 @@ static void cm_reject_sidr_req(struct cm_id_private *cm_id_priv, ...@@ -807,7 +831,7 @@ static void cm_reject_sidr_req(struct cm_id_private *cm_id_priv,
ib_send_cm_sidr_rep(&cm_id_priv->id, &param); ib_send_cm_sidr_rep(&cm_id_priv->id, &param);
} }
struct ib_cm_id *ib_create_cm_id(struct ib_device *device, static struct cm_id_private *cm_alloc_id_priv(struct ib_device *device,
ib_cm_handler cm_handler, ib_cm_handler cm_handler,
void *context) void *context)
{ {
...@@ -840,15 +864,37 @@ struct ib_cm_id *ib_create_cm_id(struct ib_device *device, ...@@ -840,15 +864,37 @@ struct ib_cm_id *ib_create_cm_id(struct ib_device *device,
if (ret) if (ret)
goto error; goto error;
cm_id_priv->id.local_id = (__force __be32)id ^ cm.random_id_operand; cm_id_priv->id.local_id = (__force __be32)id ^ cm.random_id_operand;
xa_store_irq(&cm.local_id_table, cm_local_id(cm_id_priv->id.local_id),
cm_id_priv, GFP_KERNEL);
return &cm_id_priv->id; return cm_id_priv;
error: error:
kfree(cm_id_priv); kfree(cm_id_priv);
return ERR_PTR(ret); return ERR_PTR(ret);
} }
/*
* Make the ID visible to the MAD handlers and other threads that use the
* xarray.
*/
static void cm_finalize_id(struct cm_id_private *cm_id_priv)
{
xa_store_irq(&cm.local_id_table, cm_local_id(cm_id_priv->id.local_id),
cm_id_priv, GFP_KERNEL);
}
struct ib_cm_id *ib_create_cm_id(struct ib_device *device,
ib_cm_handler cm_handler,
void *context)
{
struct cm_id_private *cm_id_priv;
cm_id_priv = cm_alloc_id_priv(device, cm_handler, context);
if (IS_ERR(cm_id_priv))
return ERR_CAST(cm_id_priv);
cm_finalize_id(cm_id_priv);
return &cm_id_priv->id;
}
EXPORT_SYMBOL(ib_create_cm_id); EXPORT_SYMBOL(ib_create_cm_id);
static struct cm_work * cm_dequeue_work(struct cm_id_private *cm_id_priv) static struct cm_work * cm_dequeue_work(struct cm_id_private *cm_id_priv)
...@@ -1092,8 +1138,27 @@ void ib_destroy_cm_id(struct ib_cm_id *cm_id) ...@@ -1092,8 +1138,27 @@ void ib_destroy_cm_id(struct ib_cm_id *cm_id)
} }
EXPORT_SYMBOL(ib_destroy_cm_id); EXPORT_SYMBOL(ib_destroy_cm_id);
static int cm_init_listen(struct cm_id_private *cm_id_priv, __be64 service_id,
__be64 service_mask)
{
service_mask = service_mask ? service_mask : ~cpu_to_be64(0);
service_id &= service_mask;
if ((service_id & IB_SERVICE_ID_AGN_MASK) == IB_CM_ASSIGN_SERVICE_ID &&
(service_id != IB_CM_ASSIGN_SERVICE_ID))
return -EINVAL;
if (service_id == IB_CM_ASSIGN_SERVICE_ID) {
cm_id_priv->id.service_id = cpu_to_be64(cm.listen_service_id++);
cm_id_priv->id.service_mask = ~cpu_to_be64(0);
} else {
cm_id_priv->id.service_id = service_id;
cm_id_priv->id.service_mask = service_mask;
}
return 0;
}
/** /**
* __ib_cm_listen - Initiates listening on the specified service ID for * ib_cm_listen - Initiates listening on the specified service ID for
* connection and service ID resolution requests. * connection and service ID resolution requests.
* @cm_id: Connection identifier associated with the listen request. * @cm_id: Connection identifier associated with the listen request.
* @service_id: Service identifier matched against incoming connection * @service_id: Service identifier matched against incoming connection
...@@ -1105,51 +1170,33 @@ EXPORT_SYMBOL(ib_destroy_cm_id); ...@@ -1105,51 +1170,33 @@ EXPORT_SYMBOL(ib_destroy_cm_id);
* exactly. This parameter is ignored if %service_id is set to * exactly. This parameter is ignored if %service_id is set to
* IB_CM_ASSIGN_SERVICE_ID. * IB_CM_ASSIGN_SERVICE_ID.
*/ */
static int __ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, int ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, __be64 service_mask)
__be64 service_mask)
{ {
struct cm_id_private *cm_id_priv, *cur_cm_id_priv; struct cm_id_private *cm_id_priv =
int ret = 0; container_of(cm_id, struct cm_id_private, id);
unsigned long flags;
service_mask = service_mask ? service_mask : ~cpu_to_be64(0); int ret;
service_id &= service_mask;
if ((service_id & IB_SERVICE_ID_AGN_MASK) == IB_CM_ASSIGN_SERVICE_ID &&
(service_id != IB_CM_ASSIGN_SERVICE_ID))
return -EINVAL;
cm_id_priv = container_of(cm_id, struct cm_id_private, id);
if (cm_id->state != IB_CM_IDLE)
return -EINVAL;
cm_id->state = IB_CM_LISTEN;
++cm_id_priv->listen_sharecount;
if (service_id == IB_CM_ASSIGN_SERVICE_ID) { spin_lock_irqsave(&cm_id_priv->lock, flags);
cm_id->service_id = cpu_to_be64(cm.listen_service_id++); if (cm_id_priv->id.state != IB_CM_IDLE) {
cm_id->service_mask = ~cpu_to_be64(0); ret = -EINVAL;
} else { goto out;
cm_id->service_id = service_id;
cm_id->service_mask = service_mask;
} }
cur_cm_id_priv = cm_insert_listen(cm_id_priv);
if (cur_cm_id_priv) { ret = cm_init_listen(cm_id_priv, service_id, service_mask);
cm_id->state = IB_CM_IDLE; if (ret)
--cm_id_priv->listen_sharecount; goto out;
if (!cm_insert_listen(cm_id_priv, NULL)) {
ret = -EBUSY; ret = -EBUSY;
goto out;
} }
return ret;
}
int ib_cm_listen(struct ib_cm_id *cm_id, __be64 service_id, __be64 service_mask) cm_id_priv->id.state = IB_CM_LISTEN;
{ ret = 0;
unsigned long flags;
int ret;
spin_lock_irqsave(&cm.lock, flags);
ret = __ib_cm_listen(cm_id, service_id, service_mask);
spin_unlock_irqrestore(&cm.lock, flags);
out:
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
return ret; return ret;
} }
EXPORT_SYMBOL(ib_cm_listen); EXPORT_SYMBOL(ib_cm_listen);
...@@ -1174,52 +1221,38 @@ struct ib_cm_id *ib_cm_insert_listen(struct ib_device *device, ...@@ -1174,52 +1221,38 @@ struct ib_cm_id *ib_cm_insert_listen(struct ib_device *device,
ib_cm_handler cm_handler, ib_cm_handler cm_handler,
__be64 service_id) __be64 service_id)
{ {
struct cm_id_private *listen_id_priv;
struct cm_id_private *cm_id_priv; struct cm_id_private *cm_id_priv;
struct ib_cm_id *cm_id;
unsigned long flags;
int err = 0; int err = 0;
/* Create an ID in advance, since the creation may sleep */ /* Create an ID in advance, since the creation may sleep */
cm_id = ib_create_cm_id(device, cm_handler, NULL); cm_id_priv = cm_alloc_id_priv(device, cm_handler, NULL);
if (IS_ERR(cm_id)) if (IS_ERR(cm_id_priv))
return cm_id; return ERR_CAST(cm_id_priv);
spin_lock_irqsave(&cm.lock, flags);
if (service_id == IB_CM_ASSIGN_SERVICE_ID) err = cm_init_listen(cm_id_priv, service_id, 0);
goto new_id; if (err)
return ERR_PTR(err);
/* Find an existing ID */ spin_lock_irq(&cm_id_priv->lock);
cm_id_priv = cm_find_listen(device, service_id); listen_id_priv = cm_insert_listen(cm_id_priv, cm_handler);
if (cm_id_priv) { if (listen_id_priv != cm_id_priv) {
if (cm_id_priv->id.cm_handler != cm_handler || spin_unlock_irq(&cm_id_priv->lock);
cm_id_priv->id.context) { ib_destroy_cm_id(&cm_id_priv->id);
/* Sharing an ib_cm_id with different handlers is not if (!listen_id_priv)
* supported */
spin_unlock_irqrestore(&cm.lock, flags);
ib_destroy_cm_id(cm_id);
return ERR_PTR(-EINVAL); return ERR_PTR(-EINVAL);
return &listen_id_priv->id;
} }
refcount_inc(&cm_id_priv->refcount); cm_id_priv->id.state = IB_CM_LISTEN;
++cm_id_priv->listen_sharecount; spin_unlock_irq(&cm_id_priv->lock);
spin_unlock_irqrestore(&cm.lock, flags);
ib_destroy_cm_id(cm_id);
cm_id = &cm_id_priv->id;
return cm_id;
}
new_id:
/* Use newly created ID */
err = __ib_cm_listen(cm_id, service_id, 0);
spin_unlock_irqrestore(&cm.lock, flags); /*
* A listen ID does not need to be in the xarray since it does not
* receive mads, is not placed in the remote_id or remote_qpn rbtree,
* and does not enter timewait.
*/
if (err) { return &cm_id_priv->id;
ib_destroy_cm_id(cm_id);
return ERR_PTR(err);
}
return cm_id;
} }
EXPORT_SYMBOL(ib_cm_insert_listen); EXPORT_SYMBOL(ib_cm_insert_listen);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment