Commit bbf73830 authored by Vlad Buslov's avatar Vlad Buslov Committed by David S. Miller

net: sched: traverse chains in block with tcf_get_next_chain()

All users of block->chain_list rely on rtnl lock and assume that no new
chains are added when traversing the list. Use tcf_get_next_chain() to
traverse chain list without relying on rtnl mutex. This function iterates
over chains by taking reference to current iterator chain only and doesn't
assume external synchronization of chain list.

Don't take reference to all chains in block when flushing and use
tcf_get_next_chain() to safely iterate over chain list instead. Remove
tcf_block_put_all_chains() that is no longer used.
Signed-off-by: default avatarVlad Buslov <vladbu@mellanox.com>
Acked-by: default avatarJiri Pirko <jiri@mellanox.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 165f0135
...@@ -44,6 +44,8 @@ bool tcf_queue_work(struct rcu_work *rwork, work_func_t func); ...@@ -44,6 +44,8 @@ bool tcf_queue_work(struct rcu_work *rwork, work_func_t func);
struct tcf_chain *tcf_chain_get_by_act(struct tcf_block *block, struct tcf_chain *tcf_chain_get_by_act(struct tcf_block *block,
u32 chain_index); u32 chain_index);
void tcf_chain_put_by_act(struct tcf_chain *chain); void tcf_chain_put_by_act(struct tcf_chain *chain);
struct tcf_chain *tcf_get_next_chain(struct tcf_block *block,
struct tcf_chain *chain);
void tcf_block_netif_keep_dst(struct tcf_block *block); void tcf_block_netif_keep_dst(struct tcf_block *block);
int tcf_block_get(struct tcf_block **p_block, int tcf_block_get(struct tcf_block **p_block,
struct tcf_proto __rcu **p_filter_chain, struct Qdisc *q, struct tcf_proto __rcu **p_filter_chain, struct Qdisc *q,
......
...@@ -883,28 +883,62 @@ static struct tcf_block *tcf_block_refcnt_get(struct net *net, u32 block_index) ...@@ -883,28 +883,62 @@ static struct tcf_block *tcf_block_refcnt_get(struct net *net, u32 block_index)
return block; return block;
} }
static void tcf_block_flush_all_chains(struct tcf_block *block) static struct tcf_chain *
__tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain)
{ {
struct tcf_chain *chain; mutex_lock(&block->lock);
if (chain)
chain = list_is_last(&chain->list, &block->chain_list) ?
NULL : list_next_entry(chain, list);
else
chain = list_first_entry_or_null(&block->chain_list,
struct tcf_chain, list);
/* Hold a refcnt for all chains, so that they don't disappear /* skip all action-only chains */
* while we are iterating. while (chain && tcf_chain_held_by_acts_only(chain))
*/ chain = list_is_last(&chain->list, &block->chain_list) ?
list_for_each_entry(chain, &block->chain_list, list) NULL : list_next_entry(chain, list);
if (chain)
tcf_chain_hold(chain); tcf_chain_hold(chain);
mutex_unlock(&block->lock);
list_for_each_entry(chain, &block->chain_list, list) return chain;
tcf_chain_flush(chain);
} }
static void tcf_block_put_all_chains(struct tcf_block *block) /* Function to be used by all clients that want to iterate over all chains on
* block. It properly obtains block->lock and takes reference to chain before
* returning it. Users of this function must be tolerant to concurrent chain
* insertion/deletion or ensure that no concurrent chain modification is
* possible. Note that all netlink dump callbacks cannot guarantee to provide
* consistent dump because rtnl lock is released each time skb is filled with
* data and sent to user-space.
*/
struct tcf_chain *
tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain)
{ {
struct tcf_chain *chain, *tmp; struct tcf_chain *chain_next = __tcf_get_next_chain(block, chain);
/* At this point, all the chains should have refcnt >= 1. */ if (chain)
list_for_each_entry_safe(chain, tmp, &block->chain_list, list) {
tcf_chain_put_explicitly_created(chain);
tcf_chain_put(chain); tcf_chain_put(chain);
return chain_next;
}
EXPORT_SYMBOL(tcf_get_next_chain);
static void tcf_block_flush_all_chains(struct tcf_block *block)
{
struct tcf_chain *chain;
/* Last reference to block. At this point chains cannot be added or
* removed concurrently.
*/
for (chain = tcf_get_next_chain(block, NULL);
chain;
chain = tcf_get_next_chain(block, chain)) {
tcf_chain_put_explicitly_created(chain);
tcf_chain_flush(chain);
} }
} }
...@@ -923,8 +957,6 @@ static void __tcf_block_put(struct tcf_block *block, struct Qdisc *q, ...@@ -923,8 +957,6 @@ static void __tcf_block_put(struct tcf_block *block, struct Qdisc *q,
mutex_unlock(&block->lock); mutex_unlock(&block->lock);
if (tcf_block_shared(block)) if (tcf_block_shared(block))
tcf_block_remove(block, block->net); tcf_block_remove(block, block->net);
if (!free_block)
tcf_block_flush_all_chains(block);
if (q) if (q)
tcf_block_offload_unbind(block, q, ei); tcf_block_offload_unbind(block, q, ei);
...@@ -932,7 +964,7 @@ static void __tcf_block_put(struct tcf_block *block, struct Qdisc *q, ...@@ -932,7 +964,7 @@ static void __tcf_block_put(struct tcf_block *block, struct Qdisc *q,
if (free_block) if (free_block)
tcf_block_destroy(block); tcf_block_destroy(block);
else else
tcf_block_put_all_chains(block); tcf_block_flush_all_chains(block);
} else if (q) { } else if (q) {
tcf_block_offload_unbind(block, q, ei); tcf_block_offload_unbind(block, q, ei);
} }
...@@ -1266,11 +1298,15 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb, ...@@ -1266,11 +1298,15 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb,
void *cb_priv, bool add, bool offload_in_use, void *cb_priv, bool add, bool offload_in_use,
struct netlink_ext_ack *extack) struct netlink_ext_ack *extack)
{ {
struct tcf_chain *chain; struct tcf_chain *chain, *chain_prev;
struct tcf_proto *tp; struct tcf_proto *tp;
int err; int err;
list_for_each_entry(chain, &block->chain_list, list) { for (chain = __tcf_get_next_chain(block, NULL);
chain;
chain_prev = chain,
chain = __tcf_get_next_chain(block, chain),
tcf_chain_put(chain_prev)) {
for (tp = rtnl_dereference(chain->filter_chain); tp; for (tp = rtnl_dereference(chain->filter_chain); tp;
tp = rtnl_dereference(tp->next)) { tp = rtnl_dereference(tp->next)) {
if (tp->ops->reoffload) { if (tp->ops->reoffload) {
...@@ -1289,6 +1325,7 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb, ...@@ -1289,6 +1325,7 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb,
return 0; return 0;
err_playback_remove: err_playback_remove:
tcf_chain_put(chain);
tcf_block_playback_offloads(block, cb, cb_priv, false, offload_in_use, tcf_block_playback_offloads(block, cb, cb_priv, false, offload_in_use,
extack); extack);
return err; return err;
...@@ -2023,11 +2060,11 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, ...@@ -2023,11 +2060,11 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent,
/* called with RTNL */ /* called with RTNL */
static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb) static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb)
{ {
struct tcf_chain *chain, *chain_prev;
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
struct nlattr *tca[TCA_MAX + 1]; struct nlattr *tca[TCA_MAX + 1];
struct Qdisc *q = NULL; struct Qdisc *q = NULL;
struct tcf_block *block; struct tcf_block *block;
struct tcf_chain *chain;
struct tcmsg *tcm = nlmsg_data(cb->nlh); struct tcmsg *tcm = nlmsg_data(cb->nlh);
long index_start; long index_start;
long index; long index;
...@@ -2091,12 +2128,17 @@ static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -2091,12 +2128,17 @@ static int tc_dump_tfilter(struct sk_buff *skb, struct netlink_callback *cb)
index_start = cb->args[0]; index_start = cb->args[0];
index = 0; index = 0;
list_for_each_entry(chain, &block->chain_list, list) { for (chain = __tcf_get_next_chain(block, NULL);
chain;
chain_prev = chain,
chain = __tcf_get_next_chain(block, chain),
tcf_chain_put(chain_prev)) {
if (tca[TCA_CHAIN] && if (tca[TCA_CHAIN] &&
nla_get_u32(tca[TCA_CHAIN]) != chain->index) nla_get_u32(tca[TCA_CHAIN]) != chain->index)
continue; continue;
if (!tcf_chain_dump(chain, q, parent, skb, cb, if (!tcf_chain_dump(chain, q, parent, skb, cb,
index_start, &index)) { index_start, &index)) {
tcf_chain_put(chain);
err = -EMSGSIZE; err = -EMSGSIZE;
break; break;
} }
...@@ -2364,11 +2406,11 @@ static int tc_ctl_chain(struct sk_buff *skb, struct nlmsghdr *n, ...@@ -2364,11 +2406,11 @@ static int tc_ctl_chain(struct sk_buff *skb, struct nlmsghdr *n,
/* called with RTNL */ /* called with RTNL */
static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb)
{ {
struct tcf_chain *chain, *chain_prev;
struct net *net = sock_net(skb->sk); struct net *net = sock_net(skb->sk);
struct nlattr *tca[TCA_MAX + 1]; struct nlattr *tca[TCA_MAX + 1];
struct Qdisc *q = NULL; struct Qdisc *q = NULL;
struct tcf_block *block; struct tcf_block *block;
struct tcf_chain *chain;
struct tcmsg *tcm = nlmsg_data(cb->nlh); struct tcmsg *tcm = nlmsg_data(cb->nlh);
long index_start; long index_start;
long index; long index;
...@@ -2432,7 +2474,11 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -2432,7 +2474,11 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb)
index_start = cb->args[0]; index_start = cb->args[0];
index = 0; index = 0;
list_for_each_entry(chain, &block->chain_list, list) { for (chain = __tcf_get_next_chain(block, NULL);
chain;
chain_prev = chain,
chain = __tcf_get_next_chain(block, chain),
tcf_chain_put(chain_prev)) {
if ((tca[TCA_CHAIN] && if ((tca[TCA_CHAIN] &&
nla_get_u32(tca[TCA_CHAIN]) != chain->index)) nla_get_u32(tca[TCA_CHAIN]) != chain->index))
continue; continue;
...@@ -2440,14 +2486,14 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb) ...@@ -2440,14 +2486,14 @@ static int tc_dump_chain(struct sk_buff *skb, struct netlink_callback *cb)
index++; index++;
continue; continue;
} }
if (tcf_chain_held_by_acts_only(chain))
continue;
err = tc_chain_fill_node(chain, net, skb, block, err = tc_chain_fill_node(chain, net, skb, block,
NETLINK_CB(cb->skb).portid, NETLINK_CB(cb->skb).portid,
cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh->nlmsg_seq, NLM_F_MULTI,
RTM_NEWCHAIN); RTM_NEWCHAIN);
if (err <= 0) if (err <= 0) {
tcf_chain_put(chain);
break; break;
}
index++; index++;
} }
......
...@@ -1909,7 +1909,9 @@ static void tc_bind_tclass(struct Qdisc *q, u32 portid, u32 clid, ...@@ -1909,7 +1909,9 @@ static void tc_bind_tclass(struct Qdisc *q, u32 portid, u32 clid,
block = cops->tcf_block(q, cl, NULL); block = cops->tcf_block(q, cl, NULL);
if (!block) if (!block)
return; return;
list_for_each_entry(chain, &block->chain_list, list) { for (chain = tcf_get_next_chain(block, NULL);
chain;
chain = tcf_get_next_chain(block, chain)) {
struct tcf_proto *tp; struct tcf_proto *tp;
for (tp = rtnl_dereference(chain->filter_chain); for (tp = rtnl_dereference(chain->filter_chain);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment