Commit 1b52fa98 authored by Sean Hefty's avatar Sean Hefty Committed by Roland Dreier

IB: refcount race fixes

Fix race condition during destruction calls to avoid possibility of
accessing object after it has been freed.  Instead of waking up a wait
queue directly, which is susceptible to a race where the object is
freed between the reference count going to 0 and the wake_up(), use a
completion to wait in the function doing the freeing.
Signed-off-by: default avatarSean Hefty <sean.hefty@intel.com>
Signed-off-by: default avatarRoland Dreier <rolandd@cisco.com>
parent 6f4bb3d8
...@@ -34,6 +34,8 @@ ...@@ -34,6 +34,8 @@
* *
* $Id: cm.c 2821 2005-07-08 17:07:28Z sean.hefty $ * $Id: cm.c 2821 2005-07-08 17:07:28Z sean.hefty $
*/ */
#include <linux/completion.h>
#include <linux/dma-mapping.h> #include <linux/dma-mapping.h>
#include <linux/err.h> #include <linux/err.h>
#include <linux/idr.h> #include <linux/idr.h>
...@@ -122,7 +124,7 @@ struct cm_id_private { ...@@ -122,7 +124,7 @@ struct cm_id_private {
struct rb_node service_node; struct rb_node service_node;
struct rb_node sidr_id_node; struct rb_node sidr_id_node;
spinlock_t lock; /* Do not acquire inside cm.lock */ spinlock_t lock; /* Do not acquire inside cm.lock */
wait_queue_head_t wait; struct completion comp;
atomic_t refcount; atomic_t refcount;
struct ib_mad_send_buf *msg; struct ib_mad_send_buf *msg;
...@@ -159,7 +161,7 @@ static void cm_work_handler(void *data); ...@@ -159,7 +161,7 @@ static void cm_work_handler(void *data);
static inline void cm_deref_id(struct cm_id_private *cm_id_priv) static inline void cm_deref_id(struct cm_id_private *cm_id_priv)
{ {
if (atomic_dec_and_test(&cm_id_priv->refcount)) if (atomic_dec_and_test(&cm_id_priv->refcount))
wake_up(&cm_id_priv->wait); complete(&cm_id_priv->comp);
} }
static int cm_alloc_msg(struct cm_id_private *cm_id_priv, static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
...@@ -559,7 +561,7 @@ struct ib_cm_id *ib_create_cm_id(struct ib_device *device, ...@@ -559,7 +561,7 @@ struct ib_cm_id *ib_create_cm_id(struct ib_device *device,
goto error; goto error;
spin_lock_init(&cm_id_priv->lock); spin_lock_init(&cm_id_priv->lock);
init_waitqueue_head(&cm_id_priv->wait); init_completion(&cm_id_priv->comp);
INIT_LIST_HEAD(&cm_id_priv->work_list); INIT_LIST_HEAD(&cm_id_priv->work_list);
atomic_set(&cm_id_priv->work_count, -1); atomic_set(&cm_id_priv->work_count, -1);
atomic_set(&cm_id_priv->refcount, 1); atomic_set(&cm_id_priv->refcount, 1);
...@@ -724,8 +726,8 @@ void ib_destroy_cm_id(struct ib_cm_id *cm_id) ...@@ -724,8 +726,8 @@ void ib_destroy_cm_id(struct ib_cm_id *cm_id)
} }
cm_free_id(cm_id->local_id); cm_free_id(cm_id->local_id);
atomic_dec(&cm_id_priv->refcount); cm_deref_id(cm_id_priv);
wait_event(cm_id_priv->wait, !atomic_read(&cm_id_priv->refcount)); wait_for_completion(&cm_id_priv->comp);
while ((work = cm_dequeue_work(cm_id_priv)) != NULL) while ((work = cm_dequeue_work(cm_id_priv)) != NULL)
cm_free_work(work); cm_free_work(work);
if (cm_id_priv->private_data && cm_id_priv->private_data_len) if (cm_id_priv->private_data && cm_id_priv->private_data_len)
......
...@@ -352,7 +352,7 @@ struct ib_mad_agent *ib_register_mad_agent(struct ib_device *device, ...@@ -352,7 +352,7 @@ struct ib_mad_agent *ib_register_mad_agent(struct ib_device *device,
INIT_WORK(&mad_agent_priv->local_work, local_completions, INIT_WORK(&mad_agent_priv->local_work, local_completions,
mad_agent_priv); mad_agent_priv);
atomic_set(&mad_agent_priv->refcount, 1); atomic_set(&mad_agent_priv->refcount, 1);
init_waitqueue_head(&mad_agent_priv->wait); init_completion(&mad_agent_priv->comp);
return &mad_agent_priv->agent; return &mad_agent_priv->agent;
...@@ -467,7 +467,7 @@ struct ib_mad_agent *ib_register_mad_snoop(struct ib_device *device, ...@@ -467,7 +467,7 @@ struct ib_mad_agent *ib_register_mad_snoop(struct ib_device *device,
mad_snoop_priv->agent.qp = port_priv->qp_info[qpn].qp; mad_snoop_priv->agent.qp = port_priv->qp_info[qpn].qp;
mad_snoop_priv->agent.port_num = port_num; mad_snoop_priv->agent.port_num = port_num;
mad_snoop_priv->mad_snoop_flags = mad_snoop_flags; mad_snoop_priv->mad_snoop_flags = mad_snoop_flags;
init_waitqueue_head(&mad_snoop_priv->wait); init_completion(&mad_snoop_priv->comp);
mad_snoop_priv->snoop_index = register_snoop_agent( mad_snoop_priv->snoop_index = register_snoop_agent(
&port_priv->qp_info[qpn], &port_priv->qp_info[qpn],
mad_snoop_priv); mad_snoop_priv);
...@@ -486,6 +486,18 @@ struct ib_mad_agent *ib_register_mad_snoop(struct ib_device *device, ...@@ -486,6 +486,18 @@ struct ib_mad_agent *ib_register_mad_snoop(struct ib_device *device,
} }
EXPORT_SYMBOL(ib_register_mad_snoop); EXPORT_SYMBOL(ib_register_mad_snoop);
static inline void deref_mad_agent(struct ib_mad_agent_private *mad_agent_priv)
{
if (atomic_dec_and_test(&mad_agent_priv->refcount))
complete(&mad_agent_priv->comp);
}
static inline void deref_snoop_agent(struct ib_mad_snoop_private *mad_snoop_priv)
{
if (atomic_dec_and_test(&mad_snoop_priv->refcount))
complete(&mad_snoop_priv->comp);
}
static void unregister_mad_agent(struct ib_mad_agent_private *mad_agent_priv) static void unregister_mad_agent(struct ib_mad_agent_private *mad_agent_priv)
{ {
struct ib_mad_port_private *port_priv; struct ib_mad_port_private *port_priv;
...@@ -509,9 +521,8 @@ static void unregister_mad_agent(struct ib_mad_agent_private *mad_agent_priv) ...@@ -509,9 +521,8 @@ static void unregister_mad_agent(struct ib_mad_agent_private *mad_agent_priv)
flush_workqueue(port_priv->wq); flush_workqueue(port_priv->wq);
ib_cancel_rmpp_recvs(mad_agent_priv); ib_cancel_rmpp_recvs(mad_agent_priv);
atomic_dec(&mad_agent_priv->refcount); deref_mad_agent(mad_agent_priv);
wait_event(mad_agent_priv->wait, wait_for_completion(&mad_agent_priv->comp);
!atomic_read(&mad_agent_priv->refcount));
kfree(mad_agent_priv->reg_req); kfree(mad_agent_priv->reg_req);
ib_dereg_mr(mad_agent_priv->agent.mr); ib_dereg_mr(mad_agent_priv->agent.mr);
...@@ -529,9 +540,8 @@ static void unregister_mad_snoop(struct ib_mad_snoop_private *mad_snoop_priv) ...@@ -529,9 +540,8 @@ static void unregister_mad_snoop(struct ib_mad_snoop_private *mad_snoop_priv)
atomic_dec(&qp_info->snoop_count); atomic_dec(&qp_info->snoop_count);
spin_unlock_irqrestore(&qp_info->snoop_lock, flags); spin_unlock_irqrestore(&qp_info->snoop_lock, flags);
atomic_dec(&mad_snoop_priv->refcount); deref_snoop_agent(mad_snoop_priv);
wait_event(mad_snoop_priv->wait, wait_for_completion(&mad_snoop_priv->comp);
!atomic_read(&mad_snoop_priv->refcount));
kfree(mad_snoop_priv); kfree(mad_snoop_priv);
} }
...@@ -600,8 +610,7 @@ static void snoop_send(struct ib_mad_qp_info *qp_info, ...@@ -600,8 +610,7 @@ static void snoop_send(struct ib_mad_qp_info *qp_info,
spin_unlock_irqrestore(&qp_info->snoop_lock, flags); spin_unlock_irqrestore(&qp_info->snoop_lock, flags);
mad_snoop_priv->agent.snoop_handler(&mad_snoop_priv->agent, mad_snoop_priv->agent.snoop_handler(&mad_snoop_priv->agent,
send_buf, mad_send_wc); send_buf, mad_send_wc);
if (atomic_dec_and_test(&mad_snoop_priv->refcount)) deref_snoop_agent(mad_snoop_priv);
wake_up(&mad_snoop_priv->wait);
spin_lock_irqsave(&qp_info->snoop_lock, flags); spin_lock_irqsave(&qp_info->snoop_lock, flags);
} }
spin_unlock_irqrestore(&qp_info->snoop_lock, flags); spin_unlock_irqrestore(&qp_info->snoop_lock, flags);
...@@ -626,8 +635,7 @@ static void snoop_recv(struct ib_mad_qp_info *qp_info, ...@@ -626,8 +635,7 @@ static void snoop_recv(struct ib_mad_qp_info *qp_info,
spin_unlock_irqrestore(&qp_info->snoop_lock, flags); spin_unlock_irqrestore(&qp_info->snoop_lock, flags);
mad_snoop_priv->agent.recv_handler(&mad_snoop_priv->agent, mad_snoop_priv->agent.recv_handler(&mad_snoop_priv->agent,
mad_recv_wc); mad_recv_wc);
if (atomic_dec_and_test(&mad_snoop_priv->refcount)) deref_snoop_agent(mad_snoop_priv);
wake_up(&mad_snoop_priv->wait);
spin_lock_irqsave(&qp_info->snoop_lock, flags); spin_lock_irqsave(&qp_info->snoop_lock, flags);
} }
spin_unlock_irqrestore(&qp_info->snoop_lock, flags); spin_unlock_irqrestore(&qp_info->snoop_lock, flags);
...@@ -968,8 +976,7 @@ void ib_free_send_mad(struct ib_mad_send_buf *send_buf) ...@@ -968,8 +976,7 @@ void ib_free_send_mad(struct ib_mad_send_buf *send_buf)
free_send_rmpp_list(mad_send_wr); free_send_rmpp_list(mad_send_wr);
kfree(send_buf->mad); kfree(send_buf->mad);
if (atomic_dec_and_test(&mad_agent_priv->refcount)) deref_mad_agent(mad_agent_priv);
wake_up(&mad_agent_priv->wait);
} }
EXPORT_SYMBOL(ib_free_send_mad); EXPORT_SYMBOL(ib_free_send_mad);
...@@ -1757,8 +1764,7 @@ static void ib_mad_complete_recv(struct ib_mad_agent_private *mad_agent_priv, ...@@ -1757,8 +1764,7 @@ static void ib_mad_complete_recv(struct ib_mad_agent_private *mad_agent_priv,
mad_recv_wc = ib_process_rmpp_recv_wc(mad_agent_priv, mad_recv_wc = ib_process_rmpp_recv_wc(mad_agent_priv,
mad_recv_wc); mad_recv_wc);
if (!mad_recv_wc) { if (!mad_recv_wc) {
if (atomic_dec_and_test(&mad_agent_priv->refcount)) deref_mad_agent(mad_agent_priv);
wake_up(&mad_agent_priv->wait);
return; return;
} }
} }
...@@ -1770,8 +1776,7 @@ static void ib_mad_complete_recv(struct ib_mad_agent_private *mad_agent_priv, ...@@ -1770,8 +1776,7 @@ static void ib_mad_complete_recv(struct ib_mad_agent_private *mad_agent_priv,
if (!mad_send_wr) { if (!mad_send_wr) {
spin_unlock_irqrestore(&mad_agent_priv->lock, flags); spin_unlock_irqrestore(&mad_agent_priv->lock, flags);
ib_free_recv_mad(mad_recv_wc); ib_free_recv_mad(mad_recv_wc);
if (atomic_dec_and_test(&mad_agent_priv->refcount)) deref_mad_agent(mad_agent_priv);
wake_up(&mad_agent_priv->wait);
return; return;
} }
ib_mark_mad_done(mad_send_wr); ib_mark_mad_done(mad_send_wr);
...@@ -1790,8 +1795,7 @@ static void ib_mad_complete_recv(struct ib_mad_agent_private *mad_agent_priv, ...@@ -1790,8 +1795,7 @@ static void ib_mad_complete_recv(struct ib_mad_agent_private *mad_agent_priv,
} else { } else {
mad_agent_priv->agent.recv_handler(&mad_agent_priv->agent, mad_agent_priv->agent.recv_handler(&mad_agent_priv->agent,
mad_recv_wc); mad_recv_wc);
if (atomic_dec_and_test(&mad_agent_priv->refcount)) deref_mad_agent(mad_agent_priv);
wake_up(&mad_agent_priv->wait);
} }
} }
...@@ -2021,8 +2025,7 @@ void ib_mad_complete_send_wr(struct ib_mad_send_wr_private *mad_send_wr, ...@@ -2021,8 +2025,7 @@ void ib_mad_complete_send_wr(struct ib_mad_send_wr_private *mad_send_wr,
mad_send_wc); mad_send_wc);
/* Release reference on agent taken when sending */ /* Release reference on agent taken when sending */
if (atomic_dec_and_test(&mad_agent_priv->refcount)) deref_mad_agent(mad_agent_priv);
wake_up(&mad_agent_priv->wait);
return; return;
done: done:
spin_unlock_irqrestore(&mad_agent_priv->lock, flags); spin_unlock_irqrestore(&mad_agent_priv->lock, flags);
......
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
#ifndef __IB_MAD_PRIV_H__ #ifndef __IB_MAD_PRIV_H__
#define __IB_MAD_PRIV_H__ #define __IB_MAD_PRIV_H__
#include <linux/completion.h>
#include <linux/pci.h> #include <linux/pci.h>
#include <linux/kthread.h> #include <linux/kthread.h>
#include <linux/workqueue.h> #include <linux/workqueue.h>
...@@ -108,7 +109,7 @@ struct ib_mad_agent_private { ...@@ -108,7 +109,7 @@ struct ib_mad_agent_private {
struct list_head rmpp_list; struct list_head rmpp_list;
atomic_t refcount; atomic_t refcount;
wait_queue_head_t wait; struct completion comp;
}; };
struct ib_mad_snoop_private { struct ib_mad_snoop_private {
...@@ -117,7 +118,7 @@ struct ib_mad_snoop_private { ...@@ -117,7 +118,7 @@ struct ib_mad_snoop_private {
int snoop_index; int snoop_index;
int mad_snoop_flags; int mad_snoop_flags;
atomic_t refcount; atomic_t refcount;
wait_queue_head_t wait; struct completion comp;
}; };
struct ib_mad_send_wr_private { struct ib_mad_send_wr_private {
......
...@@ -49,7 +49,7 @@ struct mad_rmpp_recv { ...@@ -49,7 +49,7 @@ struct mad_rmpp_recv {
struct list_head list; struct list_head list;
struct work_struct timeout_work; struct work_struct timeout_work;
struct work_struct cleanup_work; struct work_struct cleanup_work;
wait_queue_head_t wait; struct completion comp;
enum rmpp_state state; enum rmpp_state state;
spinlock_t lock; spinlock_t lock;
atomic_t refcount; atomic_t refcount;
...@@ -69,10 +69,16 @@ struct mad_rmpp_recv { ...@@ -69,10 +69,16 @@ struct mad_rmpp_recv {
u8 method; u8 method;
}; };
static inline void deref_rmpp_recv(struct mad_rmpp_recv *rmpp_recv)
{
if (atomic_dec_and_test(&rmpp_recv->refcount))
complete(&rmpp_recv->comp);
}
static void destroy_rmpp_recv(struct mad_rmpp_recv *rmpp_recv) static void destroy_rmpp_recv(struct mad_rmpp_recv *rmpp_recv)
{ {
atomic_dec(&rmpp_recv->refcount); deref_rmpp_recv(rmpp_recv);
wait_event(rmpp_recv->wait, !atomic_read(&rmpp_recv->refcount)); wait_for_completion(&rmpp_recv->comp);
ib_destroy_ah(rmpp_recv->ah); ib_destroy_ah(rmpp_recv->ah);
kfree(rmpp_recv); kfree(rmpp_recv);
} }
...@@ -253,7 +259,7 @@ create_rmpp_recv(struct ib_mad_agent_private *agent, ...@@ -253,7 +259,7 @@ create_rmpp_recv(struct ib_mad_agent_private *agent,
goto error; goto error;
rmpp_recv->agent = agent; rmpp_recv->agent = agent;
init_waitqueue_head(&rmpp_recv->wait); init_completion(&rmpp_recv->comp);
INIT_WORK(&rmpp_recv->timeout_work, recv_timeout_handler, rmpp_recv); INIT_WORK(&rmpp_recv->timeout_work, recv_timeout_handler, rmpp_recv);
INIT_WORK(&rmpp_recv->cleanup_work, recv_cleanup_handler, rmpp_recv); INIT_WORK(&rmpp_recv->cleanup_work, recv_cleanup_handler, rmpp_recv);
spin_lock_init(&rmpp_recv->lock); spin_lock_init(&rmpp_recv->lock);
...@@ -279,12 +285,6 @@ error: kfree(rmpp_recv); ...@@ -279,12 +285,6 @@ error: kfree(rmpp_recv);
return NULL; return NULL;
} }
static inline void deref_rmpp_recv(struct mad_rmpp_recv *rmpp_recv)
{
if (atomic_dec_and_test(&rmpp_recv->refcount))
wake_up(&rmpp_recv->wait);
}
static struct mad_rmpp_recv * static struct mad_rmpp_recv *
find_rmpp_recv(struct ib_mad_agent_private *agent, find_rmpp_recv(struct ib_mad_agent_private *agent,
struct ib_mad_recv_wc *mad_recv_wc) struct ib_mad_recv_wc *mad_recv_wc)
......
...@@ -32,6 +32,8 @@ ...@@ -32,6 +32,8 @@
* *
* $Id: ucm.c 2594 2005-06-13 19:46:02Z libor $ * $Id: ucm.c 2594 2005-06-13 19:46:02Z libor $
*/ */
#include <linux/completion.h>
#include <linux/init.h> #include <linux/init.h>
#include <linux/fs.h> #include <linux/fs.h>
#include <linux/module.h> #include <linux/module.h>
...@@ -72,7 +74,7 @@ struct ib_ucm_file { ...@@ -72,7 +74,7 @@ struct ib_ucm_file {
struct ib_ucm_context { struct ib_ucm_context {
int id; int id;
wait_queue_head_t wait; struct completion comp;
atomic_t ref; atomic_t ref;
int events_reported; int events_reported;
...@@ -138,7 +140,7 @@ static struct ib_ucm_context *ib_ucm_ctx_get(struct ib_ucm_file *file, int id) ...@@ -138,7 +140,7 @@ static struct ib_ucm_context *ib_ucm_ctx_get(struct ib_ucm_file *file, int id)
static void ib_ucm_ctx_put(struct ib_ucm_context *ctx) static void ib_ucm_ctx_put(struct ib_ucm_context *ctx)
{ {
if (atomic_dec_and_test(&ctx->ref)) if (atomic_dec_and_test(&ctx->ref))
wake_up(&ctx->wait); complete(&ctx->comp);
} }
static inline int ib_ucm_new_cm_id(int event) static inline int ib_ucm_new_cm_id(int event)
...@@ -178,7 +180,7 @@ static struct ib_ucm_context *ib_ucm_ctx_alloc(struct ib_ucm_file *file) ...@@ -178,7 +180,7 @@ static struct ib_ucm_context *ib_ucm_ctx_alloc(struct ib_ucm_file *file)
return NULL; return NULL;
atomic_set(&ctx->ref, 1); atomic_set(&ctx->ref, 1);
init_waitqueue_head(&ctx->wait); init_completion(&ctx->comp);
ctx->file = file; ctx->file = file;
INIT_LIST_HEAD(&ctx->events); INIT_LIST_HEAD(&ctx->events);
...@@ -586,8 +588,8 @@ static ssize_t ib_ucm_destroy_id(struct ib_ucm_file *file, ...@@ -586,8 +588,8 @@ static ssize_t ib_ucm_destroy_id(struct ib_ucm_file *file,
if (IS_ERR(ctx)) if (IS_ERR(ctx))
return PTR_ERR(ctx); return PTR_ERR(ctx);
atomic_dec(&ctx->ref); ib_ucm_ctx_put(ctx);
wait_event(ctx->wait, !atomic_read(&ctx->ref)); wait_for_completion(&ctx->comp);
/* No new events will be generated after destroying the cm_id. */ /* No new events will be generated after destroying the cm_id. */
ib_destroy_cm_id(ctx->cm_id); ib_destroy_cm_id(ctx->cm_id);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment