Commit dbe42559 authored by Dave Watson's avatar Dave Watson Committed by David S. Miller

tls: Move cipher info to a separate struct

Separate tx crypto parameters to a separate cipher_context struct.
The same parameters will be used for rx using the same struct.

tls_advance_record_sn is modified to only take the cipher info.
Signed-off-by: default avatarDave Watson <davejwatson@fb.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 69ca9293
...@@ -81,6 +81,16 @@ enum { ...@@ -81,6 +81,16 @@ enum {
TLS_PENDING_CLOSED_RECORD TLS_PENDING_CLOSED_RECORD
}; };
struct cipher_context {
u16 prepend_size;
u16 tag_size;
u16 overhead_size;
u16 iv_size;
char *iv;
u16 rec_seq_size;
char *rec_seq;
};
struct tls_context { struct tls_context {
union { union {
struct tls_crypto_info crypto_send; struct tls_crypto_info crypto_send;
...@@ -91,13 +101,7 @@ struct tls_context { ...@@ -91,13 +101,7 @@ struct tls_context {
u8 tx_conf:2; u8 tx_conf:2;
u16 prepend_size; struct cipher_context tx;
u16 tag_size;
u16 overhead_size;
u16 iv_size;
char *iv;
u16 rec_seq_size;
char *rec_seq;
struct scatterlist *partially_sent_record; struct scatterlist *partially_sent_record;
u16 partially_sent_offset; u16 partially_sent_offset;
...@@ -190,7 +194,7 @@ static inline bool tls_bigint_increment(unsigned char *seq, int len) ...@@ -190,7 +194,7 @@ static inline bool tls_bigint_increment(unsigned char *seq, int len)
} }
static inline void tls_advance_record_sn(struct sock *sk, static inline void tls_advance_record_sn(struct sock *sk,
struct tls_context *ctx) struct cipher_context *ctx)
{ {
if (tls_bigint_increment(ctx->rec_seq, ctx->rec_seq_size)) if (tls_bigint_increment(ctx->rec_seq, ctx->rec_seq_size))
tls_err_abort(sk); tls_err_abort(sk);
...@@ -203,9 +207,9 @@ static inline void tls_fill_prepend(struct tls_context *ctx, ...@@ -203,9 +207,9 @@ static inline void tls_fill_prepend(struct tls_context *ctx,
size_t plaintext_len, size_t plaintext_len,
unsigned char record_type) unsigned char record_type)
{ {
size_t pkt_len, iv_size = ctx->iv_size; size_t pkt_len, iv_size = ctx->tx.iv_size;
pkt_len = plaintext_len + iv_size + ctx->tag_size; pkt_len = plaintext_len + iv_size + ctx->tx.tag_size;
/* we cover nonce explicit here as well, so buf should be of /* we cover nonce explicit here as well, so buf should be of
* size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE * size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE
...@@ -217,7 +221,7 @@ static inline void tls_fill_prepend(struct tls_context *ctx, ...@@ -217,7 +221,7 @@ static inline void tls_fill_prepend(struct tls_context *ctx,
buf[3] = pkt_len >> 8; buf[3] = pkt_len >> 8;
buf[4] = pkt_len & 0xFF; buf[4] = pkt_len & 0xFF;
memcpy(buf + TLS_NONCE_OFFSET, memcpy(buf + TLS_NONCE_OFFSET,
ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv_size); ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv_size);
} }
static inline void tls_make_aad(char *buf, static inline void tls_make_aad(char *buf,
......
...@@ -259,8 +259,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) ...@@ -259,8 +259,8 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
} }
} }
kfree(ctx->rec_seq); kfree(ctx->tx.rec_seq);
kfree(ctx->iv); kfree(ctx->tx.iv);
if (ctx->tx_conf == TLS_SW_TX) if (ctx->tx_conf == TLS_SW_TX)
tls_sw_free_tx_resources(sk); tls_sw_free_tx_resources(sk);
...@@ -319,9 +319,9 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, ...@@ -319,9 +319,9 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
} }
lock_sock(sk); lock_sock(sk);
memcpy(crypto_info_aes_gcm_128->iv, memcpy(crypto_info_aes_gcm_128->iv,
ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
TLS_CIPHER_AES_GCM_128_IV_SIZE); TLS_CIPHER_AES_GCM_128_IV_SIZE);
memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->rec_seq, memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->tx.rec_seq,
TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE); TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
release_sock(sk); release_sock(sk);
if (copy_to_user(optval, if (copy_to_user(optval,
......
...@@ -79,7 +79,7 @@ static void trim_both_sgl(struct sock *sk, int target_size) ...@@ -79,7 +79,7 @@ static void trim_both_sgl(struct sock *sk, int target_size)
target_size); target_size);
if (target_size > 0) if (target_size > 0)
target_size += tls_ctx->overhead_size; target_size += tls_ctx->tx.overhead_size;
trim_sg(sk, ctx->sg_encrypted_data, trim_sg(sk, ctx->sg_encrypted_data,
&ctx->sg_encrypted_num_elem, &ctx->sg_encrypted_num_elem,
...@@ -152,21 +152,21 @@ static int tls_do_encryption(struct tls_context *tls_ctx, ...@@ -152,21 +152,21 @@ static int tls_do_encryption(struct tls_context *tls_ctx,
if (!aead_req) if (!aead_req)
return -ENOMEM; return -ENOMEM;
ctx->sg_encrypted_data[0].offset += tls_ctx->prepend_size; ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
ctx->sg_encrypted_data[0].length -= tls_ctx->prepend_size; ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
aead_request_set_tfm(aead_req, ctx->aead_send); aead_request_set_tfm(aead_req, ctx->aead_send);
aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out, aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
data_len, tls_ctx->iv); data_len, tls_ctx->tx.iv);
aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
crypto_req_done, &ctx->async_wait); crypto_req_done, &ctx->async_wait);
rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait); rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
ctx->sg_encrypted_data[0].offset -= tls_ctx->prepend_size; ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
ctx->sg_encrypted_data[0].length += tls_ctx->prepend_size; ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
kfree(aead_req); kfree(aead_req);
return rc; return rc;
...@@ -183,7 +183,7 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -183,7 +183,7 @@ static int tls_push_record(struct sock *sk, int flags,
sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1); sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size, tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
tls_ctx->rec_seq, tls_ctx->rec_seq_size, tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
record_type); record_type);
tls_fill_prepend(tls_ctx, tls_fill_prepend(tls_ctx,
...@@ -216,7 +216,7 @@ static int tls_push_record(struct sock *sk, int flags, ...@@ -216,7 +216,7 @@ static int tls_push_record(struct sock *sk, int flags,
if (rc < 0 && rc != -EAGAIN) if (rc < 0 && rc != -EAGAIN)
tls_err_abort(sk); tls_err_abort(sk);
tls_advance_record_sn(sk, tls_ctx); tls_advance_record_sn(sk, &tls_ctx->tx);
return rc; return rc;
} }
...@@ -357,7 +357,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -357,7 +357,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
} }
required_size = ctx->sg_plaintext_size + try_to_copy + required_size = ctx->sg_plaintext_size + try_to_copy +
tls_ctx->overhead_size; tls_ctx->tx.overhead_size;
if (!sk_stream_memory_free(sk)) if (!sk_stream_memory_free(sk))
goto wait_for_sndbuf; goto wait_for_sndbuf;
...@@ -420,7 +420,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) ...@@ -420,7 +420,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
&ctx->sg_encrypted_num_elem, &ctx->sg_encrypted_num_elem,
&ctx->sg_encrypted_size, &ctx->sg_encrypted_size,
ctx->sg_plaintext_size + ctx->sg_plaintext_size +
tls_ctx->overhead_size); tls_ctx->tx.overhead_size);
} }
ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy); ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
...@@ -512,7 +512,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page, ...@@ -512,7 +512,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
full_record = true; full_record = true;
} }
required_size = ctx->sg_plaintext_size + copy + required_size = ctx->sg_plaintext_size + copy +
tls_ctx->overhead_size; tls_ctx->tx.overhead_size;
if (!sk_stream_memory_free(sk)) if (!sk_stream_memory_free(sk))
goto wait_for_sndbuf; goto wait_for_sndbuf;
...@@ -644,24 +644,26 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) ...@@ -644,24 +644,26 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
goto free_priv; goto free_priv;
} }
ctx->prepend_size = TLS_HEADER_SIZE + nonce_size; ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
ctx->tag_size = tag_size; ctx->tx.tag_size = tag_size;
ctx->overhead_size = ctx->prepend_size + ctx->tag_size; ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
ctx->iv_size = iv_size; ctx->tx.iv_size = iv_size;
ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, GFP_KERNEL); ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
if (!ctx->iv) { GFP_KERNEL);
if (!ctx->tx.iv) {
rc = -ENOMEM; rc = -ENOMEM;
goto free_priv; goto free_priv;
} }
memcpy(ctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); memcpy(ctx->tx.iv, gcm_128_info->salt,
memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); TLS_CIPHER_AES_GCM_128_SALT_SIZE);
ctx->rec_seq_size = rec_seq_size; memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); ctx->tx.rec_seq_size = rec_seq_size;
if (!ctx->rec_seq) { ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
if (!ctx->tx.rec_seq) {
rc = -ENOMEM; rc = -ENOMEM;
goto free_iv; goto free_iv;
} }
memcpy(ctx->rec_seq, rec_seq, rec_seq_size); memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size);
sg_init_table(sw_ctx->sg_encrypted_data, sg_init_table(sw_ctx->sg_encrypted_data,
ARRAY_SIZE(sw_ctx->sg_encrypted_data)); ARRAY_SIZE(sw_ctx->sg_encrypted_data));
...@@ -697,7 +699,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) ...@@ -697,7 +699,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
if (rc) if (rc)
goto free_aead; goto free_aead;
rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tag_size); rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tx.tag_size);
if (!rc) if (!rc)
return 0; return 0;
...@@ -705,11 +707,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx) ...@@ -705,11 +707,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
crypto_free_aead(sw_ctx->aead_send); crypto_free_aead(sw_ctx->aead_send);
sw_ctx->aead_send = NULL; sw_ctx->aead_send = NULL;
free_rec_seq: free_rec_seq:
kfree(ctx->rec_seq); kfree(ctx->tx.rec_seq);
ctx->rec_seq = NULL; ctx->tx.rec_seq = NULL;
free_iv: free_iv:
kfree(ctx->iv); kfree(ctx->tx.iv);
ctx->iv = NULL; ctx->tx.iv = NULL;
free_priv: free_priv:
kfree(ctx->priv_ctx); kfree(ctx->priv_ctx);
ctx->priv_ctx = NULL; ctx->priv_ctx = NULL;
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment