Commit 67cfa5d3 authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu

crypto: arm64/aes-neonbs - implement ciphertext stealing for XTS

Update the AES-XTS implementation based on NEON instructions so that it
can deal with inputs whose size is not a multiple of the cipher block
size. This is part of the original XTS specification, but was never
implemented before in the Linux kernel.

Since the bit slicing driver is only faster if it can operate on at
least 7 blocks of input at the same time, let's reuse the alternate
path we are adding for CTS to process any data tail whose size is
not a multiple of 128 bytes.
Signed-off-by: default avatarArd Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 7cceca8b
...@@ -21,6 +21,9 @@ ...@@ -21,6 +21,9 @@
.macro xts_reload_mask, tmp .macro xts_reload_mask, tmp
.endm .endm
.macro xts_cts_skip_tw, reg, lbl
.endm
/* preload all round keys */ /* preload all round keys */
.macro load_round_keys, rounds, rk .macro load_round_keys, rounds, rk
cmp \rounds, #12 cmp \rounds, #12
......
...@@ -1071,5 +1071,7 @@ module_cpu_feature_match(AES, aes_init); ...@@ -1071,5 +1071,7 @@ module_cpu_feature_match(AES, aes_init);
module_init(aes_init); module_init(aes_init);
EXPORT_SYMBOL(neon_aes_ecb_encrypt); EXPORT_SYMBOL(neon_aes_ecb_encrypt);
EXPORT_SYMBOL(neon_aes_cbc_encrypt); EXPORT_SYMBOL(neon_aes_cbc_encrypt);
EXPORT_SYMBOL(neon_aes_xts_encrypt);
EXPORT_SYMBOL(neon_aes_xts_decrypt);
#endif #endif
module_exit(aes_exit); module_exit(aes_exit);
...@@ -442,6 +442,7 @@ AES_ENTRY(aes_xts_encrypt) ...@@ -442,6 +442,7 @@ AES_ENTRY(aes_xts_encrypt)
cbz w7, .Lxtsencnotfirst cbz w7, .Lxtsencnotfirst
enc_prepare w3, x5, x8 enc_prepare w3, x5, x8
xts_cts_skip_tw w7, .LxtsencNx
encrypt_block v4, w3, x5, x8, w7 /* first tweak */ encrypt_block v4, w3, x5, x8, w7 /* first tweak */
enc_switch_key w3, x2, x8 enc_switch_key w3, x2, x8
b .LxtsencNx b .LxtsencNx
...@@ -530,10 +531,12 @@ AES_ENTRY(aes_xts_decrypt) ...@@ -530,10 +531,12 @@ AES_ENTRY(aes_xts_decrypt)
ld1 {v4.16b}, [x6] ld1 {v4.16b}, [x6]
xts_load_mask v8 xts_load_mask v8
xts_cts_skip_tw w7, .Lxtsdecskiptw
cbz w7, .Lxtsdecnotfirst cbz w7, .Lxtsdecnotfirst
enc_prepare w3, x5, x8 enc_prepare w3, x5, x8
encrypt_block v4, w3, x5, x8, w7 /* first tweak */ encrypt_block v4, w3, x5, x8, w7 /* first tweak */
.Lxtsdecskiptw:
dec_prepare w3, x2, x8 dec_prepare w3, x2, x8
b .LxtsdecNx b .LxtsdecNx
......
...@@ -19,6 +19,11 @@ ...@@ -19,6 +19,11 @@
xts_load_mask \tmp xts_load_mask \tmp
.endm .endm
/* special case for the neon-bs driver calling into this one for CTS */
.macro xts_cts_skip_tw, reg, lbl
tbnz \reg, #1, \lbl
.endm
/* multiply by polynomial 'x' in GF(2^8) */ /* multiply by polynomial 'x' in GF(2^8) */
.macro mul_by_x, out, in, temp, const .macro mul_by_x, out, in, temp, const
sshr \temp, \in, #7 sshr \temp, \in, #7
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <crypto/ctr.h> #include <crypto/ctr.h>
#include <crypto/internal/simd.h> #include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h> #include <crypto/internal/skcipher.h>
#include <crypto/scatterwalk.h>
#include <crypto/xts.h> #include <crypto/xts.h>
#include <linux/module.h> #include <linux/module.h>
...@@ -45,6 +46,12 @@ asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], ...@@ -45,6 +46,12 @@ asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int blocks); int rounds, int blocks);
asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int blocks, u8 iv[]); int rounds, int blocks, u8 iv[]);
asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[],
u32 const rk1[], int rounds, int bytes,
u32 const rk2[], u8 iv[], int first);
asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[],
u32 const rk1[], int rounds, int bytes,
u32 const rk2[], u8 iv[], int first);
struct aesbs_ctx { struct aesbs_ctx {
u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32];
...@@ -64,6 +71,7 @@ struct aesbs_ctr_ctx { ...@@ -64,6 +71,7 @@ struct aesbs_ctr_ctx {
struct aesbs_xts_ctx { struct aesbs_xts_ctx {
struct aesbs_ctx key; struct aesbs_ctx key;
u32 twkey[AES_MAX_KEYLENGTH_U32]; u32 twkey[AES_MAX_KEYLENGTH_U32];
struct crypto_aes_ctx cts;
}; };
static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
...@@ -270,6 +278,10 @@ static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, ...@@ -270,6 +278,10 @@ static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
return err; return err;
key_len /= 2; key_len /= 2;
err = aes_expandkey(&ctx->cts, in_key, key_len);
if (err)
return err;
err = aes_expandkey(&rk, in_key + key_len, key_len); err = aes_expandkey(&rk, in_key + key_len, key_len);
if (err) if (err)
return err; return err;
...@@ -302,48 +314,119 @@ static int ctr_encrypt_sync(struct skcipher_request *req) ...@@ -302,48 +314,119 @@ static int ctr_encrypt_sync(struct skcipher_request *req)
return ctr_encrypt(req); return ctr_encrypt(req);
} }
static int __xts_crypt(struct skcipher_request *req, static int __xts_crypt(struct skcipher_request *req, bool encrypt,
void (*fn)(u8 out[], u8 const in[], u8 const rk[], void (*fn)(u8 out[], u8 const in[], u8 const rk[],
int rounds, int blocks, u8 iv[])) int rounds, int blocks, u8 iv[]))
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
int tail = req->cryptlen % (8 * AES_BLOCK_SIZE);
struct scatterlist sg_src[2], sg_dst[2];
struct skcipher_request subreq;
struct scatterlist *src, *dst;
struct skcipher_walk walk; struct skcipher_walk walk;
int err; int nbytes, err;
int first = 1;
u8 *out, *in;
if (req->cryptlen < AES_BLOCK_SIZE)
return -EINVAL;
/* ensure that the cts tail is covered by a single step */
if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) {
int xts_blocks = DIV_ROUND_UP(req->cryptlen,
AES_BLOCK_SIZE) - 2;
skcipher_request_set_tfm(&subreq, tfm);
skcipher_request_set_callback(&subreq,
skcipher_request_flags(req),
NULL, NULL);
skcipher_request_set_crypt(&subreq, req->src, req->dst,
xts_blocks * AES_BLOCK_SIZE,
req->iv);
req = &subreq;
} else {
tail = 0;
}
err = skcipher_walk_virt(&walk, req, false); err = skcipher_walk_virt(&walk, req, false);
if (err) if (err)
return err; return err;
kernel_neon_begin();
neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1);
kernel_neon_end();
while (walk.nbytes >= AES_BLOCK_SIZE) { while (walk.nbytes >= AES_BLOCK_SIZE) {
unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
if (walk.nbytes < walk.total) if (walk.nbytes < walk.total || walk.nbytes % AES_BLOCK_SIZE)
blocks = round_down(blocks, blocks = round_down(blocks,
walk.stride / AES_BLOCK_SIZE); walk.stride / AES_BLOCK_SIZE);
out = walk.dst.virt.addr;
in = walk.src.virt.addr;
nbytes = walk.nbytes;
kernel_neon_begin(); kernel_neon_begin();
fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk, if (likely(blocks > 6)) { /* plain NEON is faster otherwise */
ctx->key.rounds, blocks, walk.iv); if (first)
neon_aes_ecb_encrypt(walk.iv, walk.iv,
ctx->twkey,
ctx->key.rounds, 1);
first = 0;
fn(out, in, ctx->key.rk, ctx->key.rounds, blocks,
walk.iv);
out += blocks * AES_BLOCK_SIZE;
in += blocks * AES_BLOCK_SIZE;
nbytes -= blocks * AES_BLOCK_SIZE;
}
if (walk.nbytes == walk.total && nbytes > 0)
goto xts_tail;
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, skcipher_walk_done(&walk, nbytes);
walk.nbytes - blocks * AES_BLOCK_SIZE);
} }
if (err || likely(!tail))
return err;
/* handle ciphertext stealing */
dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
if (req->dst != req->src)
dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
req->iv);
err = skcipher_walk_virt(&walk, req, false);
if (err)
return err; return err;
out = walk.dst.virt.addr;
in = walk.src.virt.addr;
nbytes = walk.nbytes;
kernel_neon_begin();
xts_tail:
if (encrypt)
neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds,
nbytes, ctx->twkey, walk.iv, first ?: 2);
else
neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds,
nbytes, ctx->twkey, walk.iv, first ?: 2);
kernel_neon_end();
return skcipher_walk_done(&walk, 0);
} }
static int xts_encrypt(struct skcipher_request *req) static int xts_encrypt(struct skcipher_request *req)
{ {
return __xts_crypt(req, aesbs_xts_encrypt); return __xts_crypt(req, true, aesbs_xts_encrypt);
} }
static int xts_decrypt(struct skcipher_request *req) static int xts_decrypt(struct skcipher_request *req)
{ {
return __xts_crypt(req, aesbs_xts_decrypt); return __xts_crypt(req, false, aesbs_xts_decrypt);
} }
static struct skcipher_alg aes_algs[] = { { static struct skcipher_alg aes_algs[] = { {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment