Commit fd94fcf0 authored by Nathan Huckleberry's avatar Nathan Huckleberry Committed by Herbert Xu

crypto: x86/aesni-xctr - Add accelerated implementation of XCTR

Add hardware accelerated version of XCTR for x86-64 CPUs with AESNI
support.

More information on XCTR can be found in the HCTR2 paper:
"Length-preserving encryption with HCTR2":
https://eprint.iacr.org/2021/1441.pdfSigned-off-by: default avatarNathan Huckleberry <nhuck@google.com>
Reviewed-by: default avatarArd Biesheuvel <ardb@kernel.org>
Reviewed-by: default avatarEric Biggers <ebiggers@google.com>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 7ff554ce
...@@ -23,6 +23,11 @@ ...@@ -23,6 +23,11 @@
#define VMOVDQ vmovdqu #define VMOVDQ vmovdqu
/*
* Note: the "x" prefix in these aliases means "this is an xmm register". The
* alias prefixes have no relation to XCTR where the "X" prefix means "XOR
* counter".
*/
#define xdata0 %xmm0 #define xdata0 %xmm0
#define xdata1 %xmm1 #define xdata1 %xmm1
#define xdata2 %xmm2 #define xdata2 %xmm2
...@@ -31,8 +36,10 @@ ...@@ -31,8 +36,10 @@
#define xdata5 %xmm5 #define xdata5 %xmm5
#define xdata6 %xmm6 #define xdata6 %xmm6
#define xdata7 %xmm7 #define xdata7 %xmm7
#define xcounter %xmm8 #define xcounter %xmm8 // CTR mode only
#define xbyteswap %xmm9 #define xiv %xmm8 // XCTR mode only
#define xbyteswap %xmm9 // CTR mode only
#define xtmp %xmm9 // XCTR mode only
#define xkey0 %xmm10 #define xkey0 %xmm10
#define xkey4 %xmm11 #define xkey4 %xmm11
#define xkey8 %xmm12 #define xkey8 %xmm12
...@@ -45,7 +52,7 @@ ...@@ -45,7 +52,7 @@
#define p_keys %rdx #define p_keys %rdx
#define p_out %rcx #define p_out %rcx
#define num_bytes %r8 #define num_bytes %r8
#define counter %r9 // XCTR mode only
#define tmp %r10 #define tmp %r10
#define DDQ_DATA 0 #define DDQ_DATA 0
#define XDATA 1 #define XDATA 1
...@@ -102,7 +109,7 @@ ddq_add_8: ...@@ -102,7 +109,7 @@ ddq_add_8:
* do_aes num_in_par load_keys key_len * do_aes num_in_par load_keys key_len
* This increments p_in, but not p_out * This increments p_in, but not p_out
*/ */
.macro do_aes b, k, key_len .macro do_aes b, k, key_len, xctr
.set by, \b .set by, \b
.set load_keys, \k .set load_keys, \k
.set klen, \key_len .set klen, \key_len
...@@ -111,29 +118,48 @@ ddq_add_8: ...@@ -111,29 +118,48 @@ ddq_add_8:
vmovdqa 0*16(p_keys), xkey0 vmovdqa 0*16(p_keys), xkey0
.endif .endif
vpshufb xbyteswap, xcounter, xdata0 .if \xctr
movq counter, xtmp
.set i, 1 .set i, 0
.rept (by - 1) .rept (by)
club XDATA, i club XDATA, i
vpaddq (ddq_add_1 + 16 * (i - 1))(%rip), xcounter, var_xdata vpaddq (ddq_add_1 + 16 * i)(%rip), xtmp, var_xdata
vptest ddq_low_msk(%rip), var_xdata .set i, (i +1)
jnz 1f .endr
vpaddq ddq_high_add_1(%rip), var_xdata, var_xdata .set i, 0
vpaddq ddq_high_add_1(%rip), xcounter, xcounter .rept (by)
1: club XDATA, i
vpshufb xbyteswap, var_xdata, var_xdata vpxor xiv, var_xdata, var_xdata
.set i, (i +1) .set i, (i +1)
.endr .endr
.else
vpshufb xbyteswap, xcounter, xdata0
.set i, 1
.rept (by - 1)
club XDATA, i
vpaddq (ddq_add_1 + 16 * (i - 1))(%rip), xcounter, var_xdata
vptest ddq_low_msk(%rip), var_xdata
jnz 1f
vpaddq ddq_high_add_1(%rip), var_xdata, var_xdata
vpaddq ddq_high_add_1(%rip), xcounter, xcounter
1:
vpshufb xbyteswap, var_xdata, var_xdata
.set i, (i +1)
.endr
.endif
vmovdqa 1*16(p_keys), xkeyA vmovdqa 1*16(p_keys), xkeyA
vpxor xkey0, xdata0, xdata0 vpxor xkey0, xdata0, xdata0
vpaddq (ddq_add_1 + 16 * (by - 1))(%rip), xcounter, xcounter .if \xctr
vptest ddq_low_msk(%rip), xcounter add $by, counter
jnz 1f .else
vpaddq ddq_high_add_1(%rip), xcounter, xcounter vpaddq (ddq_add_1 + 16 * (by - 1))(%rip), xcounter, xcounter
1: vptest ddq_low_msk(%rip), xcounter
jnz 1f
vpaddq ddq_high_add_1(%rip), xcounter, xcounter
1:
.endif
.set i, 1 .set i, 1
.rept (by - 1) .rept (by - 1)
...@@ -371,94 +397,99 @@ ddq_add_8: ...@@ -371,94 +397,99 @@ ddq_add_8:
.endr .endr
.endm .endm
.macro do_aes_load val, key_len .macro do_aes_load val, key_len, xctr
do_aes \val, 1, \key_len do_aes \val, 1, \key_len, \xctr
.endm .endm
.macro do_aes_noload val, key_len .macro do_aes_noload val, key_len, xctr
do_aes \val, 0, \key_len do_aes \val, 0, \key_len, \xctr
.endm .endm
/* main body of aes ctr load */ /* main body of aes ctr load */
.macro do_aes_ctrmain key_len .macro do_aes_ctrmain key_len, xctr
cmp $16, num_bytes cmp $16, num_bytes
jb .Ldo_return2\key_len jb .Ldo_return2\xctr\key_len
vmovdqa byteswap_const(%rip), xbyteswap .if \xctr
vmovdqu (p_iv), xcounter shr $4, counter
vpshufb xbyteswap, xcounter, xcounter vmovdqu (p_iv), xiv
.else
vmovdqa byteswap_const(%rip), xbyteswap
vmovdqu (p_iv), xcounter
vpshufb xbyteswap, xcounter, xcounter
.endif
mov num_bytes, tmp mov num_bytes, tmp
and $(7*16), tmp and $(7*16), tmp
jz .Lmult_of_8_blks\key_len jz .Lmult_of_8_blks\xctr\key_len
/* 1 <= tmp <= 7 */ /* 1 <= tmp <= 7 */
cmp $(4*16), tmp cmp $(4*16), tmp
jg .Lgt4\key_len jg .Lgt4\xctr\key_len
je .Leq4\key_len je .Leq4\xctr\key_len
.Llt4\key_len: .Llt4\xctr\key_len:
cmp $(2*16), tmp cmp $(2*16), tmp
jg .Leq3\key_len jg .Leq3\xctr\key_len
je .Leq2\key_len je .Leq2\xctr\key_len
.Leq1\key_len: .Leq1\xctr\key_len:
do_aes_load 1, \key_len do_aes_load 1, \key_len, \xctr
add $(1*16), p_out add $(1*16), p_out
and $(~7*16), num_bytes and $(~7*16), num_bytes
jz .Ldo_return2\key_len jz .Ldo_return2\xctr\key_len
jmp .Lmain_loop2\key_len jmp .Lmain_loop2\xctr\key_len
.Leq2\key_len: .Leq2\xctr\key_len:
do_aes_load 2, \key_len do_aes_load 2, \key_len, \xctr
add $(2*16), p_out add $(2*16), p_out
and $(~7*16), num_bytes and $(~7*16), num_bytes
jz .Ldo_return2\key_len jz .Ldo_return2\xctr\key_len
jmp .Lmain_loop2\key_len jmp .Lmain_loop2\xctr\key_len
.Leq3\key_len: .Leq3\xctr\key_len:
do_aes_load 3, \key_len do_aes_load 3, \key_len, \xctr
add $(3*16), p_out add $(3*16), p_out
and $(~7*16), num_bytes and $(~7*16), num_bytes
jz .Ldo_return2\key_len jz .Ldo_return2\xctr\key_len
jmp .Lmain_loop2\key_len jmp .Lmain_loop2\xctr\key_len
.Leq4\key_len: .Leq4\xctr\key_len:
do_aes_load 4, \key_len do_aes_load 4, \key_len, \xctr
add $(4*16), p_out add $(4*16), p_out
and $(~7*16), num_bytes and $(~7*16), num_bytes
jz .Ldo_return2\key_len jz .Ldo_return2\xctr\key_len
jmp .Lmain_loop2\key_len jmp .Lmain_loop2\xctr\key_len
.Lgt4\key_len: .Lgt4\xctr\key_len:
cmp $(6*16), tmp cmp $(6*16), tmp
jg .Leq7\key_len jg .Leq7\xctr\key_len
je .Leq6\key_len je .Leq6\xctr\key_len
.Leq5\key_len: .Leq5\xctr\key_len:
do_aes_load 5, \key_len do_aes_load 5, \key_len, \xctr
add $(5*16), p_out add $(5*16), p_out
and $(~7*16), num_bytes and $(~7*16), num_bytes
jz .Ldo_return2\key_len jz .Ldo_return2\xctr\key_len
jmp .Lmain_loop2\key_len jmp .Lmain_loop2\xctr\key_len
.Leq6\key_len: .Leq6\xctr\key_len:
do_aes_load 6, \key_len do_aes_load 6, \key_len, \xctr
add $(6*16), p_out add $(6*16), p_out
and $(~7*16), num_bytes and $(~7*16), num_bytes
jz .Ldo_return2\key_len jz .Ldo_return2\xctr\key_len
jmp .Lmain_loop2\key_len jmp .Lmain_loop2\xctr\key_len
.Leq7\key_len: .Leq7\xctr\key_len:
do_aes_load 7, \key_len do_aes_load 7, \key_len, \xctr
add $(7*16), p_out add $(7*16), p_out
and $(~7*16), num_bytes and $(~7*16), num_bytes
jz .Ldo_return2\key_len jz .Ldo_return2\xctr\key_len
jmp .Lmain_loop2\key_len jmp .Lmain_loop2\xctr\key_len
.Lmult_of_8_blks\key_len: .Lmult_of_8_blks\xctr\key_len:
.if (\key_len != KEY_128) .if (\key_len != KEY_128)
vmovdqa 0*16(p_keys), xkey0 vmovdqa 0*16(p_keys), xkey0
vmovdqa 4*16(p_keys), xkey4 vmovdqa 4*16(p_keys), xkey4
...@@ -471,17 +502,19 @@ ddq_add_8: ...@@ -471,17 +502,19 @@ ddq_add_8:
vmovdqa 9*16(p_keys), xkey12 vmovdqa 9*16(p_keys), xkey12
.endif .endif
.align 16 .align 16
.Lmain_loop2\key_len: .Lmain_loop2\xctr\key_len:
/* num_bytes is a multiple of 8 and >0 */ /* num_bytes is a multiple of 8 and >0 */
do_aes_noload 8, \key_len do_aes_noload 8, \key_len, \xctr
add $(8*16), p_out add $(8*16), p_out
sub $(8*16), num_bytes sub $(8*16), num_bytes
jne .Lmain_loop2\key_len jne .Lmain_loop2\xctr\key_len
.Ldo_return2\key_len: .Ldo_return2\xctr\key_len:
/* return updated IV */ .if !\xctr
vpshufb xbyteswap, xcounter, xcounter /* return updated IV */
vmovdqu xcounter, (p_iv) vpshufb xbyteswap, xcounter, xcounter
vmovdqu xcounter, (p_iv)
.endif
RET RET
.endm .endm
...@@ -494,7 +527,7 @@ ddq_add_8: ...@@ -494,7 +527,7 @@ ddq_add_8:
*/ */
SYM_FUNC_START(aes_ctr_enc_128_avx_by8) SYM_FUNC_START(aes_ctr_enc_128_avx_by8)
/* call the aes main loop */ /* call the aes main loop */
do_aes_ctrmain KEY_128 do_aes_ctrmain KEY_128 0
SYM_FUNC_END(aes_ctr_enc_128_avx_by8) SYM_FUNC_END(aes_ctr_enc_128_avx_by8)
...@@ -507,7 +540,7 @@ SYM_FUNC_END(aes_ctr_enc_128_avx_by8) ...@@ -507,7 +540,7 @@ SYM_FUNC_END(aes_ctr_enc_128_avx_by8)
*/ */
SYM_FUNC_START(aes_ctr_enc_192_avx_by8) SYM_FUNC_START(aes_ctr_enc_192_avx_by8)
/* call the aes main loop */ /* call the aes main loop */
do_aes_ctrmain KEY_192 do_aes_ctrmain KEY_192 0
SYM_FUNC_END(aes_ctr_enc_192_avx_by8) SYM_FUNC_END(aes_ctr_enc_192_avx_by8)
...@@ -520,6 +553,45 @@ SYM_FUNC_END(aes_ctr_enc_192_avx_by8) ...@@ -520,6 +553,45 @@ SYM_FUNC_END(aes_ctr_enc_192_avx_by8)
*/ */
SYM_FUNC_START(aes_ctr_enc_256_avx_by8) SYM_FUNC_START(aes_ctr_enc_256_avx_by8)
/* call the aes main loop */ /* call the aes main loop */
do_aes_ctrmain KEY_256 do_aes_ctrmain KEY_256 0
SYM_FUNC_END(aes_ctr_enc_256_avx_by8) SYM_FUNC_END(aes_ctr_enc_256_avx_by8)
/*
* routine to do AES128 XCTR enc/decrypt "by8"
* XMM registers are clobbered.
* Saving/restoring must be done at a higher level
* aes_xctr_enc_128_avx_by8(const u8 *in, const u8 *iv, const void *keys,
* u8* out, unsigned int num_bytes, unsigned int byte_ctr)
*/
SYM_FUNC_START(aes_xctr_enc_128_avx_by8)
/* call the aes main loop */
do_aes_ctrmain KEY_128 1
SYM_FUNC_END(aes_xctr_enc_128_avx_by8)
/*
* routine to do AES192 XCTR enc/decrypt "by8"
* XMM registers are clobbered.
* Saving/restoring must be done at a higher level
* aes_xctr_enc_192_avx_by8(const u8 *in, const u8 *iv, const void *keys,
* u8* out, unsigned int num_bytes, unsigned int byte_ctr)
*/
SYM_FUNC_START(aes_xctr_enc_192_avx_by8)
/* call the aes main loop */
do_aes_ctrmain KEY_192 1
SYM_FUNC_END(aes_xctr_enc_192_avx_by8)
/*
* routine to do AES256 XCTR enc/decrypt "by8"
* XMM registers are clobbered.
* Saving/restoring must be done at a higher level
* aes_xctr_enc_256_avx_by8(const u8 *in, const u8 *iv, const void *keys,
* u8* out, unsigned int num_bytes, unsigned int byte_ctr)
*/
SYM_FUNC_START(aes_xctr_enc_256_avx_by8)
/* call the aes main loop */
do_aes_ctrmain KEY_256 1
SYM_FUNC_END(aes_xctr_enc_256_avx_by8)
...@@ -135,6 +135,20 @@ asmlinkage void aes_ctr_enc_192_avx_by8(const u8 *in, u8 *iv, ...@@ -135,6 +135,20 @@ asmlinkage void aes_ctr_enc_192_avx_by8(const u8 *in, u8 *iv,
void *keys, u8 *out, unsigned int num_bytes); void *keys, u8 *out, unsigned int num_bytes);
asmlinkage void aes_ctr_enc_256_avx_by8(const u8 *in, u8 *iv, asmlinkage void aes_ctr_enc_256_avx_by8(const u8 *in, u8 *iv,
void *keys, u8 *out, unsigned int num_bytes); void *keys, u8 *out, unsigned int num_bytes);
asmlinkage void aes_xctr_enc_128_avx_by8(const u8 *in, const u8 *iv,
const void *keys, u8 *out, unsigned int num_bytes,
unsigned int byte_ctr);
asmlinkage void aes_xctr_enc_192_avx_by8(const u8 *in, const u8 *iv,
const void *keys, u8 *out, unsigned int num_bytes,
unsigned int byte_ctr);
asmlinkage void aes_xctr_enc_256_avx_by8(const u8 *in, const u8 *iv,
const void *keys, u8 *out, unsigned int num_bytes,
unsigned int byte_ctr);
/* /*
* asmlinkage void aesni_gcm_init_avx_gen2() * asmlinkage void aesni_gcm_init_avx_gen2()
* gcm_data *my_ctx_data, context data * gcm_data *my_ctx_data, context data
...@@ -527,6 +541,59 @@ static int ctr_crypt(struct skcipher_request *req) ...@@ -527,6 +541,59 @@ static int ctr_crypt(struct skcipher_request *req)
return err; return err;
} }
static void aesni_xctr_enc_avx_tfm(struct crypto_aes_ctx *ctx, u8 *out,
const u8 *in, unsigned int len, u8 *iv,
unsigned int byte_ctr)
{
if (ctx->key_length == AES_KEYSIZE_128)
aes_xctr_enc_128_avx_by8(in, iv, (void *)ctx, out, len,
byte_ctr);
else if (ctx->key_length == AES_KEYSIZE_192)
aes_xctr_enc_192_avx_by8(in, iv, (void *)ctx, out, len,
byte_ctr);
else
aes_xctr_enc_256_avx_by8(in, iv, (void *)ctx, out, len,
byte_ctr);
}
static int xctr_crypt(struct skcipher_request *req)
{
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct crypto_aes_ctx *ctx = aes_ctx(crypto_skcipher_ctx(tfm));
u8 keystream[AES_BLOCK_SIZE];
struct skcipher_walk walk;
unsigned int nbytes;
unsigned int byte_ctr = 0;
int err;
__le32 block[AES_BLOCK_SIZE / sizeof(__le32)];
err = skcipher_walk_virt(&walk, req, false);
while ((nbytes = walk.nbytes) > 0) {
kernel_fpu_begin();
if (nbytes & AES_BLOCK_MASK)
aesni_xctr_enc_avx_tfm(ctx, walk.dst.virt.addr,
walk.src.virt.addr, nbytes & AES_BLOCK_MASK,
walk.iv, byte_ctr);
nbytes &= ~AES_BLOCK_MASK;
byte_ctr += walk.nbytes - nbytes;
if (walk.nbytes == walk.total && nbytes > 0) {
memcpy(block, walk.iv, AES_BLOCK_SIZE);
block[0] ^= cpu_to_le32(1 + byte_ctr / AES_BLOCK_SIZE);
aesni_enc(ctx, keystream, (u8 *)block);
crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes -
nbytes, walk.src.virt.addr + walk.nbytes
- nbytes, keystream, nbytes);
byte_ctr += nbytes;
nbytes = 0;
}
kernel_fpu_end();
err = skcipher_walk_done(&walk, nbytes);
}
return err;
}
static int static int
rfc4106_set_hash_subkey(u8 *hash_subkey, const u8 *key, unsigned int key_len) rfc4106_set_hash_subkey(u8 *hash_subkey, const u8 *key, unsigned int key_len)
{ {
...@@ -1050,6 +1117,33 @@ static struct skcipher_alg aesni_skciphers[] = { ...@@ -1050,6 +1117,33 @@ static struct skcipher_alg aesni_skciphers[] = {
static static
struct simd_skcipher_alg *aesni_simd_skciphers[ARRAY_SIZE(aesni_skciphers)]; struct simd_skcipher_alg *aesni_simd_skciphers[ARRAY_SIZE(aesni_skciphers)];
#ifdef CONFIG_X86_64
/*
* XCTR does not have a non-AVX implementation, so it must be enabled
* conditionally.
*/
static struct skcipher_alg aesni_xctr = {
.base = {
.cra_name = "__xctr(aes)",
.cra_driver_name = "__xctr-aes-aesni",
.cra_priority = 400,
.cra_flags = CRYPTO_ALG_INTERNAL,
.cra_blocksize = 1,
.cra_ctxsize = CRYPTO_AES_CTX_SIZE,
.cra_module = THIS_MODULE,
},
.min_keysize = AES_MIN_KEY_SIZE,
.max_keysize = AES_MAX_KEY_SIZE,
.ivsize = AES_BLOCK_SIZE,
.chunksize = AES_BLOCK_SIZE,
.setkey = aesni_skcipher_setkey,
.encrypt = xctr_crypt,
.decrypt = xctr_crypt,
};
static struct simd_skcipher_alg *aesni_simd_xctr;
#endif /* CONFIG_X86_64 */
#ifdef CONFIG_X86_64 #ifdef CONFIG_X86_64
static int generic_gcmaes_set_key(struct crypto_aead *aead, const u8 *key, static int generic_gcmaes_set_key(struct crypto_aead *aead, const u8 *key,
unsigned int key_len) unsigned int key_len)
...@@ -1163,7 +1257,7 @@ static int __init aesni_init(void) ...@@ -1163,7 +1257,7 @@ static int __init aesni_init(void)
static_call_update(aesni_ctr_enc_tfm, aesni_ctr_enc_avx_tfm); static_call_update(aesni_ctr_enc_tfm, aesni_ctr_enc_avx_tfm);
pr_info("AES CTR mode by8 optimization enabled\n"); pr_info("AES CTR mode by8 optimization enabled\n");
} }
#endif #endif /* CONFIG_X86_64 */
err = crypto_register_alg(&aesni_cipher_alg); err = crypto_register_alg(&aesni_cipher_alg);
if (err) if (err)
...@@ -1180,8 +1274,22 @@ static int __init aesni_init(void) ...@@ -1180,8 +1274,22 @@ static int __init aesni_init(void)
if (err) if (err)
goto unregister_skciphers; goto unregister_skciphers;
#ifdef CONFIG_X86_64
if (boot_cpu_has(X86_FEATURE_AVX))
err = simd_register_skciphers_compat(&aesni_xctr, 1,
&aesni_simd_xctr);
if (err)
goto unregister_aeads;
#endif /* CONFIG_X86_64 */
return 0; return 0;
#ifdef CONFIG_X86_64
unregister_aeads:
simd_unregister_aeads(aesni_aeads, ARRAY_SIZE(aesni_aeads),
aesni_simd_aeads);
#endif /* CONFIG_X86_64 */
unregister_skciphers: unregister_skciphers:
simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers), simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
aesni_simd_skciphers); aesni_simd_skciphers);
...@@ -1197,6 +1305,10 @@ static void __exit aesni_exit(void) ...@@ -1197,6 +1305,10 @@ static void __exit aesni_exit(void)
simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers), simd_unregister_skciphers(aesni_skciphers, ARRAY_SIZE(aesni_skciphers),
aesni_simd_skciphers); aesni_simd_skciphers);
crypto_unregister_alg(&aesni_cipher_alg); crypto_unregister_alg(&aesni_cipher_alg);
#ifdef CONFIG_X86_64
if (boot_cpu_has(X86_FEATURE_AVX))
simd_unregister_skciphers(&aesni_xctr, 1, &aesni_simd_xctr);
#endif /* CONFIG_X86_64 */
} }
late_initcall(aesni_init); late_initcall(aesni_init);
......
...@@ -1169,7 +1169,7 @@ config CRYPTO_AES_NI_INTEL ...@@ -1169,7 +1169,7 @@ config CRYPTO_AES_NI_INTEL
In addition to AES cipher algorithm support, the acceleration In addition to AES cipher algorithm support, the acceleration
for some popular block cipher mode is supported too, including for some popular block cipher mode is supported too, including
ECB, CBC, LRW, XTS. The 64 bit version has additional ECB, CBC, LRW, XTS. The 64 bit version has additional
acceleration for CTR. acceleration for CTR and XCTR.
config CRYPTO_AES_SPARC64 config CRYPTO_AES_SPARC64
tristate "AES cipher algorithms (SPARC64)" tristate "AES cipher algorithms (SPARC64)"
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment