/* sha512.c */
/*
 * Copyright (c) 2012-2015 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "aiconfig.h"

#include <stdio.h>
#include <string.h>
#include <stdint.h>
#include "sha2.h"
#include <aicrypto/ok_sha2.h>

#define NUMOFSEQ        80	/* number of sequence */

/*
 * FIPS PUB 180-3
 * 4.2.3 SHA-384 and SHA-512 Constants
 */
static const uint64_t K512[NUMOFSEQ] = {
	0x428a2f98d728ae22ULL, 0x7137449123ef65cdULL, 0xb5c0fbcfec4d3b2fULL,
	    0xe9b5dba58189dbbcULL,
	0x3956c25bf348b538ULL, 0x59f111f1b605d019ULL, 0x923f82a4af194f9bULL,
	    0xab1c5ed5da6d8118ULL,
	0xd807aa98a3030242ULL, 0x12835b0145706fbeULL, 0x243185be4ee4b28cULL,
	    0x550c7dc3d5ffb4e2ULL,
	0x72be5d74f27b896fULL, 0x80deb1fe3b1696b1ULL, 0x9bdc06a725c71235ULL,
	    0xc19bf174cf692694ULL,
	0xe49b69c19ef14ad2ULL, 0xefbe4786384f25e3ULL, 0x0fc19dc68b8cd5b5ULL,
	    0x240ca1cc77ac9c65ULL,
	0x2de92c6f592b0275ULL, 0x4a7484aa6ea6e483ULL, 0x5cb0a9dcbd41fbd4ULL,
	    0x76f988da831153b5ULL,
	0x983e5152ee66dfabULL, 0xa831c66d2db43210ULL, 0xb00327c898fb213fULL,
	    0xbf597fc7beef0ee4ULL,
	0xc6e00bf33da88fc2ULL, 0xd5a79147930aa725ULL, 0x06ca6351e003826fULL,
	    0x142929670a0e6e70ULL,
	0x27b70a8546d22ffcULL, 0x2e1b21385c26c926ULL, 0x4d2c6dfc5ac42aedULL,
	    0x53380d139d95b3dfULL,
	0x650a73548baf63deULL, 0x766a0abb3c77b2a8ULL, 0x81c2c92e47edaee6ULL,
	    0x92722c851482353bULL,
	0xa2bfe8a14cf10364ULL, 0xa81a664bbc423001ULL, 0xc24b8b70d0f89791ULL,
	    0xc76c51a30654be30ULL,
	0xd192e819d6ef5218ULL, 0xd69906245565a910ULL, 0xf40e35855771202aULL,
	    0x106aa07032bbd1b8ULL,
	0x19a4c116b8d2d0c8ULL, 0x1e376c085141ab53ULL, 0x2748774cdf8eeb99ULL,
	    0x34b0bcb5e19b48a8ULL,
	0x391c0cb3c5c95a63ULL, 0x4ed8aa4ae3418acbULL, 0x5b9cca4f7763e373ULL,
	    0x682e6ff3d6b2b8a3ULL,
	0x748f82ee5defb2fcULL, 0x78a5636f43172f60ULL, 0x84c87814a1f0ab72ULL,
	    0x8cc702081a6439ecULL,
	0x90befffa23631e28ULL, 0xa4506cebde82bde9ULL, 0xbef9a3f7b2c67915ULL,
	    0xc67178f2e372532bULL,
	0xca273eceea26619cULL, 0xd186b8c721c0c207ULL, 0xeada7dd6cde0eb1eULL,
	    0xf57d4f7fee6ed178ULL,
	0x06f067aa72176fbaULL, 0x0a637dc5a2c898a6ULL, 0x113f9804bef90daeULL,
	    0x1b710b35131c471bULL,
	0x28db77f523047d84ULL, 0x32caab7b40c72493ULL, 0x3c9ebe0a15c9bebcULL,
	    0x431d67c49c100d4cULL,
	0x4cc5d4becb3e42b6ULL, 0x597f299cfc657e2aULL, 0x5fcb6fab3ad6faecULL,
	    0x6c44198c4a475817ULL,
};

/*
 * FIPS PUB 180-4
 * 5.3 Setting the Initial Hash Value (H(0))
 */
/* 5.3.4 SHA-384 */
static const uint64_t initH384[] = {
	0xcbbb9d5dc1059ed8ULL, 0x629a292a367cd507ULL, 0x9159015a3070dd17ULL,
	    0x152fecd8f70e5939ULL,
	0x67332667ffc00b31ULL, 0x8eb44a8768581511ULL, 0xdb0c2e0d64f98fa7ULL,
	    0x47b5481dbefa4fa4ULL,
};

/* 5.3.5 SHA-512 */
static const uint64_t initH512[] = {
	0x6a09e667f3bcc908ULL, 0xbb67ae8584caa73bULL, 0x3c6ef372fe94f82bULL,
	    0xa54ff53a5f1d36f1ULL,
	0x510e527fade682d1ULL, 0x9b05688c2b3e6c1fULL, 0x1f83d9abfb41bd6bULL,
	    0x5be0cd19137e2179ULL,
};

/* 5.3.6.1 SHA-512/224 */
static const uint64_t initH512224[] = {
	0x8c3d37c819544da2ULL, 0x73e1996689dcd4d6ULL,
	0x1DFAB7Ae32ff9c82ULL, 0x679dd514582f9fcfULL,
	0x0F6D2B697bd44da8ULL, 0x77e36f7304c48942ULL,
	0x3f9d85a86a1d36c8ULL, 0x1112e6ad91d692a1ULL,
};

/* 5.3.6.2 SHA-512/256 */
static const uint64_t initH512256[] = {
	0x22312194FC2BF72cULL, 0x9f555fa3c84c64c2ULL,
	0x2393b86b6f53b151ULL, 0x963877195940eabdULL,
	0x96283Ee2a88effe3ULL, 0xBE5e1e2553863992ULL,
	0x2b0199fc2c85b8aaULL, 0x0eb72ddc81c52ca2ULL,
};

static void _SHA512init(SHA512_CTX *ctx, const uint64_t *initH);
static void _SHA512update(SHA512_CTX *ctx, unsigned char *in, unsigned int len);
static void _SHA512final(unsigned char *ret, SHA512_CTX *ctx);

static void sha512_trans(uint64_t *w, uint64_t *H);
static void clear_w(uint64_t *w);
static void set_w(unsigned char *in, uint64_t *w, int max);
static void calc_w(uint64_t *w);
static void set_length(uint64_t len[2], uint64_t *w);

static void uc2ull(unsigned char *in, uint64_t *w, int max);
static void ull2uc(uint64_t *w, unsigned char *ret, int len);

/*-----------------------------------------------
    SHA384 function.(return 384bit char)
-----------------------------------------------*/
void OK_SHA384(unsigned int len, unsigned char *in, unsigned char *ret)
{
	SHA512_CTX ctx;

	if (in == NULL) {
		return;
	}
	SHA384init(&ctx);
	SHA384update(&ctx, in, len);
	SHA384final(ret, &ctx);
}

/*-----------------------------------------------
    SHA384 functions
-----------------------------------------------*/
void SHA384init(SHA512_CTX *ctx)
{
	_SHA512init(ctx, initH384);
}

void SHA384update(SHA512_CTX *ctx, unsigned char *in, unsigned int len)
{
	_SHA512update(ctx, in, len);
}

void SHA384final(unsigned char *ret, SHA512_CTX *ctx)
{
	_SHA512final(ret, ctx);
#if 0
	printf("%.16jx %.16jx %.16jx %.16jx\n%.16llx %.16llx\n",
	       ctx->H[0], ctx->H[1], ctx->H[2], ctx->H[3],
	       ctx->H[4], ctx->H[5]);
#endif
	ull2uc(ctx->H, ret, 384 / 64);
}

/*-----------------------------------------------
    SHA512 function.(return 512bit char)
-----------------------------------------------*/
void OK_SHA512(unsigned int len, unsigned char *in, unsigned char *ret)
{
	SHA512_CTX ctx;

	if (in == NULL) {
		return;
	}
	SHA512init(&ctx);
	SHA512update(&ctx, in, len);
	SHA512final(ret, &ctx);
}

/*-----------------------------------------------
    SHA512 functions
-----------------------------------------------*/
void SHA512init(SHA512_CTX *ctx)
{
	_SHA512init(ctx, initH512);
}

void SHA512update(SHA512_CTX *ctx, unsigned char *in, unsigned int len)
{
	_SHA512update(ctx, in, len);
}

void SHA512final(unsigned char *ret, SHA512_CTX *ctx)
{
	_SHA512final(ret, ctx);
#if 0
	printf("%.16jx %.16jx %.16jx %.16jx\n%.16llx %.16llx %.16llx %.16llx\n",
	       ctx->H[0], ctx->H[1], ctx->H[2], ctx->H[3],
	       ctx->H[4], ctx->H[5], ctx->H[6], ctx->H[7]);
#endif
	ull2uc(ctx->H, ret, 512 / 64);
}

/*-----------------------------------------------
    SHA512/224 function.(return 224bit char)
-----------------------------------------------*/
void OK_SHA512224(unsigned int len, unsigned char *in, unsigned char *ret)
{
	SHA512_CTX ctx;

	if (in == NULL) {
		return;
	}
	SHA512224init(&ctx);
	SHA512224update(&ctx, in, len);
	SHA512224final(ret, &ctx);
}

/*-----------------------------------------------
    SHA512/224 functions
-----------------------------------------------*/
void SHA512224init(SHA512_CTX *ctx)
{
	_SHA512init(ctx, initH512224);
}

void SHA512224update(SHA512_CTX *ctx, unsigned char *in, unsigned int len)
{
	_SHA512update(ctx, in, len);
}

void SHA512224final(unsigned char *ret, SHA512_CTX *ctx)
{
	unsigned char ret2[256];

	_SHA512final(ret2, ctx);
#if 0
	printf("%.16jx %.16jx %.16jx %.16jx\n",
	       ctx->H[0], ctx->H[1], ctx->H[2], ctx->H[3]);
#endif
	ull2uc(ctx->H, ret2, 256 / 64);

	memcpy(ret, ret2, SHA512224_DIGESTSIZE);	/* left-most 224 bits */
}

/*-----------------------------------------------
    SHA512/256 function.(return 256bit char)
-----------------------------------------------*/
void OK_SHA512256(unsigned int len, unsigned char *in, unsigned char *ret)
{
	SHA512_CTX ctx;

	if (in == NULL) {
		return;
	}
	SHA512256init(&ctx);
	SHA512256update(&ctx, in, len);
	SHA512256final(ret, &ctx);
}

/*-----------------------------------------------
    SHA512/256 functions
-----------------------------------------------*/
void SHA512256init(SHA512_CTX *ctx)
{
	_SHA512init(ctx, initH512256);
}

void SHA512256update(SHA512_CTX *ctx, unsigned char *in, unsigned int len)
{
	_SHA512update(ctx, in, len);
}

void SHA512256final(unsigned char *ret, SHA512_CTX *ctx)
{
	_SHA512final(ret, ctx);
#if 0
	printf("%.16jx %.16jx %.16jx %.16jx\n",
	       ctx->H[0], ctx->H[1], ctx->H[2], ctx->H[3]);
#endif
	ull2uc(ctx->H, ret, 256 / 64);
}

/* private functions */

static void _SHA512init(SHA512_CTX *ctx, const uint64_t *initH)
{
	int i;

	for (i = 0; i < 8; i++)
		ctx->H[i] = initH[i];
	ctx->len[0] = ctx->len[1] = 0;
	ctx->mod = 0;
}

static void _SHA512update(SHA512_CTX *ctx, unsigned char *in, unsigned int len)
{
	uint64_t w[NUMOFSEQ];
	uint64_t *H;
	unsigned char *dat;
	int i, mod, tmp;

	if (in == NULL)
		return;

	H = ctx->H;
	/* update the length in bits. */
	if ((ctx->len[0] += ((uint64_t)len << 3)) < ((uint64_t)len << 3))
		ctx->len[1]++;
	ctx->len[1] += ((uint64_t)len >> 61);
	dat = ctx->dat;
	mod = ctx->mod;

	if (len + mod <= SHA512_BLOCKSIZE) {
		ctx->mod = len + mod;
		memcpy(&dat[mod], in, len);
	} else {
		memcpy(&dat[mod], in, SHA512_BLOCKSIZE - mod);
		set_w(dat, w, SHA512_BLOCKSIZE);
		calc_w(w);
		sha512_trans(w, H);

		tmp = len - SHA512_BLOCKSIZE;
		for (i = SHA512_BLOCKSIZE - mod; i < tmp; i += SHA512_BLOCKSIZE) {
			set_w(&in[i], w, SHA512_BLOCKSIZE);
			calc_w(w);
			sha512_trans(w, H);
		}
		ctx->mod = len - i;
		memcpy(dat, &in[i], ctx->mod);
	}
}

static void _SHA512final(unsigned char *ret, SHA512_CTX *ctx)
{
	uint64_t w[NUMOFSEQ];
	uint64_t *H;
	int mod;

	H = ctx->H;
	mod = ctx->mod;
	if (mod >= 112) {	/* 112 (byte): 896bit == 1024-(64*2) */
		set_w(ctx->dat, w, mod);
		calc_w(w);
		sha512_trans(w, H);

		clear_w(w);
		if (mod == SHA512_BLOCKSIZE)
			w[0] = 0x8000000000000000ULL;
		set_length(ctx->len, w);
		calc_w(w);
		sha512_trans(w, H);
	} else {
		set_w(ctx->dat, w, mod);
		set_length(ctx->len, w);
		calc_w(w);
		sha512_trans(w, H);
	}
}

/*-----------------------------------------------
  char <--> long long (max must be a multiple of 8)
-----------------------------------------------*/
static void uc2ull(unsigned char *in, uint64_t *w, int max)
{
	int i, j;
	for (i = 0, j = 0; j < max; i++, j += 8)
		w[i] = (uint64_t) in[j  ] << 56 |
		       (uint64_t) in[j+1] << 48 |
		       (uint64_t) in[j+2] << 40 |
		       (uint64_t) in[j+3] << 32 |
		       (uint64_t) in[j+4] << 24 |
		       (uint64_t) in[j+5] << 16 |
		       (uint64_t) in[j+6] <<  8 |
		       (uint64_t) in[j+7];
}

static void ull2uc(uint64_t *H, unsigned char *ret, int len)
{
	int i, j;
	for (i = j = 0; i < len; i++, j += 8) {
		ret[j  ] = (unsigned char)(H[i] >> 56);
		ret[j+1] = (unsigned char)(H[i] >> 48);
		ret[j+2] = (unsigned char)(H[i] >> 40);
		ret[j+3] = (unsigned char)(H[i] >> 32);
		ret[j+4] = (unsigned char)(H[i] >> 24);
		ret[j+5] = (unsigned char)(H[i] >> 16);
		ret[j+6] = (unsigned char)(H[i] >> 8);
		ret[j+7] = (unsigned char) H[i];
	}
}

/*-----------------------------------------------
  set w[]
  w: uint64_t[16]
  max: byte (max < 128 byte (== 1024bit))
-----------------------------------------------*/
static void set_w(unsigned char *in, uint64_t *w, int max)
{
	int div, mod;

	/* clear w */
	memset(w, 0, sizeof(uint64_t) * 16);

	/* set w */
	if (max == SHA512_BLOCKSIZE)
		uc2ull(in, w, max);
	else {
		div = max / 8;
		mod = max % 8;
		uc2ull(in, w, max - mod);

		switch (mod) {
		case 0:
			w[div] = 0x8000000000000000ULL;
			break;
		case 1:
			w[div] = (uint64_t) in[max-1] << 56 |
				 0x80000000000000ULL;
			break;
		case 2:
			w[div] = (uint64_t) in[max-2] << 56 |
				 (uint64_t) in[max-1] << 48 |
				 0x800000000000ULL;
			break;
		case 3:
			w[div] = (uint64_t) in[max-3] << 56 |
				 (uint64_t) in[max-2] << 48 |
				 (uint64_t) in[max-1] << 40 |
				 0x8000000000ULL;
			break;
		case 4:
			w[div] = (uint64_t) in[max-4] << 56 |
				 (uint64_t) in[max-3] << 48 |
				 (uint64_t) in[max-2] << 40 |
				 (uint64_t) in[max-1] << 32 |
				 0x80000000ULL;
			break;
		case 5:
			w[div] = (uint64_t) in[max-5] << 56 |
				 (uint64_t) in[max-4] << 48 |
				 (uint64_t) in[max-3] << 40 |
				 (uint64_t) in[max-2] << 32 |
				 (uint64_t) in[max-1] << 24 |
				 0x800000ULL;
			break;
		case 6:
			w[div] = (uint64_t) in[max-6] << 56 |
				 (uint64_t) in[max-5] << 48 |
				 (uint64_t) in[max-4] << 40 |
				 (uint64_t) in[max-3] << 32 |
				 (uint64_t) in[max-2] << 24 |
				 (uint64_t) in[max-1] << 16 |
				 0x8000ULL;
			break;
		case 7:
			w[div] = (uint64_t) in[max-7] << 56 |
				 (uint64_t) in[max-6] << 48 |
				 (uint64_t) in[max-5] << 40 |
				 (uint64_t) in[max-4] << 32 |
				 (uint64_t) in[max-3] << 24 |
				 (uint64_t) in[max-2] << 16 |
				 (uint64_t) in[max-1] <<  8 |
				 0x80ULL;
			break;
		}
	}
}

static void clear_w(uint64_t *w)
{
	memset(w, 0, sizeof(uint64_t) * 16);
}

static void calc_w(uint64_t *w)
{
	int i;
	for (i = 16; i < NUMOFSEQ; i++)
		w[i] = sum512_1(w[i-2]) + w[i-7] + sum512_0(w[i-15]) + w[i-16];
}

static void set_length(uint64_t len[2], uint64_t *w)
{
	w[14] = len[1];
	w[15] = len[0];
}

/*-----------------------------------------------
    SHA512 transrate.
-----------------------------------------------*/
static void sha512_trans(uint64_t *w, uint64_t *H)
{
	uint64_t a, b, c, d, e, f, g, h, T1, T2;
	int i;

	a = H[0];
	b = H[1];
	c = H[2];
	d = H[3];
	e = H[4];
	f = H[5];
	g = H[6];
	h = H[7];
	for (i = 0; i < NUMOFSEQ; i++) {
		T1 = h + SUM512_1(e) + f1(e, f, g) + K512[i] + w[i];
		T2 = SUM512_0(a) + f3(a, b, c);
		h = g;
		g = f;
		f = e;
		e = d + T1;
		d = c;
		c = b;
		b = a;
		a = T1 + T2;
#if 0
		printf("t=%02d : %.16jx %.16jx %.16jx %.16jx\n"
		       "       %.16jx %.16jx %.16jx %.16jx\n",
		       i, a, b, c, d, e, f, g, h);
#endif
	}

	H[0] += a;
	H[1] += b;
	H[2] += c;
	H[3] += d;
	H[4] += e;
	H[5] += f;
	H[6] += g;
	H[7] += h;
}
