Commit de79d9aa authored by Tianjia Zhang's avatar Tianjia Zhang Committed by Herbert Xu

crypto: x86/sm4 - export reusable AESNI/AVX functions

Export the reusable functions in the SM4 AESNI/AVX implementation,
mainly public functions, which are used to develop the SM4 AESNI/AVX2
implementation, and eliminate unnecessary duplication of code.

At the same time, in order to make the public function universal,
minor fixes was added.
Signed-off-by: default avatarTianjia Zhang <tianjia.zhang@linux.alibaba.com>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent ff1469a2
/* SPDX-License-Identifier: GPL-2.0-or-later */
#ifndef ASM_X86_SM4_AVX_H
#define ASM_X86_SM4_AVX_H
#include <linux/types.h>
#include <crypto/sm4.h>
typedef void (*sm4_crypt_func)(const u32 *rk, u8 *dst, const u8 *src, u8 *iv);
int sm4_avx_ecb_encrypt(struct skcipher_request *req);
int sm4_avx_ecb_decrypt(struct skcipher_request *req);
int sm4_cbc_encrypt(struct skcipher_request *req);
int sm4_avx_cbc_decrypt(struct skcipher_request *req,
unsigned int bsize, sm4_crypt_func func);
int sm4_cfb_encrypt(struct skcipher_request *req);
int sm4_avx_cfb_decrypt(struct skcipher_request *req,
unsigned int bsize, sm4_crypt_func func);
int sm4_avx_ctr_crypt(struct skcipher_request *req,
unsigned int bsize, sm4_crypt_func func);
#endif
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <crypto/internal/simd.h> #include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h> #include <crypto/internal/skcipher.h>
#include <crypto/sm4.h> #include <crypto/sm4.h>
#include "sm4-avx.h"
#define SM4_CRYPT8_BLOCK_SIZE (SM4_BLOCK_SIZE * 8) #define SM4_CRYPT8_BLOCK_SIZE (SM4_BLOCK_SIZE * 8)
...@@ -71,23 +72,25 @@ static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey) ...@@ -71,23 +72,25 @@ static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
return err; return err;
} }
static int ecb_encrypt(struct skcipher_request *req) int sm4_avx_ecb_encrypt(struct skcipher_request *req)
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
return ecb_do_crypt(req, ctx->rkey_enc); return ecb_do_crypt(req, ctx->rkey_enc);
} }
EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
static int ecb_decrypt(struct skcipher_request *req) int sm4_avx_ecb_decrypt(struct skcipher_request *req)
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
return ecb_do_crypt(req, ctx->rkey_dec); return ecb_do_crypt(req, ctx->rkey_dec);
} }
EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
static int cbc_encrypt(struct skcipher_request *req) int sm4_cbc_encrypt(struct skcipher_request *req)
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
...@@ -118,8 +121,10 @@ static int cbc_encrypt(struct skcipher_request *req) ...@@ -118,8 +121,10 @@ static int cbc_encrypt(struct skcipher_request *req)
return err; return err;
} }
EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
static int cbc_decrypt(struct skcipher_request *req) int sm4_avx_cbc_decrypt(struct skcipher_request *req,
unsigned int bsize, sm4_crypt_func func)
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
...@@ -135,15 +140,14 @@ static int cbc_decrypt(struct skcipher_request *req) ...@@ -135,15 +140,14 @@ static int cbc_decrypt(struct skcipher_request *req)
kernel_fpu_begin(); kernel_fpu_begin();
while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) { while (nbytes >= bsize) {
sm4_aesni_avx_cbc_dec_blk8(ctx->rkey_dec, dst, func(ctx->rkey_dec, dst, src, walk.iv);
src, walk.iv); dst += bsize;
dst += SM4_CRYPT8_BLOCK_SIZE; src += bsize;
src += SM4_CRYPT8_BLOCK_SIZE; nbytes -= bsize;
nbytes -= SM4_CRYPT8_BLOCK_SIZE;
} }
if (nbytes >= SM4_BLOCK_SIZE) { while (nbytes >= SM4_BLOCK_SIZE) {
u8 keystream[SM4_BLOCK_SIZE * 8]; u8 keystream[SM4_BLOCK_SIZE * 8];
u8 iv[SM4_BLOCK_SIZE]; u8 iv[SM4_BLOCK_SIZE];
unsigned int nblocks = min(nbytes >> 4, 8u); unsigned int nblocks = min(nbytes >> 4, 8u);
...@@ -165,6 +169,8 @@ static int cbc_decrypt(struct skcipher_request *req) ...@@ -165,6 +169,8 @@ static int cbc_decrypt(struct skcipher_request *req)
} }
crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE); crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
memcpy(walk.iv, iv, SM4_BLOCK_SIZE); memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
dst += nblocks * SM4_BLOCK_SIZE;
src += (nblocks + 1) * SM4_BLOCK_SIZE;
nbytes -= nblocks * SM4_BLOCK_SIZE; nbytes -= nblocks * SM4_BLOCK_SIZE;
} }
...@@ -174,8 +180,15 @@ static int cbc_decrypt(struct skcipher_request *req) ...@@ -174,8 +180,15 @@ static int cbc_decrypt(struct skcipher_request *req)
return err; return err;
} }
EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
static int cbc_decrypt(struct skcipher_request *req)
{
return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
sm4_aesni_avx_cbc_dec_blk8);
}
static int cfb_encrypt(struct skcipher_request *req) int sm4_cfb_encrypt(struct skcipher_request *req)
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
...@@ -214,8 +227,10 @@ static int cfb_encrypt(struct skcipher_request *req) ...@@ -214,8 +227,10 @@ static int cfb_encrypt(struct skcipher_request *req)
return err; return err;
} }
EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
static int cfb_decrypt(struct skcipher_request *req) int sm4_avx_cfb_decrypt(struct skcipher_request *req,
unsigned int bsize, sm4_crypt_func func)
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
...@@ -231,15 +246,14 @@ static int cfb_decrypt(struct skcipher_request *req) ...@@ -231,15 +246,14 @@ static int cfb_decrypt(struct skcipher_request *req)
kernel_fpu_begin(); kernel_fpu_begin();
while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) { while (nbytes >= bsize) {
sm4_aesni_avx_cfb_dec_blk8(ctx->rkey_enc, dst, func(ctx->rkey_enc, dst, src, walk.iv);
src, walk.iv); dst += bsize;
dst += SM4_CRYPT8_BLOCK_SIZE; src += bsize;
src += SM4_CRYPT8_BLOCK_SIZE; nbytes -= bsize;
nbytes -= SM4_CRYPT8_BLOCK_SIZE;
} }
if (nbytes >= SM4_BLOCK_SIZE) { while (nbytes >= SM4_BLOCK_SIZE) {
u8 keystream[SM4_BLOCK_SIZE * 8]; u8 keystream[SM4_BLOCK_SIZE * 8];
unsigned int nblocks = min(nbytes >> 4, 8u); unsigned int nblocks = min(nbytes >> 4, 8u);
...@@ -276,8 +290,16 @@ static int cfb_decrypt(struct skcipher_request *req) ...@@ -276,8 +290,16 @@ static int cfb_decrypt(struct skcipher_request *req)
return err; return err;
} }
EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
static int ctr_crypt(struct skcipher_request *req) static int cfb_decrypt(struct skcipher_request *req)
{
return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
sm4_aesni_avx_cfb_dec_blk8);
}
int sm4_avx_ctr_crypt(struct skcipher_request *req,
unsigned int bsize, sm4_crypt_func func)
{ {
struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm); struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
...@@ -293,15 +315,14 @@ static int ctr_crypt(struct skcipher_request *req) ...@@ -293,15 +315,14 @@ static int ctr_crypt(struct skcipher_request *req)
kernel_fpu_begin(); kernel_fpu_begin();
while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) { while (nbytes >= bsize) {
sm4_aesni_avx_ctr_enc_blk8(ctx->rkey_enc, dst, func(ctx->rkey_enc, dst, src, walk.iv);
src, walk.iv); dst += bsize;
dst += SM4_CRYPT8_BLOCK_SIZE; src += bsize;
src += SM4_CRYPT8_BLOCK_SIZE; nbytes -= bsize;
nbytes -= SM4_CRYPT8_BLOCK_SIZE;
} }
if (nbytes >= SM4_BLOCK_SIZE) { while (nbytes >= SM4_BLOCK_SIZE) {
u8 keystream[SM4_BLOCK_SIZE * 8]; u8 keystream[SM4_BLOCK_SIZE * 8];
unsigned int nblocks = min(nbytes >> 4, 8u); unsigned int nblocks = min(nbytes >> 4, 8u);
int i; int i;
...@@ -343,6 +364,13 @@ static int ctr_crypt(struct skcipher_request *req) ...@@ -343,6 +364,13 @@ static int ctr_crypt(struct skcipher_request *req)
return err; return err;
} }
EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
static int ctr_crypt(struct skcipher_request *req)
{
return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
sm4_aesni_avx_ctr_enc_blk8);
}
static struct skcipher_alg sm4_aesni_avx_skciphers[] = { static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
{ {
...@@ -359,8 +387,8 @@ static struct skcipher_alg sm4_aesni_avx_skciphers[] = { ...@@ -359,8 +387,8 @@ static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
.max_keysize = SM4_KEY_SIZE, .max_keysize = SM4_KEY_SIZE,
.walksize = 8 * SM4_BLOCK_SIZE, .walksize = 8 * SM4_BLOCK_SIZE,
.setkey = sm4_skcipher_setkey, .setkey = sm4_skcipher_setkey,
.encrypt = ecb_encrypt, .encrypt = sm4_avx_ecb_encrypt,
.decrypt = ecb_decrypt, .decrypt = sm4_avx_ecb_decrypt,
}, { }, {
.base = { .base = {
.cra_name = "__cbc(sm4)", .cra_name = "__cbc(sm4)",
...@@ -376,7 +404,7 @@ static struct skcipher_alg sm4_aesni_avx_skciphers[] = { ...@@ -376,7 +404,7 @@ static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
.ivsize = SM4_BLOCK_SIZE, .ivsize = SM4_BLOCK_SIZE,
.walksize = 8 * SM4_BLOCK_SIZE, .walksize = 8 * SM4_BLOCK_SIZE,
.setkey = sm4_skcipher_setkey, .setkey = sm4_skcipher_setkey,
.encrypt = cbc_encrypt, .encrypt = sm4_cbc_encrypt,
.decrypt = cbc_decrypt, .decrypt = cbc_decrypt,
}, { }, {
.base = { .base = {
...@@ -394,7 +422,7 @@ static struct skcipher_alg sm4_aesni_avx_skciphers[] = { ...@@ -394,7 +422,7 @@ static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
.chunksize = SM4_BLOCK_SIZE, .chunksize = SM4_BLOCK_SIZE,
.walksize = 8 * SM4_BLOCK_SIZE, .walksize = 8 * SM4_BLOCK_SIZE,
.setkey = sm4_skcipher_setkey, .setkey = sm4_skcipher_setkey,
.encrypt = cfb_encrypt, .encrypt = sm4_cfb_encrypt,
.decrypt = cfb_decrypt, .decrypt = cfb_decrypt,
}, { }, {
.base = { .base = {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment