Commit 7c0aee30 authored by David S. Miller's avatar David S. Miller

Merge branch 'ktls-use-after-free'

Maxim Mikityanskiy says:

====================
Fix use-after-free after the TLS device goes down and up

This small series fixes a use-after-free bug in the TLS offload code.
The first patch is a preparation for the second one, and the second is
the fix itself.

v2 changes:

Remove unneeded EXPORT_SYMBOL_GPL.
====================
Acked-by: default avatarJakub Kicinski <kuba@kernel.org>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents f336d0b9 c55dcdd4
...@@ -193,7 +193,11 @@ struct tls_offload_context_tx { ...@@ -193,7 +193,11 @@ struct tls_offload_context_tx {
(sizeof(struct tls_offload_context_tx) + TLS_DRIVER_STATE_SIZE_TX) (sizeof(struct tls_offload_context_tx) + TLS_DRIVER_STATE_SIZE_TX)
enum tls_context_flags { enum tls_context_flags {
TLS_RX_SYNC_RUNNING = 0, /* tls_device_down was called after the netdev went down, device state
* was released, and kTLS works in software, even though rx_conf is
* still TLS_HW (needed for transition).
*/
TLS_RX_DEV_DEGRADED = 0,
/* Unlike RX where resync is driven entirely by the core in TX only /* Unlike RX where resync is driven entirely by the core in TX only
* the driver knows when things went out of sync, so we need the flag * the driver knows when things went out of sync, so we need the flag
* to be atomic. * to be atomic.
...@@ -266,6 +270,7 @@ struct tls_context { ...@@ -266,6 +270,7 @@ struct tls_context {
/* cache cold stuff */ /* cache cold stuff */
struct proto *sk_proto; struct proto *sk_proto;
struct sock *sk;
void (*sk_destruct)(struct sock *sk); void (*sk_destruct)(struct sock *sk);
...@@ -448,6 +453,9 @@ static inline u16 tls_user_config(struct tls_context *ctx, bool tx) ...@@ -448,6 +453,9 @@ static inline u16 tls_user_config(struct tls_context *ctx, bool tx)
struct sk_buff * struct sk_buff *
tls_validate_xmit_skb(struct sock *sk, struct net_device *dev, tls_validate_xmit_skb(struct sock *sk, struct net_device *dev,
struct sk_buff *skb); struct sk_buff *skb);
struct sk_buff *
tls_validate_xmit_skb_sw(struct sock *sk, struct net_device *dev,
struct sk_buff *skb);
static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk) static inline bool tls_is_sk_tx_device_offloaded(struct sock *sk)
{ {
......
...@@ -50,6 +50,7 @@ static void tls_device_gc_task(struct work_struct *work); ...@@ -50,6 +50,7 @@ static void tls_device_gc_task(struct work_struct *work);
static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task); static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
static LIST_HEAD(tls_device_gc_list); static LIST_HEAD(tls_device_gc_list);
static LIST_HEAD(tls_device_list); static LIST_HEAD(tls_device_list);
static LIST_HEAD(tls_device_down_list);
static DEFINE_SPINLOCK(tls_device_lock); static DEFINE_SPINLOCK(tls_device_lock);
static void tls_device_free_ctx(struct tls_context *ctx) static void tls_device_free_ctx(struct tls_context *ctx)
...@@ -680,15 +681,13 @@ static void tls_device_resync_rx(struct tls_context *tls_ctx, ...@@ -680,15 +681,13 @@ static void tls_device_resync_rx(struct tls_context *tls_ctx,
struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx); struct tls_offload_context_rx *rx_ctx = tls_offload_ctx_rx(tls_ctx);
struct net_device *netdev; struct net_device *netdev;
if (WARN_ON(test_and_set_bit(TLS_RX_SYNC_RUNNING, &tls_ctx->flags)))
return;
trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type); trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
rcu_read_lock();
netdev = READ_ONCE(tls_ctx->netdev); netdev = READ_ONCE(tls_ctx->netdev);
if (netdev) if (netdev)
netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn, netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
TLS_OFFLOAD_CTX_DIR_RX); TLS_OFFLOAD_CTX_DIR_RX);
clear_bit_unlock(TLS_RX_SYNC_RUNNING, &tls_ctx->flags); rcu_read_unlock();
TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICERESYNC); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICERESYNC);
} }
...@@ -761,6 +760,8 @@ void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) ...@@ -761,6 +760,8 @@ void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq)
if (tls_ctx->rx_conf != TLS_HW) if (tls_ctx->rx_conf != TLS_HW)
return; return;
if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags)))
return;
prot = &tls_ctx->prot_info; prot = &tls_ctx->prot_info;
rx_ctx = tls_offload_ctx_rx(tls_ctx); rx_ctx = tls_offload_ctx_rx(tls_ctx);
...@@ -963,6 +964,17 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, ...@@ -963,6 +964,17 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
ctx->sw.decrypted |= is_decrypted; ctx->sw.decrypted |= is_decrypted;
if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
if (likely(is_encrypted || is_decrypted))
return 0;
/* After tls_device_down disables the offload, the next SKB will
* likely have initial fragments decrypted, and final ones not
* decrypted. We need to reencrypt that single SKB.
*/
return tls_device_reencrypt(sk, skb);
}
/* Return immediately if the record is either entirely plaintext or /* Return immediately if the record is either entirely plaintext or
* entirely ciphertext. Otherwise handle reencrypt partially decrypted * entirely ciphertext. Otherwise handle reencrypt partially decrypted
* record. * record.
...@@ -1292,6 +1304,26 @@ static int tls_device_down(struct net_device *netdev) ...@@ -1292,6 +1304,26 @@ static int tls_device_down(struct net_device *netdev)
spin_unlock_irqrestore(&tls_device_lock, flags); spin_unlock_irqrestore(&tls_device_lock, flags);
list_for_each_entry_safe(ctx, tmp, &list, list) { list_for_each_entry_safe(ctx, tmp, &list, list) {
/* Stop offloaded TX and switch to the fallback.
* tls_is_sk_tx_device_offloaded will return false.
*/
WRITE_ONCE(ctx->sk->sk_validate_xmit_skb, tls_validate_xmit_skb_sw);
/* Stop the RX and TX resync.
* tls_dev_resync must not be called after tls_dev_del.
*/
WRITE_ONCE(ctx->netdev, NULL);
/* Start skipping the RX resync logic completely. */
set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);
/* Sync with inflight packets. After this point:
* TX: no non-encrypted packets will be passed to the driver.
* RX: resync requests from the driver will be ignored.
*/
synchronize_net();
/* Release the offload context on the driver side. */
if (ctx->tx_conf == TLS_HW) if (ctx->tx_conf == TLS_HW)
netdev->tlsdev_ops->tls_dev_del(netdev, ctx, netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
TLS_OFFLOAD_CTX_DIR_TX); TLS_OFFLOAD_CTX_DIR_TX);
...@@ -1299,15 +1331,21 @@ static int tls_device_down(struct net_device *netdev) ...@@ -1299,15 +1331,21 @@ static int tls_device_down(struct net_device *netdev)
!test_bit(TLS_RX_DEV_CLOSED, &ctx->flags)) !test_bit(TLS_RX_DEV_CLOSED, &ctx->flags))
netdev->tlsdev_ops->tls_dev_del(netdev, ctx, netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
TLS_OFFLOAD_CTX_DIR_RX); TLS_OFFLOAD_CTX_DIR_RX);
WRITE_ONCE(ctx->netdev, NULL);
smp_mb__before_atomic(); /* pairs with test_and_set_bit() */
while (test_bit(TLS_RX_SYNC_RUNNING, &ctx->flags))
usleep_range(10, 200);
dev_put(netdev); dev_put(netdev);
list_del_init(&ctx->list);
if (refcount_dec_and_test(&ctx->refcount)) /* Move the context to a separate list for two reasons:
tls_device_free_ctx(ctx); * 1. When the context is deallocated, list_del is called.
* 2. It's no longer an offloaded context, so we don't want to
* run offload-specific code on this context.
*/
spin_lock_irqsave(&tls_device_lock, flags);
list_move_tail(&ctx->list, &tls_device_down_list);
spin_unlock_irqrestore(&tls_device_lock, flags);
/* Device contexts for RX and TX will be freed in on sk_destruct
* by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW.
*/
} }
up_write(&device_offload_lock); up_write(&device_offload_lock);
......
...@@ -431,6 +431,13 @@ struct sk_buff *tls_validate_xmit_skb(struct sock *sk, ...@@ -431,6 +431,13 @@ struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
} }
EXPORT_SYMBOL_GPL(tls_validate_xmit_skb); EXPORT_SYMBOL_GPL(tls_validate_xmit_skb);
struct sk_buff *tls_validate_xmit_skb_sw(struct sock *sk,
struct net_device *dev,
struct sk_buff *skb)
{
return tls_sw_fallback(sk, skb);
}
struct sk_buff *tls_encrypt_skb(struct sk_buff *skb) struct sk_buff *tls_encrypt_skb(struct sk_buff *skb)
{ {
return tls_sw_fallback(skb->sk, skb); return tls_sw_fallback(skb->sk, skb);
......
...@@ -636,6 +636,7 @@ struct tls_context *tls_ctx_create(struct sock *sk) ...@@ -636,6 +636,7 @@ struct tls_context *tls_ctx_create(struct sock *sk)
mutex_init(&ctx->tx_lock); mutex_init(&ctx->tx_lock);
rcu_assign_pointer(icsk->icsk_ulp_data, ctx); rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
ctx->sk_proto = READ_ONCE(sk->sk_prot); ctx->sk_proto = READ_ONCE(sk->sk_prot);
ctx->sk = sk;
return ctx; return ctx;
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment