/* sha256.c */
/*
 * Copyright (c) 2012-2015 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/
 * If you redistribute this file, with or without modifications, you must 
 * include this notice in the file.
 */

#include "aiconfig.h"

#include <stdio.h>
#include <string.h>
#include <stdint.h>
#include "sha2.h"
#include <aicrypto/ok_sha2.h>

#define NUMOFSEQ        64	/* number of sequence */

/*
 * FIPS PUB 180-3
 * 4.2.2 SHA-224 and SHA-256 Constants
 */
static const uint32_t K256[NUMOFSEQ] = {
	0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
	0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
	0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
	0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
	0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
	0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
	0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
	0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
	0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
	0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
	0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
	0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
	0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
	0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
	0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
	0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
};

/*
 * FIPS PUB 180-3
 * 5.3 Setting the Initial Hash Value (H(0))
 */
/* 5.3.2 SHA-224 */
static const uint32_t initH224[] = {
	0xc1059ed8, 0x367cd507, 0x3070dd17, 0xf70e5939,
	0xffc00b31, 0x68581511, 0x64f98fa7, 0xbefa4fa4
};

/* 5.3.3 SHA-256 */
static const uint32_t initH256[] = {
	0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a,
	0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19
};

static void _SHA256init(SHA256_CTX *ctx, const uint32_t *initH);
static void _SHA256update(SHA256_CTX *ctx, unsigned char *in, unsigned int len);
static void _SHA256final(unsigned char *ret, SHA256_CTX *ctx);

static void sha256_trans(uint32_t *w, uint32_t *H);
static void clear_w(uint32_t *w);
static void set_w(unsigned char *in, uint32_t *w, int max);
static void calc_w(uint32_t *w);
static void set_length(uint32_t len[2], uint32_t *w);

static void uc2ul(unsigned char *in, uint32_t *w, int max);
static void ul2uc(uint32_t *w, unsigned char *ret, int n);

/*-----------------------------------------------
    SHA224 function.(return 224bit char)
-----------------------------------------------*/
void OK_SHA224(unsigned int len, unsigned char *in, unsigned char *ret)
{
	SHA256_CTX ctx;

	SHA224init(&ctx);
	SHA224update(&ctx, in, len);
	SHA224final(ret, &ctx);
}

/*-----------------------------------------------
    SHA224 functions
-----------------------------------------------*/
void SHA224init(SHA256_CTX *ctx)
{
	_SHA256init(ctx, initH224);
}

void SHA224update(SHA256_CTX *ctx, unsigned char *in, unsigned int len)
{
	_SHA256update(ctx, in, len);
}

void SHA224final(unsigned char *ret, SHA256_CTX *ctx)
{
	_SHA256final(ret, ctx);
#if 0
	printf("%.8x %.8x %.8x %.8x %.8x %.8x %.8x\n",
	       ctx->H[0], ctx->H[1], ctx->H[2], ctx->H[3],
	       ctx->H[4], ctx->H[5], ctx->H[6]);
#endif
	ul2uc(ctx->H, ret, (224 / 32));
}

/*-----------------------------------------------
    SHA256 function.(return 256bit char)
-----------------------------------------------*/
void OK_SHA256(unsigned int len, unsigned char *in, unsigned char *ret)
{
	SHA256_CTX ctx;

	SHA256init(&ctx);
	SHA256update(&ctx, in, len);
	SHA256final(ret, &ctx);
}

/*-----------------------------------------------
    SHA256 functions
-----------------------------------------------*/
void SHA256init(SHA256_CTX *ctx)
{
	_SHA256init(ctx, initH256);
}

void SHA256update(SHA256_CTX *ctx, unsigned char *in, unsigned int len)
{
	_SHA256update(ctx, in, len);
}

void SHA256final(unsigned char *ret, SHA256_CTX *ctx)
{
	_SHA256final(ret, ctx);
#if 0
	printf("%.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x\n",
	       ctx->H[0], ctx->H[1], ctx->H[2], ctx->H[3],
	       ctx->H[4], ctx->H[5], ctx->H[6], ctx->H[7]);
#endif
	ul2uc(ctx->H, ret, (256 / 32));
}

static void _SHA256init(SHA256_CTX *ctx, const uint32_t *initH)
{
	int i;

	for (i = 0; i < 8; i++)
		ctx->H[i] = initH[i];
	ctx->len[0] = ctx->len[1] = 0;
	ctx->mod = 0;
}

static void _SHA256update(SHA256_CTX *ctx, unsigned char *in, unsigned int len)
{
	uint32_t w[NUMOFSEQ];
	uint32_t *H;
	unsigned char *dat;
	int i, mod, tmp;

	if (in == NULL)
		return;

	H = ctx->H;
	/* update the length in bits. */
	if ((ctx->len[0] += ((uint32_t)len << 3)) < ((uint32_t)len << 3))
		ctx->len[1]++;
	ctx->len[1] += ((uint32_t)len >> 29);
	dat = ctx->dat;
	mod = ctx->mod;

	if (len + mod <= SHA256_BLOCKSIZE) {
		ctx->mod = len + mod;
		memcpy(&dat[mod], in, len);
	} else {
		memcpy(&dat[mod], in, SHA256_BLOCKSIZE - mod);
		set_w(dat, w, SHA256_BLOCKSIZE);
		calc_w(w);
		sha256_trans(w, H);

		tmp = len - SHA256_BLOCKSIZE;
		for (i = SHA256_BLOCKSIZE - mod; i < tmp; i += SHA256_BLOCKSIZE) {
			set_w(&in[i], w, SHA256_BLOCKSIZE);
			calc_w(w);
			sha256_trans(w, H);
		}
		ctx->mod = len - i;
		memcpy(dat, &in[i], ctx->mod);
	}
}

static void _SHA256final(unsigned char *ret, SHA256_CTX *ctx)
{
	uint32_t w[NUMOFSEQ];
	uint32_t *H;
	int mod;

	H = ctx->H;
	mod = ctx->mod;
	if (mod >= 56) {	/* 56 (byte): 448bit == 512-(32*2) */
		set_w(ctx->dat, w, mod);
		calc_w(w);
		sha256_trans(w, H);

		clear_w(w);
		if (mod == SHA256_BLOCKSIZE)
			w[0] = 0x80000000L;
		set_length(ctx->len, w);
		calc_w(w);
		sha256_trans(w, H);
	} else {
		set_w(ctx->dat, w, mod);
		set_length(ctx->len, w);
		calc_w(w);
		sha256_trans(w, H);
	}
}

/*-----------------------------------------------
  char <--> long (max must be a multiple of 4)
-----------------------------------------------*/
static void uc2ul(unsigned char *in, uint32_t *w, int max)
{
	int i, j;
	for (i = 0, j = 0; j < max; i++, j += 4)
		w[i] = ((uint32_t) in[j  ] << 24) |
		       ((uint32_t) in[j+1] << 16) |
		       ((uint32_t) in[j+2] <<  8) |
			(uint32_t) in[j+3];
}

static void ul2uc(uint32_t *H, unsigned char *ret, int n)
{
	int i, j;
	for (i = j = 0; i < n; i++, j += 4) {
		ret[j  ] = (unsigned char)(H[i] >> 24);
		ret[j+1] = (unsigned char)(H[i] >> 16);
		ret[j+2] = (unsigned char)(H[i] >> 8);
		ret[j+3] = (unsigned char)H[i];
	}
}

/*-----------------------------------------------
  set w[]
-----------------------------------------------*/
static void set_w(unsigned char *in, uint32_t *w, int max)
{
	int div, mod;

	/* clear w */
	memset(w, 0, sizeof(long) * 16);

	/* set w */
	if (max == SHA256_BLOCKSIZE)
		uc2ul(in, w, max);
	else {
		div = max / 4;
		mod = max % 4;
		uc2ul(in, w, max - mod);

		switch (mod) {
		case 0:
			w[div] = (uint32_t) 0x80000000L;
			break;
		case 1:
			w[div] = (uint32_t) in[max-1] << 24 |
				 (uint32_t) 0x800000L;
			break;
		case 2:
			w[div] = (uint32_t) in[max-2] << 24 |
				 (uint32_t) in[max-1] << 16 |
				 (uint32_t) 0x8000L;
			break;
		case 3:
			w[div] = (uint32_t) in[max-3] << 24 |
				 (uint32_t) in[max-2] << 16 |
				 (uint32_t) in[max-1] <<  8 |
				 (uint32_t) 0x80L;
			break;
		}
	}
}

static void clear_w(uint32_t *w)
{
	memset(w, 0, sizeof(long) * 16);
}

static void calc_w(uint32_t *w)
{
	int i;
	for (i = 16; i < NUMOFSEQ; i++)
		w[i] = sum256_1(w[i-2]) + w[i-7] + sum256_0(w[i-15]) + w[i-16];
}

static void set_length(uint32_t len[2], uint32_t *w)
{
	w[14] = len[1];
	w[15] = len[0];
}

/*-----------------------------------------------
    SHA256 transrate.
-----------------------------------------------*/
static void sha256_trans(uint32_t *w, uint32_t *H)
{
	uint32_t a, b, c, d, e, f, g, h, T1, T2;
	int i;

	a = H[0];
	b = H[1];
	c = H[2];
	d = H[3];
	e = H[4];
	f = H[5];
	g = H[6];
	h = H[7];
	for (i = 0; i < NUMOFSEQ; i++) {
		T1 = h + SUM256_1(e) + f1(e, f, g) + K256[i] + w[i];
		T2 = SUM256_0(a) + f3(a, b, c);
		h = g;
		g = f;
		f = e;
		e = d + T1;
		d = c;
		c = b;
		b = a;
		a = T1 + T2;
#if 0
		printf("t=%02d : %.8x %.8x %.8x %.8x %.8x %.8x %.8x %.8x\n",
		       i, a, b, c, d, e, f, g, h);
#endif
	}

	H[0] += a;
	H[1] += b;
	H[2] += c;
	H[3] += d;
	H[4] += e;
	H[5] += f;
	H[6] += g;
	H[7] += h;
}
