Commit 2cc5e4ca authored by David S. Miller's avatar David S. Miller

Merge branch 'sctp-transport-races'

Xin Long says:

====================
fix the transport dead race check by using atomic_add_unless on refcnt

  sctp: fix the transport dead race check by using atomic_add_unless on
    refcnt
  sctp: hold transport before we access t->asoc in sctp proc
  sctp: remove the dead field of sctp_transport
====================
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 2baaa2d1 47faa1e4
...@@ -756,7 +756,6 @@ struct sctp_transport { ...@@ -756,7 +756,6 @@ struct sctp_transport {
/* Reference counting. */ /* Reference counting. */
atomic_t refcnt; atomic_t refcnt;
__u32 dead:1,
/* RTO-Pending : A flag used to track if one of the DATA /* RTO-Pending : A flag used to track if one of the DATA
* chunks sent to this address is currently being * chunks sent to this address is currently being
* used to compute a RTT. If this flag is 0, * used to compute a RTT. If this flag is 0,
...@@ -766,7 +765,7 @@ struct sctp_transport { ...@@ -766,7 +765,7 @@ struct sctp_transport {
* calculation completes (i.e. the DATA chunk * calculation completes (i.e. the DATA chunk
* is SACK'd) clear this flag. * is SACK'd) clear this flag.
*/ */
rto_pending:1, __u32 rto_pending:1,
/* /*
* hb_sent : a flag that signals that we have a pending * hb_sent : a flag that signals that we have a pending
...@@ -955,7 +954,7 @@ void sctp_transport_route(struct sctp_transport *, union sctp_addr *, ...@@ -955,7 +954,7 @@ void sctp_transport_route(struct sctp_transport *, union sctp_addr *,
void sctp_transport_pmtu(struct sctp_transport *, struct sock *sk); void sctp_transport_pmtu(struct sctp_transport *, struct sock *sk);
void sctp_transport_free(struct sctp_transport *); void sctp_transport_free(struct sctp_transport *);
void sctp_transport_reset_timers(struct sctp_transport *); void sctp_transport_reset_timers(struct sctp_transport *);
void sctp_transport_hold(struct sctp_transport *); int sctp_transport_hold(struct sctp_transport *);
void sctp_transport_put(struct sctp_transport *); void sctp_transport_put(struct sctp_transport *);
void sctp_transport_update_rto(struct sctp_transport *, __u32); void sctp_transport_update_rto(struct sctp_transport *, __u32);
void sctp_transport_raise_cwnd(struct sctp_transport *, __u32, __u32); void sctp_transport_raise_cwnd(struct sctp_transport *, __u32, __u32);
......
...@@ -935,15 +935,22 @@ static struct sctp_association *__sctp_lookup_association( ...@@ -935,15 +935,22 @@ static struct sctp_association *__sctp_lookup_association(
struct sctp_transport **pt) struct sctp_transport **pt)
{ {
struct sctp_transport *t; struct sctp_transport *t;
struct sctp_association *asoc = NULL;
rcu_read_lock();
t = sctp_addrs_lookup_transport(net, local, peer); t = sctp_addrs_lookup_transport(net, local, peer);
if (!t || t->dead) if (!t || !sctp_transport_hold(t))
return NULL; goto out;
sctp_association_hold(t->asoc); asoc = t->asoc;
sctp_association_hold(asoc);
*pt = t; *pt = t;
return t->asoc; sctp_transport_put(t);
out:
rcu_read_unlock();
return asoc;
} }
/* Look up an association. protected by RCU read lock */ /* Look up an association. protected by RCU read lock */
...@@ -955,9 +962,7 @@ struct sctp_association *sctp_lookup_association(struct net *net, ...@@ -955,9 +962,7 @@ struct sctp_association *sctp_lookup_association(struct net *net,
{ {
struct sctp_association *asoc; struct sctp_association *asoc;
rcu_read_lock();
asoc = __sctp_lookup_association(net, laddr, paddr, transportp); asoc = __sctp_lookup_association(net, laddr, paddr, transportp);
rcu_read_unlock();
return asoc; return asoc;
} }
......
...@@ -165,8 +165,6 @@ static void sctp_seq_dump_remote_addrs(struct seq_file *seq, struct sctp_associa ...@@ -165,8 +165,6 @@ static void sctp_seq_dump_remote_addrs(struct seq_file *seq, struct sctp_associa
list_for_each_entry_rcu(transport, &assoc->peer.transport_addr_list, list_for_each_entry_rcu(transport, &assoc->peer.transport_addr_list,
transports) { transports) {
addr = &transport->ipaddr; addr = &transport->ipaddr;
if (transport->dead)
continue;
af = sctp_get_af_specific(addr->sa.sa_family); af = sctp_get_af_specific(addr->sa.sa_family);
if (af->cmp_addr(addr, primary)) { if (af->cmp_addr(addr, primary)) {
...@@ -380,6 +378,8 @@ static int sctp_assocs_seq_show(struct seq_file *seq, void *v) ...@@ -380,6 +378,8 @@ static int sctp_assocs_seq_show(struct seq_file *seq, void *v)
} }
transport = (struct sctp_transport *)v; transport = (struct sctp_transport *)v;
if (!sctp_transport_hold(transport))
return 0;
assoc = transport->asoc; assoc = transport->asoc;
epb = &assoc->base; epb = &assoc->base;
sk = epb->sk; sk = epb->sk;
...@@ -412,6 +412,8 @@ static int sctp_assocs_seq_show(struct seq_file *seq, void *v) ...@@ -412,6 +412,8 @@ static int sctp_assocs_seq_show(struct seq_file *seq, void *v)
sk->sk_rcvbuf); sk->sk_rcvbuf);
seq_printf(seq, "\n"); seq_printf(seq, "\n");
sctp_transport_put(transport);
return 0; return 0;
} }
...@@ -489,12 +491,12 @@ static int sctp_remaddr_seq_show(struct seq_file *seq, void *v) ...@@ -489,12 +491,12 @@ static int sctp_remaddr_seq_show(struct seq_file *seq, void *v)
} }
tsp = (struct sctp_transport *)v; tsp = (struct sctp_transport *)v;
if (!sctp_transport_hold(tsp))
return 0;
assoc = tsp->asoc; assoc = tsp->asoc;
list_for_each_entry_rcu(tsp, &assoc->peer.transport_addr_list, list_for_each_entry_rcu(tsp, &assoc->peer.transport_addr_list,
transports) { transports) {
if (tsp->dead)
continue;
/* /*
* The remote address (ADDR) * The remote address (ADDR)
*/ */
...@@ -544,6 +546,8 @@ static int sctp_remaddr_seq_show(struct seq_file *seq, void *v) ...@@ -544,6 +546,8 @@ static int sctp_remaddr_seq_show(struct seq_file *seq, void *v)
seq_printf(seq, "\n"); seq_printf(seq, "\n");
} }
sctp_transport_put(tsp);
return 0; return 0;
} }
......
...@@ -259,12 +259,6 @@ void sctp_generate_t3_rtx_event(unsigned long peer) ...@@ -259,12 +259,6 @@ void sctp_generate_t3_rtx_event(unsigned long peer)
goto out_unlock; goto out_unlock;
} }
/* Is this transport really dead and just waiting around for
* the timer to let go of the reference?
*/
if (transport->dead)
goto out_unlock;
/* Run through the state machine. */ /* Run through the state machine. */
error = sctp_do_sm(net, SCTP_EVENT_T_TIMEOUT, error = sctp_do_sm(net, SCTP_EVENT_T_TIMEOUT,
SCTP_ST_TIMEOUT(SCTP_EVENT_TIMEOUT_T3_RTX), SCTP_ST_TIMEOUT(SCTP_EVENT_TIMEOUT_T3_RTX),
...@@ -380,12 +374,6 @@ void sctp_generate_heartbeat_event(unsigned long data) ...@@ -380,12 +374,6 @@ void sctp_generate_heartbeat_event(unsigned long data)
goto out_unlock; goto out_unlock;
} }
/* Is this structure just waiting around for us to actually
* get destroyed?
*/
if (transport->dead)
goto out_unlock;
error = sctp_do_sm(net, SCTP_EVENT_T_TIMEOUT, error = sctp_do_sm(net, SCTP_EVENT_T_TIMEOUT,
SCTP_ST_TIMEOUT(SCTP_EVENT_TIMEOUT_HEARTBEAT), SCTP_ST_TIMEOUT(SCTP_EVENT_TIMEOUT_HEARTBEAT),
asoc->state, asoc->ep, asoc, asoc->state, asoc->ep, asoc,
......
...@@ -132,8 +132,6 @@ struct sctp_transport *sctp_transport_new(struct net *net, ...@@ -132,8 +132,6 @@ struct sctp_transport *sctp_transport_new(struct net *net,
*/ */
void sctp_transport_free(struct sctp_transport *transport) void sctp_transport_free(struct sctp_transport *transport)
{ {
transport->dead = 1;
/* Try to delete the heartbeat timer. */ /* Try to delete the heartbeat timer. */
if (del_timer(&transport->hb_timer)) if (del_timer(&transport->hb_timer))
sctp_transport_put(transport); sctp_transport_put(transport);
...@@ -169,7 +167,7 @@ static void sctp_transport_destroy_rcu(struct rcu_head *head) ...@@ -169,7 +167,7 @@ static void sctp_transport_destroy_rcu(struct rcu_head *head)
*/ */
static void sctp_transport_destroy(struct sctp_transport *transport) static void sctp_transport_destroy(struct sctp_transport *transport)
{ {
if (unlikely(!transport->dead)) { if (unlikely(atomic_read(&transport->refcnt))) {
WARN(1, "Attempt to destroy undead transport %p!\n", transport); WARN(1, "Attempt to destroy undead transport %p!\n", transport);
return; return;
} }
...@@ -296,9 +294,9 @@ void sctp_transport_route(struct sctp_transport *transport, ...@@ -296,9 +294,9 @@ void sctp_transport_route(struct sctp_transport *transport,
} }
/* Hold a reference to a transport. */ /* Hold a reference to a transport. */
void sctp_transport_hold(struct sctp_transport *transport) int sctp_transport_hold(struct sctp_transport *transport)
{ {
atomic_inc(&transport->refcnt); return atomic_add_unless(&transport->refcnt, 1, 0);
} }
/* Release a reference to a transport and clean up /* Release a reference to a transport and clean up
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment