#include "rijndael_platform.h"

#if defined(RIJNDAEL_BITSLICE)
#include "rijndael_ct64_enc.h"


/****** API functions *******/
#if defined(RIJNDAEL_OPT_ARMV7M)
/* When RIJNDAEL_OPT_ARMV7M is defined, we use the assembly optimized with t-tables
 * from  https://eprint.iacr.org/2016/714.pdf
 * NOTE: we do not need a constant time key schedule as only public data is used */
extern void AES_128_keyschedule(const uint8_t *, uint8_t *);
WEAK int aes128_ct64_setkey_enc(rijndael_ct64_ctx *ctx, const uint8_t key[16])
{
        int ret = -1;
        uint8_t *rk;

        if(ctx == NULL){
                goto err;
        }
        ctx->rtype = AES128;
        rk = (uint8_t*)(ctx->rk);

        memcpy(&rk[0], key, 16);
        AES_128_keyschedule(key, &rk[16]);

        ret = 0;
err:
        return ret;
}
#else
WEAK int aes128_ct64_setkey_enc(rijndael_ct64_ctx *ctx, const uint8_t key[16])
{
	return br_aes_ct64_keysched(ctx, key, AES128);
}
#endif

WEAK int aes256_ct64_setkey_enc(rijndael_ct64_ctx *ctx, const uint8_t key[32])
{
	return br_aes_ct64_keysched(ctx, key, AES256);
}

WEAK int rijndael256_ct64_setkey_enc(rijndael_ct64_ctx *ctx, const uint8_t key[32])
{
	return br_aes_ct64_keysched(ctx, key, RIJNDAEL_256_256);
}


// === AES-128 enc
#if defined(RIJNDAEL_OPT_ARMV7M)
/* When RIJNDAEL_OPT_ARMV7M is defined, we use the assembly optimized "fixsliced" based implementation
 * from https://eprint.iacr.org/2020/1123.pdf */

/* Helper to interleave two round keys, stolen from https://github.com/aadomn/aes/blob/master/opt32/fixslicing/aes_encrypt.c */
#define SWAPMOVE(a, b, mask, n)	({							\
	tmp = (b ^ (a >> n)) & mask;							\
	b ^= tmp;									\
	a ^= (tmp << n);								\
})
#define LE_LOAD_32(x) 									\
    ((((uint32_t)((x)[3])) << 24) | 							\
     (((uint32_t)((x)[2])) << 16) | 							\
     (((uint32_t)((x)[1])) << 8) | 							\
      ((uint32_t)((x)[0])))

/******************************************************************************
* Applies ShiftRows^(-1) on a round key to match the fixsliced representation.
******************************************************************************/
static void inv_shiftrows_1(uint32_t* rkey) {
	uint32_t tmp;
	for(int i = 0; i < 8; i++) {
		SWAPMOVE(rkey[i], rkey[i], 0x0c0f0300, 4);
		SWAPMOVE(rkey[i], rkey[i], 0x33003300, 2);
	}
}
/******************************************************************************
* Applies ShiftRows^(-2) on a round key to match the fixsliced representation.
******************************************************************************/
static void inv_shiftrows_2(uint32_t* rkey) {
	uint32_t tmp;
	for(int i = 0; i < 8; i++)
		SWAPMOVE(rkey[i], rkey[i], 0x0f000f00, 4);
}

/******************************************************************************
* Applies ShiftRows^(-3) on a round key to match the fixsliced representation.
******************************************************************************/
static void inv_shiftrows_3(uint32_t* rkey) {
	uint32_t tmp;
	for(int i = 0; i < 8; i++) {
		SWAPMOVE(rkey[i], rkey[i], 0x030f0c00, 4);
		SWAPMOVE(rkey[i], rkey[i], 0x33003300, 2);
	}
}
/******************************************************************************
* Packs two 128-bit input blocs in0, in1 into the 256-bit internal state out 
* where the bits are packed as follows:
* out[0] = b_24 b_56 b_88 b_120 || ... || b_0 b_32 b_64 b_96
* out[1] = b_25 b_57 b_89 b_121 || ... || b_1 b_33 b_65 b_97
* out[2] = b_26 b_58 b_90 b_122 || ... || b_2 b_34 b_66 b_98
* out[3] = b_27 b_59 b_91 b_123 || ... || b_3 b_35 b_67 b_99
* out[4] = b_28 b_60 b_92 b_124 || ... || b_4 b_36 b_68 b_100
* out[5] = b_29 b_61 b_93 b_125 || ... || b_5 b_37 b_69 b_101
* out[6] = b_30 b_62 b_94 b_126 || ... || b_6 b_38 b_70 b_102
* out[7] = b_31 b_63 b_95 b_127 || ... || b_7 b_39 b_71 b_103
******************************************************************************/
void keys_packing(uint32_t* out, const unsigned char* in0,
		const unsigned char* in1, unsigned int i) {
	uint32_t tmp;
	out[0] = LE_LOAD_32(in0);
	out[1] = LE_LOAD_32(in1);
	out[2] = LE_LOAD_32(in0 + 4);
	out[3] = LE_LOAD_32(in1 + 4);
	out[4] = LE_LOAD_32(in0 + 8);
	out[5] = LE_LOAD_32(in1 + 8);
	out[6] = LE_LOAD_32(in0 + 12);
	out[7] = LE_LOAD_32(in1 + 12);
	SWAPMOVE(out[1], out[0], 0x55555555, 1);
	SWAPMOVE(out[3], out[2], 0x55555555, 1);
	SWAPMOVE(out[5], out[4], 0x55555555, 1);
	SWAPMOVE(out[7], out[6], 0x55555555, 1);
	SWAPMOVE(out[2], out[0], 0x33333333, 2);
	SWAPMOVE(out[3], out[1], 0x33333333, 2);
	SWAPMOVE(out[6], out[4], 0x33333333, 2);
	SWAPMOVE(out[7], out[5], 0x33333333, 2);
	SWAPMOVE(out[4], out[0], 0x0f0f0f0f, 4);
	SWAPMOVE(out[5], out[1], 0x0f0f0f0f, 4);
	SWAPMOVE(out[6], out[2], 0x0f0f0f0f, 4);
	SWAPMOVE(out[7], out[3], 0x0f0f0f0f, 4);
        /* Apply inverse shiftrows on some round keys to match fixslicing */
        switch(i){
                case 3:
                case 7:
                        inv_shiftrows_3(out);
                        break;
                case 1:
                case 5:
                case 9:
                        inv_shiftrows_1(out);
                        break;
                case 2:
                case 6:
                        inv_shiftrows_2(out);
                        break;
                default:
                        break;
        }
	if(i > 0){
		/* Apply the xor with 0xffffffff since it is expected by the bitslice
 		 * encryption to speedup the SBox */
		out[1] ^= 0xffffffff;
		out[2] ^= 0xffffffff;
		out[6] ^= 0xffffffff;
		out[7] ^= 0xffffffff;
	}
}
extern void aes128_encrypt_ffs(uint8_t* ctext, uint8_t* ctext_bis, const uint8_t* ptext,
                               const uint8_t* ptext_bis, const uint32_t* rkey); 

WEAK int aes128_ct64_enc(const rijndael_ct64_ctx *ctx, const uint8_t data_in[16], uint8_t data_out[16])
{
	int ret = -1;
	unsigned int i;	
	uint32_t interleaved_rkeys[88];

	if((ctx == NULL) || (ctx->rtype != AES128)){
		goto err;
	}

	/* The "fixsliced" implementation expects two interleaved keys: use the same key as a dummy one */
	for(i = 0; i < 11; i++){
		keys_packing(&interleaved_rkeys[8*i], (uint8_t*)(&ctx->rk[4*i]), (uint8_t*)(&ctx->rk[4*i]), i);
	}
	aes128_encrypt_ffs(data_out, data_out, data_in, data_in, interleaved_rkeys);

	ret = 0;
err:
	return ret;
}

WEAK int aes128_ct64_enc_x2(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const uint8_t plainText1[16], const uint8_t plainText2[16], uint8_t cipherText1[16], uint8_t cipherText2[16])
{
	int ret = -1;
	unsigned int i;	
	uint32_t interleaved_rkeys[88];

	if((ctx1 == NULL) || (ctx1->rtype != AES128)){
		goto err;
	}
	if((ctx2 == NULL) || (ctx2->rtype != AES128)){
		goto err;
	}

	/* The "fixsliced" implementation expects two interleaved keys: use the same key as a dummy one */
	for(i = 0; i < 11; i++){
		keys_packing(&interleaved_rkeys[8*i], (uint8_t*)(&ctx1->rk[4*i]), (uint8_t*)(&ctx2->rk[4*i]), i);
	}
	aes128_encrypt_ffs(cipherText1, cipherText2, plainText1, plainText2, interleaved_rkeys);

	ret = 0;
err:
	return ret;
}

WEAK int aes128_ct64_enc_x4(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const rijndael_ct64_ctx *ctx3, const rijndael_ct64_ctx *ctx4,
                const uint8_t plainText1[16], const uint8_t plainText2[16], const uint8_t plainText3[16], const uint8_t plainText4[16],
                uint8_t cipherText1[16], uint8_t cipherText2[16], uint8_t cipherText3[16], uint8_t cipherText4[16])
{
	int ret = 0;
        ret |= aes128_ct64_enc_x2(ctx1, ctx2, plainText1, plainText2, cipherText1, cipherText2);
        ret |= aes128_ct64_enc_x2(ctx3, ctx4, plainText3, plainText4, cipherText3, cipherText4);
	return ret;
}

#else
WEAK int aes128_ct64_enc(const rijndael_ct64_ctx *ctx, const uint8_t data_in[16], uint8_t data_out[16])
{
	if((ctx == NULL) || (ctx->rtype != AES128)){
		return -1;
	}
	return core_ct64_bitslice_encrypt(ctx, NULL, NULL, NULL,
        data_in, NULL, NULL, NULL,
        data_out, NULL, NULL, NULL);
}

WEAK int aes128_ct64_enc_x2(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const uint8_t plainText1[16], const uint8_t plainText2[16], uint8_t cipherText1[16], uint8_t cipherText2[16])
{
	if((ctx1 == NULL) || (ctx1->rtype != AES128)){
		return -1;
	}
	if((ctx2 == NULL) || (ctx2->rtype != AES128)){
		return -1;
	}
	return core_ct64_bitslice_encrypt(ctx1, ctx2, NULL, NULL,
        plainText1, plainText2, NULL, NULL,
        cipherText1, cipherText2, NULL, NULL);
}

WEAK int aes128_ct64_enc_x4(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const rijndael_ct64_ctx *ctx3, const rijndael_ct64_ctx *ctx4,
                const uint8_t plainText1[16], const uint8_t plainText2[16], const uint8_t plainText3[16], const uint8_t plainText4[16],
                uint8_t cipherText1[16], uint8_t cipherText2[16], uint8_t cipherText3[16], uint8_t cipherText4[16])
{
	if((ctx1 == NULL) || (ctx1->rtype != AES128)){
		return -1;
	}
	if((ctx2 == NULL) || (ctx2->rtype != AES128)){
		return -1;
	}
	if((ctx3 == NULL) || (ctx3->rtype != AES128)){
		return -1;
	}
	if((ctx4 == NULL) || (ctx4->rtype != AES128)){
		return -1;
	}

	return core_ct64_bitslice_encrypt(ctx1, ctx2, ctx3, ctx4,
        plainText1, plainText2, plainText3, plainText4,
        cipherText1, cipherText2, cipherText3, cipherText4);
}
#endif

WEAK int aes128_ct64_enc_x8(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const rijndael_ct64_ctx *ctx3, const rijndael_ct64_ctx *ctx4,
                  const rijndael_ct64_ctx *ctx5, const rijndael_ct64_ctx *ctx6, const rijndael_ct64_ctx *ctx7, const rijndael_ct64_ctx *ctx8,
                const uint8_t plainText1[16], const uint8_t plainText2[16], const uint8_t plainText3[16], const uint8_t plainText4[16],
                const uint8_t plainText5[16], const uint8_t plainText6[16], const uint8_t plainText7[16], const uint8_t plainText8[16],
                uint8_t cipherText1[16], uint8_t cipherText2[16], uint8_t cipherText3[16], uint8_t cipherText4[16],
                uint8_t cipherText5[16], uint8_t cipherText6[16], uint8_t cipherText7[16], uint8_t cipherText8[16])
{
	int ret = 0;
        ret |= aes128_ct64_enc_x4(ctx1, ctx2, ctx3, ctx4, plainText1, plainText2, plainText3, plainText4, cipherText1, cipherText2, cipherText3, cipherText4);
        ret |= aes128_ct64_enc_x4(ctx5, ctx6, ctx7, ctx8, plainText5, plainText6, plainText7, plainText8, cipherText5, cipherText6, cipherText7, cipherText8);
	return ret;
}

// === AES-256 enc
WEAK int aes256_ct64_enc(const rijndael_ct64_ctx *ctx, const uint8_t data_in[16], uint8_t data_out[16])
{
	if((ctx == NULL) || (ctx->rtype != AES256)){
		return -1;
	}
	return core_ct64_bitslice_encrypt(ctx, NULL, NULL, NULL,
        data_in, NULL, NULL, NULL,
        data_out, NULL, NULL, NULL);
}

WEAK int aes256_ct64_enc_x2(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const uint8_t plainText1[16], const uint8_t plainText2[16], uint8_t cipherText1[16], uint8_t cipherText2[16])
{
	if((ctx1 == NULL) || (ctx1->rtype != AES256)){
		return -1;
	}
	if((ctx2 == NULL) || (ctx2->rtype != AES256)){
		return -1;
	}
	return core_ct64_bitslice_encrypt(ctx1, ctx2, NULL, NULL,
        plainText1, plainText2, NULL, NULL,
        cipherText1, cipherText2, NULL, NULL);
}

WEAK int aes256_ct64_enc_x4(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const rijndael_ct64_ctx *ctx3, const rijndael_ct64_ctx *ctx4,
                const uint8_t plainText1[16], const uint8_t plainText2[16], const uint8_t plainText3[16], const uint8_t plainText4[16],
                uint8_t cipherText1[16], uint8_t cipherText2[16], uint8_t cipherText3[16], uint8_t cipherText4[16])
{
	if((ctx1 == NULL) || (ctx1->rtype != AES256)){
		return -1;
	}
	if((ctx2 == NULL) || (ctx2->rtype != AES256)){
		return -1;
	}
	if((ctx3 == NULL) || (ctx3->rtype != AES256)){
		return -1;
	}
	if((ctx4 == NULL) || (ctx4->rtype != AES256)){
		return -1;
	}

	return core_ct64_bitslice_encrypt(ctx1, ctx2, ctx3, ctx4,
        plainText1, plainText2, plainText3, plainText4,
        cipherText1, cipherText2, cipherText3, cipherText4);
}

WEAK int aes256_ct64_enc_x8(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const rijndael_ct64_ctx *ctx3, const rijndael_ct64_ctx *ctx4,
                  const rijndael_ct64_ctx *ctx5, const rijndael_ct64_ctx *ctx6, const rijndael_ct64_ctx *ctx7, const rijndael_ct64_ctx *ctx8,
                const uint8_t plainText1[16], const uint8_t plainText2[16], const uint8_t plainText3[16], const uint8_t plainText4[16],
                const uint8_t plainText5[16], const uint8_t plainText6[16], const uint8_t plainText7[16], const uint8_t plainText8[16],
                uint8_t cipherText1[16], uint8_t cipherText2[16], uint8_t cipherText3[16], uint8_t cipherText4[16],
                uint8_t cipherText5[16], uint8_t cipherText6[16], uint8_t cipherText7[16], uint8_t cipherText8[16])
{
	int ret = 0;
        ret |= aes256_ct64_enc_x4(ctx1, ctx2, ctx3, ctx4, plainText1, plainText2, plainText3, plainText4, cipherText1, cipherText2, cipherText3, cipherText4);
        ret |= aes256_ct64_enc_x4(ctx5, ctx6, ctx7, ctx8, plainText5, plainText6, plainText7, plainText8, cipherText5, cipherText6, cipherText7, cipherText8);
	return ret;
}

// === Rijndael-256 enc
WEAK int rijndael256_ct64_enc(const rijndael_ct64_ctx *ctx, const uint8_t data_in[32], uint8_t data_out[32])
{
	if((ctx == NULL) || (ctx->rtype != RIJNDAEL_256_256)){
		return -1;
	}
	return core_ct64_bitslice_encrypt(ctx, NULL, NULL, NULL,
        data_in, NULL, NULL, NULL,
        data_out, NULL, NULL, NULL);
}

WEAK int rijndael256_ct64_enc_x2(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const uint8_t plainText1[32], const uint8_t plainText2[32], uint8_t cipherText1[32], uint8_t cipherText2[32])
{
	if((ctx1 == NULL) || (ctx1->rtype != RIJNDAEL_256_256)){
		return -1;
	}
	if((ctx2 == NULL) || (ctx2->rtype != RIJNDAEL_256_256)){
		return -1;
	}
	return core_ct64_bitslice_encrypt(ctx1, ctx2, NULL, NULL,
        plainText1, plainText2, NULL, NULL,
        cipherText1, cipherText2, NULL, NULL);
}

WEAK int rijndael256_ct64_enc_x4(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const rijndael_ct64_ctx *ctx3, const rijndael_ct64_ctx *ctx4,
                const uint8_t plainText1[32], const uint8_t plainText2[32], const uint8_t plainText3[32], const uint8_t plainText4[32],
                uint8_t cipherText1[32], uint8_t cipherText2[32], uint8_t cipherText3[32], uint8_t cipherText4[32])
{
	int ret = -1;

	if((ctx1 == NULL) || (ctx1->rtype != RIJNDAEL_256_256)){
		return -1;
	}
	if((ctx2 == NULL) || (ctx2->rtype != RIJNDAEL_256_256)){
		return -1;
	}
	if((ctx3 == NULL) || (ctx3->rtype != RIJNDAEL_256_256)){
		return -1;
	}
	if((ctx4 == NULL) || (ctx4->rtype != RIJNDAEL_256_256)){
		return -1;
	}

	ret = core_ct64_bitslice_encrypt(ctx1, ctx2, NULL, NULL,
        plainText1, plainText2, NULL, NULL,
        cipherText1, cipherText2, NULL, NULL);
	if(ret){
		ret = -1;
		goto err;
	}
	ret = core_ct64_bitslice_encrypt(ctx3, ctx4, NULL, NULL,
        plainText3, plainText4, NULL, NULL,
        cipherText3, cipherText4, NULL, NULL);
	if(ret){
		ret = -1;
		goto err;
	}

	ret = 0;
err:
	return ret;
}

WEAK int rijndael256_ct64_enc_x8(const rijndael_ct64_ctx *ctx1, const rijndael_ct64_ctx *ctx2, const rijndael_ct64_ctx *ctx3, const rijndael_ct64_ctx *ctx4,
                  const rijndael_ct64_ctx *ctx5, const rijndael_ct64_ctx *ctx6, const rijndael_ct64_ctx *ctx7, const rijndael_ct64_ctx *ctx8,
                const uint8_t plainText1[32], const uint8_t plainText2[32], const uint8_t plainText3[32], const uint8_t plainText4[32],
                const uint8_t plainText5[32], const uint8_t plainText6[32], const uint8_t plainText7[32], const uint8_t plainText8[32],
                uint8_t cipherText1[32], uint8_t cipherText2[32], uint8_t cipherText3[32], uint8_t cipherText4[32],
                uint8_t cipherText5[32], uint8_t cipherText6[32], uint8_t cipherText7[32], uint8_t cipherText8[32])
{
	int ret = 0;
        ret |= rijndael256_ct64_enc_x4(ctx1, ctx2, ctx3, ctx4, plainText1, plainText2, plainText3, plainText4, cipherText1, cipherText2, cipherText3, cipherText4);
        ret |= rijndael256_ct64_enc_x4(ctx5, ctx6, ctx7, ctx8, plainText5, plainText6, plainText7, plainText8, cipherText5, cipherText6, cipherText7, cipherText8);
	return ret;
}

#endif /* RIJNDAEL_BITSLICE */

