Commit 62508017 authored by Tianjia Zhang's avatar Tianjia Zhang Committed by Herbert Xu

crypto: arm64/sm4 - refactor and simplify NEON implementation

This patch does not add new features. The main work is to refactor and
simplify the implementation of SM4 NEON, which is reflected in the
following aspects:

The accelerated implementation supports the arbitrary number of blocks,
not just multiples of 8, which simplifies the implementation and brings
some optimization acceleration for data that is not aligned by 8 blocks.

When loading the input data, use the ld4 instruction to replace the
original ld1 instruction as much as possible, which will save the cost
of matrix transposition of the input data.

Use 8-block parallelism whenever possible to speed up matrix transpose
and rotation operations, instead of up to 4-block parallelism.
Signed-off-by: default avatarTianjia Zhang <tianjia.zhang@linux.alibaba.com>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent a41b2129
This diff is collapsed.
......@@ -18,19 +18,14 @@
#include <crypto/internal/skcipher.h>
#include <crypto/sm4.h>
#define BYTES2BLKS(nbytes) ((nbytes) >> 4)
#define BYTES2BLK8(nbytes) (((nbytes) >> 4) & ~(8 - 1))
asmlinkage void sm4_neon_crypt_blk1_8(const u32 *rkey, u8 *dst, const u8 *src,
unsigned int nblks);
asmlinkage void sm4_neon_crypt_blk8(const u32 *rkey, u8 *dst, const u8 *src,
unsigned int nblks);
asmlinkage void sm4_neon_cbc_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
u8 *iv, unsigned int nblks);
asmlinkage void sm4_neon_cfb_dec_blk8(const u32 *rkey, u8 *dst, const u8 *src,
u8 *iv, unsigned int nblks);
asmlinkage void sm4_neon_ctr_enc_blk8(const u32 *rkey, u8 *dst, const u8 *src,
u8 *iv, unsigned int nblks);
asmlinkage void sm4_neon_crypt(const u32 *rkey, u8 *dst, const u8 *src,
unsigned int nblocks);
asmlinkage void sm4_neon_cbc_dec(const u32 *rkey_dec, u8 *dst, const u8 *src,
u8 *iv, unsigned int nblocks);
asmlinkage void sm4_neon_cfb_dec(const u32 *rkey_enc, u8 *dst, const u8 *src,
u8 *iv, unsigned int nblocks);
asmlinkage void sm4_neon_ctr_crypt(const u32 *rkey_enc, u8 *dst, const u8 *src,
u8 *iv, unsigned int nblocks);
static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
unsigned int key_len)
......@@ -51,27 +46,18 @@ static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
while ((nbytes = walk.nbytes) > 0) {
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
unsigned int nblks;
unsigned int nblocks;
nblocks = nbytes / SM4_BLOCK_SIZE;
if (nblocks) {
kernel_neon_begin();
nblks = BYTES2BLK8(nbytes);
if (nblks) {
sm4_neon_crypt_blk8(rkey, dst, src, nblks);
dst += nblks * SM4_BLOCK_SIZE;
src += nblks * SM4_BLOCK_SIZE;
nbytes -= nblks * SM4_BLOCK_SIZE;
}
nblks = BYTES2BLKS(nbytes);
if (nblks) {
sm4_neon_crypt_blk1_8(rkey, dst, src, nblks);
nbytes -= nblks * SM4_BLOCK_SIZE;
}
sm4_neon_crypt(rkey, dst, src, nblocks);
kernel_neon_end();
}
err = skcipher_walk_done(&walk, nbytes);
err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
}
return err;
......@@ -138,48 +124,19 @@ static int sm4_cbc_decrypt(struct skcipher_request *req)
while ((nbytes = walk.nbytes) > 0) {
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
unsigned int nblks;
unsigned int nblocks;
nblocks = nbytes / SM4_BLOCK_SIZE;
if (nblocks) {
kernel_neon_begin();
nblks = BYTES2BLK8(nbytes);
if (nblks) {
sm4_neon_cbc_dec_blk8(ctx->rkey_dec, dst, src,
walk.iv, nblks);
dst += nblks * SM4_BLOCK_SIZE;
src += nblks * SM4_BLOCK_SIZE;
nbytes -= nblks * SM4_BLOCK_SIZE;
}
nblks = BYTES2BLKS(nbytes);
if (nblks) {
u8 keystream[SM4_BLOCK_SIZE * 8];
u8 iv[SM4_BLOCK_SIZE];
int i;
sm4_neon_crypt_blk1_8(ctx->rkey_dec, keystream,
src, nblks);
src += ((int)nblks - 2) * SM4_BLOCK_SIZE;
dst += (nblks - 1) * SM4_BLOCK_SIZE;
memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
for (i = nblks - 1; i > 0; i--) {
crypto_xor_cpy(dst, src,
&keystream[i * SM4_BLOCK_SIZE],
SM4_BLOCK_SIZE);
src -= SM4_BLOCK_SIZE;
dst -= SM4_BLOCK_SIZE;
}
crypto_xor_cpy(dst, walk.iv,
keystream, SM4_BLOCK_SIZE);
memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
nbytes -= nblks * SM4_BLOCK_SIZE;
}
sm4_neon_cbc_dec(ctx->rkey_dec, dst, src,
walk.iv, nblocks);
kernel_neon_end();
}
err = skcipher_walk_done(&walk, nbytes);
err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
}
return err;
......@@ -238,42 +195,22 @@ static int sm4_cfb_decrypt(struct skcipher_request *req)
while ((nbytes = walk.nbytes) > 0) {
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
unsigned int nblks;
unsigned int nblocks;
nblocks = nbytes / SM4_BLOCK_SIZE;
if (nblocks) {
kernel_neon_begin();
nblks = BYTES2BLK8(nbytes);
if (nblks) {
sm4_neon_cfb_dec_blk8(ctx->rkey_enc, dst, src,
walk.iv, nblks);
dst += nblks * SM4_BLOCK_SIZE;
src += nblks * SM4_BLOCK_SIZE;
nbytes -= nblks * SM4_BLOCK_SIZE;
}
nblks = BYTES2BLKS(nbytes);
if (nblks) {
u8 keystream[SM4_BLOCK_SIZE * 8];
memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
if (nblks > 1)
memcpy(&keystream[SM4_BLOCK_SIZE], src,
(nblks - 1) * SM4_BLOCK_SIZE);
memcpy(walk.iv, src + (nblks - 1) * SM4_BLOCK_SIZE,
SM4_BLOCK_SIZE);
sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
keystream, nblks);
crypto_xor_cpy(dst, src, keystream,
nblks * SM4_BLOCK_SIZE);
dst += nblks * SM4_BLOCK_SIZE;
src += nblks * SM4_BLOCK_SIZE;
nbytes -= nblks * SM4_BLOCK_SIZE;
}
sm4_neon_cfb_dec(ctx->rkey_enc, dst, src,
walk.iv, nblocks);
kernel_neon_end();
dst += nblocks * SM4_BLOCK_SIZE;
src += nblocks * SM4_BLOCK_SIZE;
nbytes -= nblocks * SM4_BLOCK_SIZE;
}
/* tail */
if (walk.nbytes == walk.total && nbytes > 0) {
u8 keystream[SM4_BLOCK_SIZE];
......@@ -302,41 +239,22 @@ static int sm4_ctr_crypt(struct skcipher_request *req)
while ((nbytes = walk.nbytes) > 0) {
const u8 *src = walk.src.virt.addr;
u8 *dst = walk.dst.virt.addr;
unsigned int nblks;
unsigned int nblocks;
nblocks = nbytes / SM4_BLOCK_SIZE;
if (nblocks) {
kernel_neon_begin();
nblks = BYTES2BLK8(nbytes);
if (nblks) {
sm4_neon_ctr_enc_blk8(ctx->rkey_enc, dst, src,
walk.iv, nblks);
dst += nblks * SM4_BLOCK_SIZE;
src += nblks * SM4_BLOCK_SIZE;
nbytes -= nblks * SM4_BLOCK_SIZE;
}
sm4_neon_ctr_crypt(ctx->rkey_enc, dst, src,
walk.iv, nblocks);
nblks = BYTES2BLKS(nbytes);
if (nblks) {
u8 keystream[SM4_BLOCK_SIZE * 8];
int i;
kernel_neon_end();
for (i = 0; i < nblks; i++) {
memcpy(&keystream[i * SM4_BLOCK_SIZE],
walk.iv, SM4_BLOCK_SIZE);
crypto_inc(walk.iv, SM4_BLOCK_SIZE);
}
sm4_neon_crypt_blk1_8(ctx->rkey_enc, keystream,
keystream, nblks);
crypto_xor_cpy(dst, src, keystream,
nblks * SM4_BLOCK_SIZE);
dst += nblks * SM4_BLOCK_SIZE;
src += nblks * SM4_BLOCK_SIZE;
nbytes -= nblks * SM4_BLOCK_SIZE;
dst += nblocks * SM4_BLOCK_SIZE;
src += nblocks * SM4_BLOCK_SIZE;
nbytes -= nblocks * SM4_BLOCK_SIZE;
}
kernel_neon_end();
/* tail */
if (walk.nbytes == walk.total && nbytes > 0) {
u8 keystream[SM4_BLOCK_SIZE];
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment