Commit 63221acb authored by Bob Pearson's avatar Bob Pearson Committed by Jason Gunthorpe

RDMA/rxe: Fix ref error in rxe_av.c

The commit referenced below can take a reference to the AH which is never
dropped. This only happens in the UD request path. This patch optionally
passes that AH back to the caller so that it can hold the reference while
the AV is being accessed and then drop it. Code to do this is added to
rxe_req.c. The AV is also passed to rxe_prepare in rxe_net.c as an
optimization.

Fixes: e2fe06c9 ("RDMA/rxe: Lookup kernel AH from ah index in UD WQEs")
Link: https://lore.kernel.org/r/20220304000808.225811-2-rpearsonhpe@gmail.comSigned-off-by: default avatarBob Pearson <rpearsonhpe@gmail.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@nvidia.com>
parent 70f92521
...@@ -99,11 +99,14 @@ void rxe_av_fill_ip_info(struct rxe_av *av, struct rdma_ah_attr *attr) ...@@ -99,11 +99,14 @@ void rxe_av_fill_ip_info(struct rxe_av *av, struct rdma_ah_attr *attr)
av->network_type = type; av->network_type = type;
} }
struct rxe_av *rxe_get_av(struct rxe_pkt_info *pkt) struct rxe_av *rxe_get_av(struct rxe_pkt_info *pkt, struct rxe_ah **ahp)
{ {
struct rxe_ah *ah; struct rxe_ah *ah;
u32 ah_num; u32 ah_num;
if (ahp)
*ahp = NULL;
if (!pkt || !pkt->qp) if (!pkt || !pkt->qp)
return NULL; return NULL;
...@@ -117,10 +120,22 @@ struct rxe_av *rxe_get_av(struct rxe_pkt_info *pkt) ...@@ -117,10 +120,22 @@ struct rxe_av *rxe_get_av(struct rxe_pkt_info *pkt)
if (ah_num) { if (ah_num) {
/* only new user provider or kernel client */ /* only new user provider or kernel client */
ah = rxe_pool_get_index(&pkt->rxe->ah_pool, ah_num); ah = rxe_pool_get_index(&pkt->rxe->ah_pool, ah_num);
if (!ah || ah->ah_num != ah_num || rxe_ah_pd(ah) != pkt->qp->pd) { if (!ah) {
pr_warn("Unable to find AH matching ah_num\n"); pr_warn("Unable to find AH matching ah_num\n");
return NULL; return NULL;
} }
if (rxe_ah_pd(ah) != pkt->qp->pd) {
pr_warn("PDs don't match for AH and QP\n");
rxe_drop_ref(ah);
return NULL;
}
if (ahp)
*ahp = ah;
else
rxe_drop_ref(ah);
return &ah->av; return &ah->av;
} }
......
...@@ -19,7 +19,7 @@ void rxe_av_to_attr(struct rxe_av *av, struct rdma_ah_attr *attr); ...@@ -19,7 +19,7 @@ void rxe_av_to_attr(struct rxe_av *av, struct rdma_ah_attr *attr);
void rxe_av_fill_ip_info(struct rxe_av *av, struct rdma_ah_attr *attr); void rxe_av_fill_ip_info(struct rxe_av *av, struct rdma_ah_attr *attr);
struct rxe_av *rxe_get_av(struct rxe_pkt_info *pkt); struct rxe_av *rxe_get_av(struct rxe_pkt_info *pkt, struct rxe_ah **ahp);
/* rxe_cq.c */ /* rxe_cq.c */
int rxe_cq_chk_attr(struct rxe_dev *rxe, struct rxe_cq *cq, int rxe_cq_chk_attr(struct rxe_dev *rxe, struct rxe_cq *cq,
...@@ -94,7 +94,8 @@ void rxe_mw_cleanup(struct rxe_pool_elem *arg); ...@@ -94,7 +94,8 @@ void rxe_mw_cleanup(struct rxe_pool_elem *arg);
/* rxe_net.c */ /* rxe_net.c */
struct sk_buff *rxe_init_packet(struct rxe_dev *rxe, struct rxe_av *av, struct sk_buff *rxe_init_packet(struct rxe_dev *rxe, struct rxe_av *av,
int paylen, struct rxe_pkt_info *pkt); int paylen, struct rxe_pkt_info *pkt);
int rxe_prepare(struct rxe_pkt_info *pkt, struct sk_buff *skb); int rxe_prepare(struct rxe_av *av, struct rxe_pkt_info *pkt,
struct sk_buff *skb);
int rxe_xmit_packet(struct rxe_qp *qp, struct rxe_pkt_info *pkt, int rxe_xmit_packet(struct rxe_qp *qp, struct rxe_pkt_info *pkt,
struct sk_buff *skb); struct sk_buff *skb);
const char *rxe_parent_name(struct rxe_dev *rxe, unsigned int port_num); const char *rxe_parent_name(struct rxe_dev *rxe, unsigned int port_num);
......
...@@ -271,13 +271,13 @@ static void prepare_ipv6_hdr(struct dst_entry *dst, struct sk_buff *skb, ...@@ -271,13 +271,13 @@ static void prepare_ipv6_hdr(struct dst_entry *dst, struct sk_buff *skb,
ip6h->payload_len = htons(skb->len - sizeof(*ip6h)); ip6h->payload_len = htons(skb->len - sizeof(*ip6h));
} }
static int prepare4(struct rxe_pkt_info *pkt, struct sk_buff *skb) static int prepare4(struct rxe_av *av, struct rxe_pkt_info *pkt,
struct sk_buff *skb)
{ {
struct rxe_qp *qp = pkt->qp; struct rxe_qp *qp = pkt->qp;
struct dst_entry *dst; struct dst_entry *dst;
bool xnet = false; bool xnet = false;
__be16 df = htons(IP_DF); __be16 df = htons(IP_DF);
struct rxe_av *av = rxe_get_av(pkt);
struct in_addr *saddr = &av->sgid_addr._sockaddr_in.sin_addr; struct in_addr *saddr = &av->sgid_addr._sockaddr_in.sin_addr;
struct in_addr *daddr = &av->dgid_addr._sockaddr_in.sin_addr; struct in_addr *daddr = &av->dgid_addr._sockaddr_in.sin_addr;
...@@ -297,11 +297,11 @@ static int prepare4(struct rxe_pkt_info *pkt, struct sk_buff *skb) ...@@ -297,11 +297,11 @@ static int prepare4(struct rxe_pkt_info *pkt, struct sk_buff *skb)
return 0; return 0;
} }
static int prepare6(struct rxe_pkt_info *pkt, struct sk_buff *skb) static int prepare6(struct rxe_av *av, struct rxe_pkt_info *pkt,
struct sk_buff *skb)
{ {
struct rxe_qp *qp = pkt->qp; struct rxe_qp *qp = pkt->qp;
struct dst_entry *dst; struct dst_entry *dst;
struct rxe_av *av = rxe_get_av(pkt);
struct in6_addr *saddr = &av->sgid_addr._sockaddr_in6.sin6_addr; struct in6_addr *saddr = &av->sgid_addr._sockaddr_in6.sin6_addr;
struct in6_addr *daddr = &av->dgid_addr._sockaddr_in6.sin6_addr; struct in6_addr *daddr = &av->dgid_addr._sockaddr_in6.sin6_addr;
...@@ -322,16 +322,17 @@ static int prepare6(struct rxe_pkt_info *pkt, struct sk_buff *skb) ...@@ -322,16 +322,17 @@ static int prepare6(struct rxe_pkt_info *pkt, struct sk_buff *skb)
return 0; return 0;
} }
int rxe_prepare(struct rxe_pkt_info *pkt, struct sk_buff *skb) int rxe_prepare(struct rxe_av *av, struct rxe_pkt_info *pkt,
struct sk_buff *skb)
{ {
int err = 0; int err = 0;
if (skb->protocol == htons(ETH_P_IP)) if (skb->protocol == htons(ETH_P_IP))
err = prepare4(pkt, skb); err = prepare4(av, pkt, skb);
else if (skb->protocol == htons(ETH_P_IPV6)) else if (skb->protocol == htons(ETH_P_IPV6))
err = prepare6(pkt, skb); err = prepare6(av, pkt, skb);
if (ether_addr_equal(skb->dev->dev_addr, rxe_get_av(pkt)->dmac)) if (ether_addr_equal(skb->dev->dev_addr, av->dmac))
pkt->mask |= RXE_LOOPBACK_MASK; pkt->mask |= RXE_LOOPBACK_MASK;
return err; return err;
......
...@@ -358,6 +358,7 @@ static inline int get_mtu(struct rxe_qp *qp) ...@@ -358,6 +358,7 @@ static inline int get_mtu(struct rxe_qp *qp)
} }
static struct sk_buff *init_req_packet(struct rxe_qp *qp, static struct sk_buff *init_req_packet(struct rxe_qp *qp,
struct rxe_av *av,
struct rxe_send_wqe *wqe, struct rxe_send_wqe *wqe,
int opcode, u32 payload, int opcode, u32 payload,
struct rxe_pkt_info *pkt) struct rxe_pkt_info *pkt)
...@@ -365,7 +366,6 @@ static struct sk_buff *init_req_packet(struct rxe_qp *qp, ...@@ -365,7 +366,6 @@ static struct sk_buff *init_req_packet(struct rxe_qp *qp,
struct rxe_dev *rxe = to_rdev(qp->ibqp.device); struct rxe_dev *rxe = to_rdev(qp->ibqp.device);
struct sk_buff *skb; struct sk_buff *skb;
struct rxe_send_wr *ibwr = &wqe->wr; struct rxe_send_wr *ibwr = &wqe->wr;
struct rxe_av *av;
int pad = (-payload) & 0x3; int pad = (-payload) & 0x3;
int paylen; int paylen;
int solicited; int solicited;
...@@ -374,21 +374,9 @@ static struct sk_buff *init_req_packet(struct rxe_qp *qp, ...@@ -374,21 +374,9 @@ static struct sk_buff *init_req_packet(struct rxe_qp *qp,
/* length from start of bth to end of icrc */ /* length from start of bth to end of icrc */
paylen = rxe_opcode[opcode].length + payload + pad + RXE_ICRC_SIZE; paylen = rxe_opcode[opcode].length + payload + pad + RXE_ICRC_SIZE;
pkt->paylen = paylen;
/* pkt->hdr, port_num and mask are initialized in ifc layer */
pkt->rxe = rxe;
pkt->opcode = opcode;
pkt->qp = qp;
pkt->psn = qp->req.psn;
pkt->mask = rxe_opcode[opcode].mask;
pkt->paylen = paylen;
pkt->wqe = wqe;
/* init skb */ /* init skb */
av = rxe_get_av(pkt);
if (!av)
return NULL;
skb = rxe_init_packet(rxe, av, paylen, pkt); skb = rxe_init_packet(rxe, av, paylen, pkt);
if (unlikely(!skb)) if (unlikely(!skb))
return NULL; return NULL;
...@@ -447,13 +435,13 @@ static struct sk_buff *init_req_packet(struct rxe_qp *qp, ...@@ -447,13 +435,13 @@ static struct sk_buff *init_req_packet(struct rxe_qp *qp,
return skb; return skb;
} }
static int finish_packet(struct rxe_qp *qp, struct rxe_send_wqe *wqe, static int finish_packet(struct rxe_qp *qp, struct rxe_av *av,
struct rxe_pkt_info *pkt, struct sk_buff *skb, struct rxe_send_wqe *wqe, struct rxe_pkt_info *pkt,
u32 paylen) struct sk_buff *skb, u32 paylen)
{ {
int err; int err;
err = rxe_prepare(pkt, skb); err = rxe_prepare(av, pkt, skb);
if (err) if (err)
return err; return err;
...@@ -608,6 +596,7 @@ static int rxe_do_local_ops(struct rxe_qp *qp, struct rxe_send_wqe *wqe) ...@@ -608,6 +596,7 @@ static int rxe_do_local_ops(struct rxe_qp *qp, struct rxe_send_wqe *wqe)
int rxe_requester(void *arg) int rxe_requester(void *arg)
{ {
struct rxe_qp *qp = (struct rxe_qp *)arg; struct rxe_qp *qp = (struct rxe_qp *)arg;
struct rxe_dev *rxe = to_rdev(qp->ibqp.device);
struct rxe_pkt_info pkt; struct rxe_pkt_info pkt;
struct sk_buff *skb; struct sk_buff *skb;
struct rxe_send_wqe *wqe; struct rxe_send_wqe *wqe;
...@@ -619,6 +608,8 @@ int rxe_requester(void *arg) ...@@ -619,6 +608,8 @@ int rxe_requester(void *arg)
struct rxe_send_wqe rollback_wqe; struct rxe_send_wqe rollback_wqe;
u32 rollback_psn; u32 rollback_psn;
struct rxe_queue *q = qp->sq.queue; struct rxe_queue *q = qp->sq.queue;
struct rxe_ah *ah;
struct rxe_av *av;
rxe_add_ref(qp); rxe_add_ref(qp);
...@@ -705,14 +696,28 @@ int rxe_requester(void *arg) ...@@ -705,14 +696,28 @@ int rxe_requester(void *arg)
payload = mtu; payload = mtu;
} }
skb = init_req_packet(qp, wqe, opcode, payload, &pkt); pkt.rxe = rxe;
pkt.opcode = opcode;
pkt.qp = qp;
pkt.psn = qp->req.psn;
pkt.mask = rxe_opcode[opcode].mask;
pkt.wqe = wqe;
av = rxe_get_av(&pkt, &ah);
if (unlikely(!av)) {
pr_err("qp#%d Failed no address vector\n", qp_num(qp));
wqe->status = IB_WC_LOC_QP_OP_ERR;
goto err_drop_ah;
}
skb = init_req_packet(qp, av, wqe, opcode, payload, &pkt);
if (unlikely(!skb)) { if (unlikely(!skb)) {
pr_err("qp#%d Failed allocating skb\n", qp_num(qp)); pr_err("qp#%d Failed allocating skb\n", qp_num(qp));
wqe->status = IB_WC_LOC_QP_OP_ERR; wqe->status = IB_WC_LOC_QP_OP_ERR;
goto err; goto err_drop_ah;
} }
ret = finish_packet(qp, wqe, &pkt, skb, payload); ret = finish_packet(qp, av, wqe, &pkt, skb, payload);
if (unlikely(ret)) { if (unlikely(ret)) {
pr_debug("qp#%d Error during finish packet\n", qp_num(qp)); pr_debug("qp#%d Error during finish packet\n", qp_num(qp));
if (ret == -EFAULT) if (ret == -EFAULT)
...@@ -720,9 +725,12 @@ int rxe_requester(void *arg) ...@@ -720,9 +725,12 @@ int rxe_requester(void *arg)
else else
wqe->status = IB_WC_LOC_QP_OP_ERR; wqe->status = IB_WC_LOC_QP_OP_ERR;
kfree_skb(skb); kfree_skb(skb);
goto err; goto err_drop_ah;
} }
if (ah)
rxe_drop_ref(ah);
/* /*
* To prevent a race on wqe access between requester and completer, * To prevent a race on wqe access between requester and completer,
* wqe members state and psn need to be set before calling * wqe members state and psn need to be set before calling
...@@ -751,6 +759,9 @@ int rxe_requester(void *arg) ...@@ -751,6 +759,9 @@ int rxe_requester(void *arg)
goto next_wqe; goto next_wqe;
err_drop_ah:
if (ah)
rxe_drop_ref(ah);
err: err:
wqe->state = wqe_state_error; wqe->state = wqe_state_error;
__rxe_do_task(&qp->comp.task); __rxe_do_task(&qp->comp.task);
......
...@@ -633,7 +633,7 @@ static struct sk_buff *prepare_ack_packet(struct rxe_qp *qp, ...@@ -633,7 +633,7 @@ static struct sk_buff *prepare_ack_packet(struct rxe_qp *qp,
if (ack->mask & RXE_ATMACK_MASK) if (ack->mask & RXE_ATMACK_MASK)
atmack_set_orig(ack, qp->resp.atomic_orig); atmack_set_orig(ack, qp->resp.atomic_orig);
err = rxe_prepare(ack, skb); err = rxe_prepare(&qp->pri_av, ack, skb);
if (err) { if (err) {
kfree_skb(skb); kfree_skb(skb);
return NULL; return NULL;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment