Commit 44524fcd authored by Marek Lindner's avatar Marek Lindner

batman-adv: Correct rcu refcounting for neigh_node

It might be possible that 2 threads access the same data in the same
rcu grace period. The first thread calls call_rcu() to decrement the
refcount and free the data while the second thread increases the
refcount to use the data. To avoid this race condition all refcount
operations have to be atomic.
Reported-by: default avatarSven Eckelmann <sven@narfation.org>
Signed-off-by: default avatarMarek Lindner <lindner_marek@yahoo.de>
parent a4c135c5
...@@ -156,7 +156,8 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff, ...@@ -156,7 +156,8 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff,
struct sk_buff *skb; struct sk_buff *skb;
struct icmp_packet_rr *icmp_packet; struct icmp_packet_rr *icmp_packet;
struct orig_node *orig_node; struct orig_node *orig_node = NULL;
struct neigh_node *neigh_node = NULL;
struct batman_if *batman_if; struct batman_if *batman_if;
size_t packet_len = sizeof(struct icmp_packet); size_t packet_len = sizeof(struct icmp_packet);
uint8_t dstaddr[ETH_ALEN]; uint8_t dstaddr[ETH_ALEN];
...@@ -224,17 +225,25 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff, ...@@ -224,17 +225,25 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff,
orig_node = ((struct orig_node *)hash_find(bat_priv->orig_hash, orig_node = ((struct orig_node *)hash_find(bat_priv->orig_hash,
compare_orig, choose_orig, compare_orig, choose_orig,
icmp_packet->dst)); icmp_packet->dst));
rcu_read_unlock();
if (!orig_node) if (!orig_node)
goto unlock; goto unlock;
if (!orig_node->router) kref_get(&orig_node->refcount);
neigh_node = orig_node->router;
if (!neigh_node)
goto unlock;
if (!atomic_inc_not_zero(&neigh_node->refcount)) {
neigh_node = NULL;
goto unlock; goto unlock;
}
rcu_read_unlock();
batman_if = orig_node->router->if_incoming; batman_if = orig_node->router->if_incoming;
memcpy(dstaddr, orig_node->router->addr, ETH_ALEN); memcpy(dstaddr, orig_node->router->addr, ETH_ALEN);
spin_unlock_bh(&bat_priv->orig_hash_lock); spin_unlock_bh(&bat_priv->orig_hash_lock);
if (!batman_if) if (!batman_if)
...@@ -247,14 +256,14 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff, ...@@ -247,14 +256,14 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff,
bat_priv->primary_if->net_dev->dev_addr, ETH_ALEN); bat_priv->primary_if->net_dev->dev_addr, ETH_ALEN);
if (packet_len == sizeof(struct icmp_packet_rr)) if (packet_len == sizeof(struct icmp_packet_rr))
memcpy(icmp_packet->rr, batman_if->net_dev->dev_addr, ETH_ALEN); memcpy(icmp_packet->rr,
batman_if->net_dev->dev_addr, ETH_ALEN);
send_skb_packet(skb, batman_if, dstaddr); send_skb_packet(skb, batman_if, dstaddr);
goto out; goto out;
unlock: unlock:
rcu_read_unlock();
spin_unlock_bh(&bat_priv->orig_hash_lock); spin_unlock_bh(&bat_priv->orig_hash_lock);
dst_unreach: dst_unreach:
icmp_packet->msg_type = DESTINATION_UNREACHABLE; icmp_packet->msg_type = DESTINATION_UNREACHABLE;
...@@ -262,6 +271,10 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff, ...@@ -262,6 +271,10 @@ static ssize_t bat_socket_write(struct file *file, const char __user *buff,
free_skb: free_skb:
kfree_skb(skb); kfree_skb(skb);
out: out:
if (neigh_node)
neigh_node_free_ref(neigh_node);
if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref);
return len; return len;
} }
......
...@@ -59,28 +59,18 @@ int originator_init(struct bat_priv *bat_priv) ...@@ -59,28 +59,18 @@ int originator_init(struct bat_priv *bat_priv)
return 0; return 0;
} }
void neigh_node_free_ref(struct kref *refcount)
{
struct neigh_node *neigh_node;
neigh_node = container_of(refcount, struct neigh_node, refcount);
kfree(neigh_node);
}
static void neigh_node_free_rcu(struct rcu_head *rcu) static void neigh_node_free_rcu(struct rcu_head *rcu)
{ {
struct neigh_node *neigh_node; struct neigh_node *neigh_node;
neigh_node = container_of(rcu, struct neigh_node, rcu); neigh_node = container_of(rcu, struct neigh_node, rcu);
kref_put(&neigh_node->refcount, neigh_node_free_ref); kfree(neigh_node);
} }
void neigh_node_free_rcu_bond(struct rcu_head *rcu) void neigh_node_free_ref(struct neigh_node *neigh_node)
{ {
struct neigh_node *neigh_node; if (atomic_dec_and_test(&neigh_node->refcount))
call_rcu(&neigh_node->rcu, neigh_node_free_rcu);
neigh_node = container_of(rcu, struct neigh_node, rcu_bond);
kref_put(&neigh_node->refcount, neigh_node_free_ref);
} }
struct neigh_node *create_neighbor(struct orig_node *orig_node, struct neigh_node *create_neighbor(struct orig_node *orig_node,
...@@ -104,7 +94,7 @@ struct neigh_node *create_neighbor(struct orig_node *orig_node, ...@@ -104,7 +94,7 @@ struct neigh_node *create_neighbor(struct orig_node *orig_node,
memcpy(neigh_node->addr, neigh, ETH_ALEN); memcpy(neigh_node->addr, neigh, ETH_ALEN);
neigh_node->orig_node = orig_neigh_node; neigh_node->orig_node = orig_neigh_node;
neigh_node->if_incoming = if_incoming; neigh_node->if_incoming = if_incoming;
kref_init(&neigh_node->refcount); atomic_set(&neigh_node->refcount, 1);
spin_lock_bh(&orig_node->neigh_list_lock); spin_lock_bh(&orig_node->neigh_list_lock);
hlist_add_head_rcu(&neigh_node->list, &orig_node->neigh_list); hlist_add_head_rcu(&neigh_node->list, &orig_node->neigh_list);
...@@ -126,14 +116,14 @@ void orig_node_free_ref(struct kref *refcount) ...@@ -126,14 +116,14 @@ void orig_node_free_ref(struct kref *refcount)
list_for_each_entry_safe(neigh_node, tmp_neigh_node, list_for_each_entry_safe(neigh_node, tmp_neigh_node,
&orig_node->bond_list, bonding_list) { &orig_node->bond_list, bonding_list) {
list_del_rcu(&neigh_node->bonding_list); list_del_rcu(&neigh_node->bonding_list);
call_rcu(&neigh_node->rcu_bond, neigh_node_free_rcu_bond); neigh_node_free_ref(neigh_node);
} }
/* for all neighbors towards this originator ... */ /* for all neighbors towards this originator ... */
hlist_for_each_entry_safe(neigh_node, node, node_tmp, hlist_for_each_entry_safe(neigh_node, node, node_tmp,
&orig_node->neigh_list, list) { &orig_node->neigh_list, list) {
hlist_del_rcu(&neigh_node->list); hlist_del_rcu(&neigh_node->list);
call_rcu(&neigh_node->rcu, neigh_node_free_rcu); neigh_node_free_ref(neigh_node);
} }
spin_unlock_bh(&orig_node->neigh_list_lock); spin_unlock_bh(&orig_node->neigh_list_lock);
...@@ -315,7 +305,7 @@ static bool purge_orig_neighbors(struct bat_priv *bat_priv, ...@@ -315,7 +305,7 @@ static bool purge_orig_neighbors(struct bat_priv *bat_priv,
hlist_del_rcu(&neigh_node->list); hlist_del_rcu(&neigh_node->list);
bonding_candidate_del(orig_node, neigh_node); bonding_candidate_del(orig_node, neigh_node);
call_rcu(&neigh_node->rcu, neigh_node_free_rcu); neigh_node_free_ref(neigh_node);
} else { } else {
if ((!*best_neigh_node) || if ((!*best_neigh_node) ||
(neigh_node->tq_avg > (*best_neigh_node)->tq_avg)) (neigh_node->tq_avg > (*best_neigh_node)->tq_avg))
......
...@@ -26,13 +26,12 @@ int originator_init(struct bat_priv *bat_priv); ...@@ -26,13 +26,12 @@ int originator_init(struct bat_priv *bat_priv);
void originator_free(struct bat_priv *bat_priv); void originator_free(struct bat_priv *bat_priv);
void purge_orig_ref(struct bat_priv *bat_priv); void purge_orig_ref(struct bat_priv *bat_priv);
void orig_node_free_ref(struct kref *refcount); void orig_node_free_ref(struct kref *refcount);
void neigh_node_free_rcu_bond(struct rcu_head *rcu);
struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr); struct orig_node *get_orig_node(struct bat_priv *bat_priv, uint8_t *addr);
struct neigh_node *create_neighbor(struct orig_node *orig_node, struct neigh_node *create_neighbor(struct orig_node *orig_node,
struct orig_node *orig_neigh_node, struct orig_node *orig_neigh_node,
uint8_t *neigh, uint8_t *neigh,
struct batman_if *if_incoming); struct batman_if *if_incoming);
void neigh_node_free_ref(struct kref *refcount); void neigh_node_free_ref(struct neigh_node *neigh_node);
int orig_seq_print_text(struct seq_file *seq, void *offset); int orig_seq_print_text(struct seq_file *seq, void *offset);
int orig_hash_add_if(struct batman_if *batman_if, int max_if_num); int orig_hash_add_if(struct batman_if *batman_if, int max_if_num);
int orig_hash_del_if(struct batman_if *batman_if, int max_if_num); int orig_hash_del_if(struct batman_if *batman_if, int max_if_num);
......
This diff is collapsed.
...@@ -117,9 +117,8 @@ struct neigh_node { ...@@ -117,9 +117,8 @@ struct neigh_node {
struct list_head bonding_list; struct list_head bonding_list;
unsigned long last_valid; unsigned long last_valid;
unsigned long real_bits[NUM_WORDS]; unsigned long real_bits[NUM_WORDS];
struct kref refcount; atomic_t refcount;
struct rcu_head rcu; struct rcu_head rcu;
struct rcu_head rcu_bond;
struct orig_node *orig_node; struct orig_node *orig_node;
struct batman_if *if_incoming; struct batman_if *if_incoming;
}; };
......
...@@ -285,38 +285,42 @@ int unicast_send_skb(struct sk_buff *skb, struct bat_priv *bat_priv) ...@@ -285,38 +285,42 @@ int unicast_send_skb(struct sk_buff *skb, struct bat_priv *bat_priv)
struct unicast_packet *unicast_packet; struct unicast_packet *unicast_packet;
struct orig_node *orig_node = NULL; struct orig_node *orig_node = NULL;
struct batman_if *batman_if; struct batman_if *batman_if;
struct neigh_node *router; struct neigh_node *neigh_node;
int data_len = skb->len; int data_len = skb->len;
uint8_t dstaddr[6]; uint8_t dstaddr[6];
int ret = 1;
spin_lock_bh(&bat_priv->orig_hash_lock); spin_lock_bh(&bat_priv->orig_hash_lock);
/* get routing information */ /* get routing information */
if (is_multicast_ether_addr(ethhdr->h_dest)) if (is_multicast_ether_addr(ethhdr->h_dest))
orig_node = (struct orig_node *)gw_get_selected(bat_priv); orig_node = (struct orig_node *)gw_get_selected(bat_priv);
if (orig_node) {
kref_get(&orig_node->refcount);
goto find_router;
}
/* check for hna host */ /* check for hna host - increases orig_node refcount */
if (!orig_node) orig_node = transtable_search(bat_priv, ethhdr->h_dest);
orig_node = transtable_search(bat_priv, ethhdr->h_dest);
find_router:
/* find_router() increases neigh_nodes refcount if found. */ /* find_router() increases neigh_nodes refcount if found. */
router = find_router(bat_priv, orig_node, NULL); neigh_node = find_router(bat_priv, orig_node, NULL);
if (!router) if (!neigh_node)
goto unlock; goto unlock;
/* don't lock while sending the packets ... we therefore if (neigh_node->if_incoming->if_status != IF_ACTIVE)
* copy the required data before sending */ goto unlock;
batman_if = router->if_incoming;
memcpy(dstaddr, router->addr, ETH_ALEN);
spin_unlock_bh(&bat_priv->orig_hash_lock);
if (batman_if->if_status != IF_ACTIVE)
goto dropped;
if (my_skb_head_push(skb, sizeof(struct unicast_packet)) < 0) if (my_skb_head_push(skb, sizeof(struct unicast_packet)) < 0)
goto dropped; goto unlock;
/* don't lock while sending the packets ... we therefore
* copy the required data before sending */
batman_if = neigh_node->if_incoming;
memcpy(dstaddr, neigh_node->addr, ETH_ALEN);
spin_unlock_bh(&bat_priv->orig_hash_lock);
unicast_packet = (struct unicast_packet *)skb->data; unicast_packet = (struct unicast_packet *)skb->data;
...@@ -330,18 +334,25 @@ int unicast_send_skb(struct sk_buff *skb, struct bat_priv *bat_priv) ...@@ -330,18 +334,25 @@ int unicast_send_skb(struct sk_buff *skb, struct bat_priv *bat_priv)
if (atomic_read(&bat_priv->fragmentation) && if (atomic_read(&bat_priv->fragmentation) &&
data_len + sizeof(struct unicast_packet) > data_len + sizeof(struct unicast_packet) >
batman_if->net_dev->mtu) { batman_if->net_dev->mtu) {
/* send frag skb decreases ttl */ /* send frag skb decreases ttl */
unicast_packet->ttl++; unicast_packet->ttl++;
return frag_send_skb(skb, bat_priv, batman_if, ret = frag_send_skb(skb, bat_priv, batman_if, dstaddr);
dstaddr); goto out;
} }
send_skb_packet(skb, batman_if, dstaddr); send_skb_packet(skb, batman_if, dstaddr);
return 0; ret = 0;
goto out;
unlock: unlock:
spin_unlock_bh(&bat_priv->orig_hash_lock); spin_unlock_bh(&bat_priv->orig_hash_lock);
dropped: out:
kfree_skb(skb); if (neigh_node)
return 1; neigh_node_free_ref(neigh_node);
if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref);
if (ret == 1)
kfree_skb(skb);
return ret;
} }
...@@ -764,21 +764,35 @@ static void unicast_vis_packet(struct bat_priv *bat_priv, ...@@ -764,21 +764,35 @@ static void unicast_vis_packet(struct bat_priv *bat_priv,
struct vis_info *info) struct vis_info *info)
{ {
struct orig_node *orig_node; struct orig_node *orig_node;
struct neigh_node *neigh_node = NULL;
struct sk_buff *skb; struct sk_buff *skb;
struct vis_packet *packet; struct vis_packet *packet;
struct batman_if *batman_if; struct batman_if *batman_if;
uint8_t dstaddr[ETH_ALEN]; uint8_t dstaddr[ETH_ALEN];
spin_lock_bh(&bat_priv->orig_hash_lock);
packet = (struct vis_packet *)info->skb_packet->data; packet = (struct vis_packet *)info->skb_packet->data;
spin_lock_bh(&bat_priv->orig_hash_lock);
rcu_read_lock(); rcu_read_lock();
orig_node = ((struct orig_node *)hash_find(bat_priv->orig_hash, orig_node = ((struct orig_node *)hash_find(bat_priv->orig_hash,
compare_orig, choose_orig, compare_orig, choose_orig,
packet->target_orig)); packet->target_orig));
rcu_read_unlock();
if ((!orig_node) || (!orig_node->router)) if (!orig_node)
goto out; goto unlock;
kref_get(&orig_node->refcount);
neigh_node = orig_node->router;
if (!neigh_node)
goto unlock;
if (!atomic_inc_not_zero(&neigh_node->refcount)) {
neigh_node = NULL;
goto unlock;
}
rcu_read_unlock();
/* don't lock while sending the packets ... we therefore /* don't lock while sending the packets ... we therefore
* copy the required data before sending */ * copy the required data before sending */
...@@ -790,10 +804,17 @@ static void unicast_vis_packet(struct bat_priv *bat_priv, ...@@ -790,10 +804,17 @@ static void unicast_vis_packet(struct bat_priv *bat_priv,
if (skb) if (skb)
send_skb_packet(skb, batman_if, dstaddr); send_skb_packet(skb, batman_if, dstaddr);
return; goto out;
out: unlock:
rcu_read_unlock();
spin_unlock_bh(&bat_priv->orig_hash_lock); spin_unlock_bh(&bat_priv->orig_hash_lock);
out:
if (neigh_node)
neigh_node_free_ref(neigh_node);
if (orig_node)
kref_put(&orig_node->refcount, orig_node_free_ref);
return;
} }
/* only send one vis packet. called from send_vis_packets() */ /* only send one vis packet. called from send_vis_packets() */
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment