Commit 9513c5e1 authored by Avri Altman's avatar Avri Altman Committed by Emmanuel Grumbach

iwlwifi: mvm: Avoid dereferencing sta if it was already flushed

Be a little bit more careful when dereferencing sta on key removal,
As it might already get flushed on other thread.
Signed-off-by: default avatarAvri Altman <avri.altman@intel.com>
Signed-off-by: default avatarEmmanuel Grumbach <emmanuel.grumbach@intel.com>
parent d6ee54a9
...@@ -1201,7 +1201,8 @@ static int iwl_mvm_set_fw_key_idx(struct iwl_mvm *mvm) ...@@ -1201,7 +1201,8 @@ static int iwl_mvm_set_fw_key_idx(struct iwl_mvm *mvm)
return max_offs; return max_offs;
} }
static u8 iwl_mvm_get_key_sta_id(struct ieee80211_vif *vif, static u8 iwl_mvm_get_key_sta_id(struct iwl_mvm *mvm,
struct ieee80211_vif *vif,
struct ieee80211_sta *sta) struct ieee80211_sta *sta)
{ {
struct iwl_mvm_vif *mvmvif = iwl_mvm_vif_from_mac80211(vif); struct iwl_mvm_vif *mvmvif = iwl_mvm_vif_from_mac80211(vif);
...@@ -1218,8 +1219,21 @@ static u8 iwl_mvm_get_key_sta_id(struct ieee80211_vif *vif, ...@@ -1218,8 +1219,21 @@ static u8 iwl_mvm_get_key_sta_id(struct ieee80211_vif *vif,
* station ID, then use AP's station ID. * station ID, then use AP's station ID.
*/ */
if (vif->type == NL80211_IFTYPE_STATION && if (vif->type == NL80211_IFTYPE_STATION &&
mvmvif->ap_sta_id != IWL_MVM_STATION_COUNT) mvmvif->ap_sta_id != IWL_MVM_STATION_COUNT) {
return mvmvif->ap_sta_id; u8 sta_id = mvmvif->ap_sta_id;
sta = rcu_dereference_protected(mvm->fw_id_to_mac_id[sta_id],
lockdep_is_held(&mvm->mutex));
/*
* It is possible that the 'sta' parameter is NULL,
* for example when a GTK is removed - the sta_id will then
* be the AP ID, and no station was passed by mac80211.
*/
if (IS_ERR_OR_NULL(sta))
return IWL_MVM_STATION_COUNT;
return sta_id;
}
return IWL_MVM_STATION_COUNT; return IWL_MVM_STATION_COUNT;
} }
...@@ -1445,7 +1459,7 @@ int iwl_mvm_set_sta_key(struct iwl_mvm *mvm, ...@@ -1445,7 +1459,7 @@ int iwl_mvm_set_sta_key(struct iwl_mvm *mvm,
lockdep_assert_held(&mvm->mutex); lockdep_assert_held(&mvm->mutex);
/* Get the station id from the mvm local station table */ /* Get the station id from the mvm local station table */
sta_id = iwl_mvm_get_key_sta_id(vif, sta); sta_id = iwl_mvm_get_key_sta_id(mvm, vif, sta);
if (sta_id == IWL_MVM_STATION_COUNT) { if (sta_id == IWL_MVM_STATION_COUNT) {
IWL_ERR(mvm, "Failed to find station id\n"); IWL_ERR(mvm, "Failed to find station id\n");
return -EINVAL; return -EINVAL;
...@@ -1531,7 +1545,7 @@ int iwl_mvm_remove_sta_key(struct iwl_mvm *mvm, ...@@ -1531,7 +1545,7 @@ int iwl_mvm_remove_sta_key(struct iwl_mvm *mvm,
lockdep_assert_held(&mvm->mutex); lockdep_assert_held(&mvm->mutex);
/* Get the station id from the mvm local station table */ /* Get the station id from the mvm local station table */
sta_id = iwl_mvm_get_key_sta_id(vif, sta); sta_id = iwl_mvm_get_key_sta_id(mvm, vif, sta);
IWL_DEBUG_WEP(mvm, "mvm remove dynamic key: idx=%d sta=%d\n", IWL_DEBUG_WEP(mvm, "mvm remove dynamic key: idx=%d sta=%d\n",
keyconf->keyidx, sta_id); keyconf->keyidx, sta_id);
...@@ -1557,24 +1571,6 @@ int iwl_mvm_remove_sta_key(struct iwl_mvm *mvm, ...@@ -1557,24 +1571,6 @@ int iwl_mvm_remove_sta_key(struct iwl_mvm *mvm,
return 0; return 0;
} }
/*
* It is possible that the 'sta' parameter is NULL, and thus
* there is a need to retrieve the sta from the local station table,
* for example when a GTK is removed (where the sta_id will then be
* the AP ID, and no station was passed by mac80211.)
*/
if (!sta) {
sta = rcu_dereference_protected(mvm->fw_id_to_mac_id[sta_id],
lockdep_is_held(&mvm->mutex));
if (!sta) {
IWL_ERR(mvm, "Invalid station id\n");
return -EINVAL;
}
}
if (WARN_ON_ONCE(iwl_mvm_sta_from_mac80211(sta)->vif != vif))
return -EINVAL;
ret = __iwl_mvm_remove_sta_key(mvm, sta_id, keyconf, mcast); ret = __iwl_mvm_remove_sta_key(mvm, sta_id, keyconf, mcast);
if (ret) if (ret)
return ret; return ret;
...@@ -1594,7 +1590,7 @@ void iwl_mvm_update_tkip_key(struct iwl_mvm *mvm, ...@@ -1594,7 +1590,7 @@ void iwl_mvm_update_tkip_key(struct iwl_mvm *mvm,
u16 *phase1key) u16 *phase1key)
{ {
struct iwl_mvm_sta *mvm_sta; struct iwl_mvm_sta *mvm_sta;
u8 sta_id = iwl_mvm_get_key_sta_id(vif, sta); u8 sta_id = iwl_mvm_get_key_sta_id(mvm, vif, sta);
bool mcast = !(keyconf->flags & IEEE80211_KEY_FLAG_PAIRWISE); bool mcast = !(keyconf->flags & IEEE80211_KEY_FLAG_PAIRWISE);
if (WARN_ON_ONCE(sta_id == IWL_MVM_STATION_COUNT)) if (WARN_ON_ONCE(sta_id == IWL_MVM_STATION_COUNT))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment