Commit 7cceca8b authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu

crypto: arm64/aes - implement support for XTS ciphertext stealing

Add the missing support for ciphertext stealing in the implementation
of AES-XTS, which is part of the XTS specification but was omitted up
until now due to lack of a need for it.

The asm helpers are updated so they can deal with any input size, as
long as the last full block and the final partial block are presented
at the same time. The glue code is updated so that the common case of
operating on a sector or page is mostly as before. When CTS is needed,
the walk is split up into two pieces, unless the entire input is covered
by a single step.
Signed-off-by: default avatarArd Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 7c9d65c4
...@@ -90,10 +90,10 @@ asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], ...@@ -90,10 +90,10 @@ asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
int rounds, int blocks, u8 ctr[]); int rounds, int blocks, u8 ctr[]);
asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[], asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
int rounds, int blocks, u32 const rk2[], u8 iv[], int rounds, int bytes, u32 const rk2[], u8 iv[],
int first); int first);
asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[], asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
int rounds, int blocks, u32 const rk2[], u8 iv[], int rounds, int bytes, u32 const rk2[], u8 iv[],
int first); int first);
asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[], asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
...@@ -527,21 +527,71 @@ static int __maybe_unused xts_encrypt(struct skcipher_request *req) ...@@ -527,21 +527,71 @@ static int __maybe_unused xts_encrypt(struct skcipher_request *req)
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm); struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
int err, first, rounds = 6 + ctx->key1.key_length / 4; int err, first, rounds = 6 + ctx->key1.key_length / 4;
int tail = req->cryptlen % AES_BLOCK_SIZE;
struct scatterlist sg_src[2], sg_dst[2];
struct skcipher_request subreq;
struct scatterlist *src, *dst;
struct skcipher_walk walk; struct skcipher_walk walk;
unsigned int blocks;
if (req->cryptlen < AES_BLOCK_SIZE)
return -EINVAL;
err = skcipher_walk_virt(&walk, req, false);
if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
int xts_blocks = DIV_ROUND_UP(req->cryptlen,
AES_BLOCK_SIZE) - 2;
skcipher_walk_abort(&walk);
skcipher_request_set_tfm(&subreq, tfm);
skcipher_request_set_callback(&subreq,
skcipher_request_flags(req),
NULL, NULL);
skcipher_request_set_crypt(&subreq, req->src, req->dst,
xts_blocks * AES_BLOCK_SIZE,
req->iv);
req = &subreq;
err = skcipher_walk_virt(&walk, req, false); err = skcipher_walk_virt(&walk, req, false);
} else {
tail = 0;
}
for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
int nbytes = walk.nbytes;
if (walk.nbytes < walk.total)
nbytes &= ~(AES_BLOCK_SIZE - 1);
for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
kernel_neon_begin(); kernel_neon_begin();
aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr, aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
ctx->key1.key_enc, rounds, blocks, ctx->key1.key_enc, rounds, nbytes,
ctx->key2.key_enc, walk.iv, first); ctx->key2.key_enc, walk.iv, first);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
} }
if (err || likely(!tail))
return err;
dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
if (req->dst != req->src)
dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
req->iv);
err = skcipher_walk_virt(&walk, &subreq, false);
if (err)
return err; return err;
kernel_neon_begin();
aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
ctx->key1.key_enc, rounds, walk.nbytes,
ctx->key2.key_enc, walk.iv, first);
kernel_neon_end();
return skcipher_walk_done(&walk, 0);
} }
static int __maybe_unused xts_decrypt(struct skcipher_request *req) static int __maybe_unused xts_decrypt(struct skcipher_request *req)
...@@ -549,21 +599,72 @@ static int __maybe_unused xts_decrypt(struct skcipher_request *req) ...@@ -549,21 +599,72 @@ static int __maybe_unused xts_decrypt(struct skcipher_request *req)
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm); struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
int err, first, rounds = 6 + ctx->key1.key_length / 4; int err, first, rounds = 6 + ctx->key1.key_length / 4;
int tail = req->cryptlen % AES_BLOCK_SIZE;
struct scatterlist sg_src[2], sg_dst[2];
struct skcipher_request subreq;
struct scatterlist *src, *dst;
struct skcipher_walk walk; struct skcipher_walk walk;
unsigned int blocks;
if (req->cryptlen < AES_BLOCK_SIZE)
return -EINVAL;
err = skcipher_walk_virt(&walk, req, false);
if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
int xts_blocks = DIV_ROUND_UP(req->cryptlen,
AES_BLOCK_SIZE) - 2;
skcipher_walk_abort(&walk);
skcipher_request_set_tfm(&subreq, tfm);
skcipher_request_set_callback(&subreq,
skcipher_request_flags(req),
NULL, NULL);
skcipher_request_set_crypt(&subreq, req->src, req->dst,
xts_blocks * AES_BLOCK_SIZE,
req->iv);
req = &subreq;
err = skcipher_walk_virt(&walk, req, false); err = skcipher_walk_virt(&walk, req, false);
} else {
tail = 0;
}
for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
int nbytes = walk.nbytes;
if (walk.nbytes < walk.total)
nbytes &= ~(AES_BLOCK_SIZE - 1);
for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
kernel_neon_begin(); kernel_neon_begin();
aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr, aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
ctx->key1.key_dec, rounds, blocks, ctx->key1.key_dec, rounds, nbytes,
ctx->key2.key_enc, walk.iv, first); ctx->key2.key_enc, walk.iv, first);
kernel_neon_end(); kernel_neon_end();
err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
} }
if (err || likely(!tail))
return err;
dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
if (req->dst != req->src)
dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
req->iv);
err = skcipher_walk_virt(&walk, &subreq, false);
if (err)
return err; return err;
kernel_neon_begin();
aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
ctx->key1.key_dec, rounds, walk.nbytes,
ctx->key2.key_enc, walk.iv, first);
kernel_neon_end();
return skcipher_walk_done(&walk, 0);
} }
static struct skcipher_alg aes_algs[] = { { static struct skcipher_alg aes_algs[] = { {
...@@ -644,6 +745,7 @@ static struct skcipher_alg aes_algs[] = { { ...@@ -644,6 +745,7 @@ static struct skcipher_alg aes_algs[] = { {
.min_keysize = 2 * AES_MIN_KEY_SIZE, .min_keysize = 2 * AES_MIN_KEY_SIZE,
.max_keysize = 2 * AES_MAX_KEY_SIZE, .max_keysize = 2 * AES_MAX_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE, .ivsize = AES_BLOCK_SIZE,
.walksize = 2 * AES_BLOCK_SIZE,
.setkey = xts_set_key, .setkey = xts_set_key,
.encrypt = xts_encrypt, .encrypt = xts_encrypt,
.decrypt = xts_decrypt, .decrypt = xts_decrypt,
......
...@@ -413,10 +413,10 @@ AES_ENDPROC(aes_ctr_encrypt) ...@@ -413,10 +413,10 @@ AES_ENDPROC(aes_ctr_encrypt)
/* /*
* aes_xts_encrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
* int bytes, u8 const rk2[], u8 iv[], int first)
* aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds, * aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
* int blocks, u8 const rk2[], u8 iv[], int first) * int bytes, u8 const rk2[], u8 iv[], int first)
* aes_xts_decrypt(u8 out[], u8 const in[], u8 const rk1[], int rounds,
* int blocks, u8 const rk2[], u8 iv[], int first)
*/ */
.macro next_tweak, out, in, tmp .macro next_tweak, out, in, tmp
...@@ -451,7 +451,7 @@ AES_ENTRY(aes_xts_encrypt) ...@@ -451,7 +451,7 @@ AES_ENTRY(aes_xts_encrypt)
.LxtsencloopNx: .LxtsencloopNx:
next_tweak v4, v4, v8 next_tweak v4, v4, v8
.LxtsencNx: .LxtsencNx:
subs w4, w4, #4 subs w4, w4, #64
bmi .Lxtsenc1x bmi .Lxtsenc1x
ld1 {v0.16b-v3.16b}, [x1], #64 /* get 4 pt blocks */ ld1 {v0.16b-v3.16b}, [x1], #64 /* get 4 pt blocks */
next_tweak v5, v4, v8 next_tweak v5, v4, v8
...@@ -468,33 +468,66 @@ AES_ENTRY(aes_xts_encrypt) ...@@ -468,33 +468,66 @@ AES_ENTRY(aes_xts_encrypt)
eor v2.16b, v2.16b, v6.16b eor v2.16b, v2.16b, v6.16b
st1 {v0.16b-v3.16b}, [x0], #64 st1 {v0.16b-v3.16b}, [x0], #64
mov v4.16b, v7.16b mov v4.16b, v7.16b
cbz w4, .Lxtsencout cbz w4, .Lxtsencret
xts_reload_mask v8 xts_reload_mask v8
b .LxtsencloopNx b .LxtsencloopNx
.Lxtsenc1x: .Lxtsenc1x:
adds w4, w4, #4 adds w4, w4, #64
beq .Lxtsencout beq .Lxtsencout
subs w4, w4, #16
bmi .LxtsencctsNx
.Lxtsencloop: .Lxtsencloop:
ld1 {v1.16b}, [x1], #16 ld1 {v0.16b}, [x1], #16
eor v0.16b, v1.16b, v4.16b .Lxtsencctsout:
eor v0.16b, v0.16b, v4.16b
encrypt_block v0, w3, x2, x8, w7 encrypt_block v0, w3, x2, x8, w7
eor v0.16b, v0.16b, v4.16b eor v0.16b, v0.16b, v4.16b
st1 {v0.16b}, [x0], #16 cbz w4, .Lxtsencout
subs w4, w4, #1 subs w4, w4, #16
beq .Lxtsencout
next_tweak v4, v4, v8 next_tweak v4, v4, v8
bmi .Lxtsenccts
st1 {v0.16b}, [x0], #16
b .Lxtsencloop b .Lxtsencloop
.Lxtsencout: .Lxtsencout:
st1 {v0.16b}, [x0]
.Lxtsencret:
st1 {v4.16b}, [x6] st1 {v4.16b}, [x6]
ldp x29, x30, [sp], #16 ldp x29, x30, [sp], #16
ret ret
AES_ENDPROC(aes_xts_encrypt)
.LxtsencctsNx:
mov v0.16b, v3.16b
sub x0, x0, #16
.Lxtsenccts:
adr_l x8, .Lcts_permute_table
add x1, x1, w4, sxtw /* rewind input pointer */
add w4, w4, #16 /* # bytes in final block */
add x9, x8, #32
add x8, x8, x4
sub x9, x9, x4
add x4, x0, x4 /* output address of final block */
ld1 {v1.16b}, [x1] /* load final block */
ld1 {v2.16b}, [x8]
ld1 {v3.16b}, [x9]
tbl v2.16b, {v0.16b}, v2.16b
tbx v0.16b, {v1.16b}, v3.16b
st1 {v2.16b}, [x4] /* overlapping stores */
mov w4, wzr
b .Lxtsencctsout
AES_ENDPROC(aes_xts_encrypt)
AES_ENTRY(aes_xts_decrypt) AES_ENTRY(aes_xts_decrypt)
stp x29, x30, [sp, #-16]! stp x29, x30, [sp, #-16]!
mov x29, sp mov x29, sp
/* subtract 16 bytes if we are doing CTS */
sub w8, w4, #0x10
tst w4, #0xf
csel w4, w4, w8, eq
ld1 {v4.16b}, [x6] ld1 {v4.16b}, [x6]
xts_load_mask v8 xts_load_mask v8
cbz w7, .Lxtsdecnotfirst cbz w7, .Lxtsdecnotfirst
...@@ -509,7 +542,7 @@ AES_ENTRY(aes_xts_decrypt) ...@@ -509,7 +542,7 @@ AES_ENTRY(aes_xts_decrypt)
.LxtsdecloopNx: .LxtsdecloopNx:
next_tweak v4, v4, v8 next_tweak v4, v4, v8
.LxtsdecNx: .LxtsdecNx:
subs w4, w4, #4 subs w4, w4, #64
bmi .Lxtsdec1x bmi .Lxtsdec1x
ld1 {v0.16b-v3.16b}, [x1], #64 /* get 4 ct blocks */ ld1 {v0.16b-v3.16b}, [x1], #64 /* get 4 ct blocks */
next_tweak v5, v4, v8 next_tweak v5, v4, v8
...@@ -530,22 +563,52 @@ AES_ENTRY(aes_xts_decrypt) ...@@ -530,22 +563,52 @@ AES_ENTRY(aes_xts_decrypt)
xts_reload_mask v8 xts_reload_mask v8
b .LxtsdecloopNx b .LxtsdecloopNx
.Lxtsdec1x: .Lxtsdec1x:
adds w4, w4, #4 adds w4, w4, #64
beq .Lxtsdecout beq .Lxtsdecout
subs w4, w4, #16
.Lxtsdecloop: .Lxtsdecloop:
ld1 {v1.16b}, [x1], #16 ld1 {v0.16b}, [x1], #16
eor v0.16b, v1.16b, v4.16b bmi .Lxtsdeccts
.Lxtsdecctsout:
eor v0.16b, v0.16b, v4.16b
decrypt_block v0, w3, x2, x8, w7 decrypt_block v0, w3, x2, x8, w7
eor v0.16b, v0.16b, v4.16b eor v0.16b, v0.16b, v4.16b
st1 {v0.16b}, [x0], #16 st1 {v0.16b}, [x0], #16
subs w4, w4, #1 cbz w4, .Lxtsdecout
beq .Lxtsdecout subs w4, w4, #16
next_tweak v4, v4, v8 next_tweak v4, v4, v8
b .Lxtsdecloop b .Lxtsdecloop
.Lxtsdecout: .Lxtsdecout:
st1 {v4.16b}, [x6] st1 {v4.16b}, [x6]
ldp x29, x30, [sp], #16 ldp x29, x30, [sp], #16
ret ret
.Lxtsdeccts:
adr_l x8, .Lcts_permute_table
add x1, x1, w4, sxtw /* rewind input pointer */
add w4, w4, #16 /* # bytes in final block */
add x9, x8, #32
add x8, x8, x4
sub x9, x9, x4
add x4, x0, x4 /* output address of final block */
next_tweak v5, v4, v8
ld1 {v1.16b}, [x1] /* load final block */
ld1 {v2.16b}, [x8]
ld1 {v3.16b}, [x9]
eor v0.16b, v0.16b, v5.16b
decrypt_block v0, w3, x2, x8, w7
eor v0.16b, v0.16b, v5.16b
tbl v2.16b, {v0.16b}, v2.16b
tbx v0.16b, {v1.16b}, v3.16b
st1 {v2.16b}, [x4] /* overlapping stores */
mov w4, wzr
b .Lxtsdecctsout
AES_ENDPROC(aes_xts_decrypt) AES_ENDPROC(aes_xts_decrypt)
/* /*
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment