/* gcm.c */
/*
 * Copyright (c) 2016 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include <string.h>
#include <stdlib.h>
#include <assert.h>

#include <aicrypto/nrg_modes.h>
#include <aicrypto/ok_err.h>

#include "bs128.h"

/**
 * Create 128bit hash subkey.
 *
 * In Step 1, the hash subkey for the GHASH function is generated
 * by applying the block cipher to the “zero” block.
 */
static bs128_t *new_hash_subkey128(Key *key, ciph128 ciph);

/**
 * Generate pre-counter block.
 *
 * iv: IV
 */
static bs128_t *gen_pre_counter_block(bitstr_t *iv, bs128_t *H);

/**
 * Incrementing Function
 * LSB 32bit
 */
static int inc_counter_block(bs128_t *x, unsigned int len);

/**
 * Compute bit length from byte length.
 */
static inline int32_t bit_length(int32_t byte);

/**
 * Compute block count from bit length.
 */
static inline int32_t bit128_block(int32_t bit_length);

/**
 * Incremtent CB and execute GCTR after that.
 *
 * If 'in' is plaintext it will be encrypt.
 * If 'in' is ciphertext it will be decrypt.
 */
static int inccb_gctr(Key *key, bs128_t *J0, const uint8_t *in, int32_t byte,
		      uint8_t *out, ciph128 ciph);

/**
 * GHASH(H, A||0||C||0||[len(A)]64||[len(C)]64)
 */
static bs128_t *ghash(bs128_t *H, uint8_t *a, unsigned int sizea,
		      const uint8_t *c, unsigned int sizec);

/**
 * [len(A)]64||[len(C)]64
 */
static int _ghash_len64(uint64_t lena, uint64_t lenc,
			uint8_t *buf, uint16_t len);

/**
 * W  = Yi-1 xor Xi
 * Yi = W . H
 * return Ym
 * len: bit length
 */
static int bs_ghash(bs128_t *h, uint8_t *data, int len, bs128_t *y);

/**
 *
 */
static int bs_product(bs128_t *x, bs128_t *y, bs128_t *ret);

/**
 *
 */
static int gctr(Key *key, bs128_t *cb, const uint8_t *in, uint16_t inbyte,
		uint8_t *out, ciph128 ciph);

static inline int32_t bit_length(int32_t byte)
{
	return byte * 8;
}

static inline int32_t bit128_block(int32_t bit_length)
{
	int32_t blocks;

	blocks = bit_length / 128;
	if (bit_length % 128) {
		blocks++;
	}
	return blocks;
}

/**
 * Increment 4-byte uint8_t array.
 */
int inc32(uint8_t *p);

/**
 * Copy unsigned int 64 data to 8-byte uint8_t array.
 */
int copy_uint64_to_uint8(uint8_t *p, uint64_t q);


int gcm_encrypt(gcm_param_t *param, int32_t byte, const uint8_t *in,
		uint8_t *out, ciph128 ciph)
{
	bs128_t *H = NULL;  /* Hash subkey */
	bs128_t *J0 = NULL; /* pre-counter block */
	bs128_t *S = NULL; /* ghash */
	int rc = 0;
	uint8_t c[byte]; /* ciphertext */
	uint8_t s[16];
	uint8_t t[16];

	/* Step 1. */
	/* H : ciphher(K, 0^128) */
	H = new_hash_subkey128(param->ciph_key, ciph);
	if (H == NULL) {
		OK_set_error(ERR_ST_MEMALLOC, ERR_LC_MODES, ERR_PT_GCMENC,
			     NULL);
		rc = -1;
		goto done;
	}

	/* Step 2. */
	J0 = gen_pre_counter_block(&(param->iv), H);
	if (J0 == NULL) {
		OK_set_error(ERR_ST_MEMALLOC, ERR_LC_MODES, ERR_PT_GCMENC + 1,
			     NULL);
		rc = -1;
		goto done;
	}

	/* Step 3. */
	/* ICB: inc32(J0) */
	/* from plaintext to ciphertext */
	inccb_gctr(param->ciph_key, J0, in, byte, c, ciph);

	/* Step 4. */
	/* GHASH(H, A||0||C||0||[len(A)]64||[len(C)]64) */
	S = ghash(H, param->aad.buf, param->aad.byte, c, byte);
	if (S == NULL) {
		OK_set_error(ERR_ST_MEMALLOC, ERR_LC_MODES, ERR_PT_GCMENC + 2,
			     NULL);
		rc = -1;
		goto done;
	}
	bs_get_data(S, s, S->len);

	/* Step 6. */
	/* T = MSBt(GCTR(K, J0, S)) */
	gctr(param->ciph_key, J0, s, sizeof(S->buf), t, ciph);

	/*
	 * out = ciphertext + T
	 */
	memcpy(&out[0], c, byte);
	memcpy(&out[byte], t, sizeof(t));

done:
	if (H != NULL) {
		bs_free(H);
	}
	if (J0 != NULL) {
		bs_free(J0);
	}
	if (S != NULL) {
		bs_free(S);
	}
	return rc;
}

int gcm_decrypt(gcm_param_t *param, int32_t byte, const uint8_t *in,
		uint8_t *out, ciph128 ciph)
{
	bs128_t *H = NULL;  /* Hash subkey */
	bs128_t *J0 = NULL; /* pre-counter block */
	bs128_t *S = NULL;  /* ghash */
	int32_t clen = byte - 16; /* ciphertext length */
	const uint8_t *t; /* authentication tag */
	int rc = 0;
	uint8_t s[16]; /* ghash (bit string) */
	uint8_t t2[16]; /* authenticate tag (bit string) */
	uint8_t plaintext[clen];

	/* Step 1. */
	/* H : ciphher(K, 0^128) */
	H = new_hash_subkey128(param->ciph_key, ciph);
	if (H == NULL) {
		OK_set_error(ERR_ST_MEMALLOC, ERR_LC_MODES, ERR_PT_GCMDEC,
			     NULL);
		rc = -1;
		goto done;
	}

	/* Step 2. */
	J0 = gen_pre_counter_block(&(param->iv), H);
	if (J0 == NULL) {
		OK_set_error(ERR_ST_MEMALLOC, ERR_LC_MODES, ERR_PT_GCMDEC+1,
			     NULL);
		rc = -1;
		goto done;
	}

	/* Set the pointer at the beginning of the authentcication tag. */
	t = in + clen;

	/* Step 5. */
	/* GHASH(H, A||0||C||0||[len(A)]64||[len(C)]64) */
	S = ghash(H, param->aad.buf, param->aad.byte, in, clen);
	bs_get_data(S, s, S->len);

	/* Step 6. */
	/* T = MSBt(GCTR(K, J0, S)) */
	gctr(param->ciph_key, J0, s, sizeof(S->buf), t2, ciph);

	if (memcmp(t, t2, sizeof(S->buf)) != 0) {
		/* return FAIL */
		rc = -2;
		goto done;
	}

	/* Step 3. */
	/* ICB: inc32(J0) */
	/* from ciphertext to plaintext */
	inccb_gctr(param->ciph_key, J0, in, clen, plaintext, ciph);
	memcpy(out, plaintext, clen);

done:
	if (H != NULL) {
		bs_free(H);
	}
	if (J0 != NULL) {
		bs_free(J0);
	}
	if (S != NULL) {
		bs_free(S);
	}
	return rc;
}

void gcm_param_set_key(gcm_param_t *param, Key *key)
{
	assert(key != NULL);

	param->ciph_key = key;
}

void gcm_param_set_iv(gcm_param_t *param, void *iv, int32_t byte)
{
	assert(iv != NULL && byte != 0);

	param->iv.byte = byte;
	param->iv.buf = (uint8_t *) iv;
}

void gcm_param_set_aad(gcm_param_t *param, uint8_t *aad, int32_t byte)
{
	param->aad.byte = (aad != NULL) ? byte : 0;
	param->aad.buf = (byte != 0) ? aad : NULL;
}

static bs128_t *new_hash_subkey128(Key *key, ciph128 ciph)
{
	uint8_t o128[16];
	bs128_t *O128;

	O128 = bs_new(128);
	if (O128 == NULL) {
		return NULL;
	}

	memset(o128, 0, 16);
	(*ciph)(key, o128);

	bs_set_data(O128, o128, 128);

	return O128;
}

static bs128_t *gen_pre_counter_block(bitstr_t *iv, bs128_t *H)
{
	bs128_t *J0;
	uint8_t j0[16];

	if (bit_length(iv->byte) == 96) {
		J0 = bs_new(128);
		if (J0 != NULL) {
			/* J0 = IV || 0^31 || 1 */
			memset(j0, 0, 16);
			memcpy(j0, iv->buf, iv->byte);
			j0[15] |= 1;

			bs_set_data(J0, j0, 128);
		}
	} else {
		J0 = ghash(H, NULL, 0, iv->buf, iv->byte);
	}

	return J0;
}

static int inc_counter_block(bs128_t *x, unsigned int len)
{
	inc32(&(x->buf[12]));
	return 0;
}

static int inccb_gctr(Key *key, bs128_t *J0, const uint8_t *in, int32_t byte,
		      uint8_t *out, ciph128 ciph)
{
	bs128_t *ICB; /* initial counter block */

	ICB = bs_clone(J0);
	inc_counter_block(ICB, 128);

	/* C = GCTR(K, ICB, P) */
	gctr(key, ICB, in, byte, out, ciph);

	bs_free(ICB);
	return 0;
}

/*
 * sizea : byte length of a
 * sizec : byte length of c
 */
static bs128_t *ghash(bs128_t *H, uint8_t *a, unsigned int sizea,
		      const uint8_t *c, unsigned int sizec)
{
	bs128_t *ret;
	int lena;
	int lenc;
	uint64_t lena64;
	uint64_t lenc64;

	uint8_t *data;

	int i, m, n;
	int len;

	ret = bs_new(128);
	if (ret == NULL) {
		return NULL;
	}

	/* (1) GHASH(H, A) => X1..m */
	lena = bit_length(sizea);
	if (sizea) {
		m = bit128_block(lena);

		data = (uint8_t *) malloc(m<<4);
		memset(data, 0, m<<4);
		memcpy(data, a, sizea);
		bs_ghash(H, data, m*128, ret);
		free(data);
	}

	/* (2) GHASH(H, C) => X1+m..n */
	lenc = bit_length(sizec);
	if (sizec) {
		n = bit128_block(lenc);
		data = (uint8_t *) malloc(n<<4);
		memset(data, 0, n<<4);
		memcpy(data, c, sizec);
		bs_ghash(H, data, n*128, ret);
		free(data);
	}

	/* (3) GHASH(H, [len(A)]64||[len(C)]64) => X1+m+n */
	len = 128/8;
	data = (uint8_t *) malloc(len);
	lena64 = lena;
	lenc64 = lenc;
	_ghash_len64(lena64, lenc64, data, len);
	bs_ghash(H, data, len*8, ret);
	free(data);

	return ret;
}

static int _ghash_len64(uint64_t lena, uint64_t lenc, uint8_t *buf,
			uint16_t len)
{
	uint8_t lalc[16];

	copy_uint64_to_uint8(&lalc[0], lena);
	copy_uint64_to_uint8(&lalc[8], lenc);

	memcpy(buf, lalc, len);
	return 0;
}

/*
 * W  = Yi-1 xor Xi
 * Yi = W . H
 * return Ym
 * len: bit length of data
 */
static int bs_ghash(bs128_t *h, uint8_t *data, int len, bs128_t *y)
{
	int i, m;
	bs128_t *w; /* working buffer */
	bs128_t *xi; /* i-th block of x */
	uint8_t *xp; /* 128bit (16byte) block pointer of x */

	assert(len % 128 == 0);

	w = bs_new(128);
	xi = bs_new(128);
	xp = data;

	for (i = 0; i < len/128; i++) {
		bs_set_data(xi, xp, 128);
		/* Yi-1 xor X => w */
		if (bs_xor(y, xi) != 0) {
			return -1;
		}

		/* y . H => Yi */
		if (bs_product(y, h, w) != 0) {
			return -1;
		}
		bs_copy(w, y);
		xp += 16;
	}

	bs_free(w);
	bs_free(xi);
	return 0;
}

#define LSB1(v)		(bs_check_bit(v, 127))

static int bs_product(bs128_t *x, bs128_t *y, bs128_t *ret)
{
	static bs128_t *R = NULL;
	static uint8_t r[]={
			0xe1, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
			0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};

	int i;
	bs128_t *v;

	if (R == NULL) {
		if ((R = bs_new(128)) == NULL) {
			/* TODO error */
			;
		}
		bs_set_data(R, r, 128);
	}

	v = bs_clone(y);

	/* 3. For i = 0 to 127, calcurate blocks Zi+1 and Vi+1 as follows: */
	/* initialize: Z0 => ret */
	bs_set_zero(ret, 0, 128);

	for (i = 0; i < 128; i++) {
		if (bs_check_bit(x, i) != 0) {
			bs_xor(ret, v);
		}

		if (i == 127) {
			break;
		}
		
		if (LSB1(v) == 0) {
			/* Vi+1 = Vi>>1 */;
			bs_rshift(v, 1);
		} else {
			/* Vi+1 = (Vi>>1) xor R */;
			bs_rshift(v, 1);
			bs_xor(v, R);
		}
	}

	bs_free(v);

	return 0;
}

/*
 * in: input data
 * inbyte: byte length of x.
 */
static int gctr(Key *key, bs128_t *cb, const uint8_t *in, uint16_t inbyte,
		uint8_t *out, ciph128 ciph)
{
	int32_t bits;
	uint8_t buf[inbyte];
	uint8_t *p;
	bs128_t *xi;
	bs128_t *w;
	uint8_t cbi[16];
	unsigned int mod; /* modulo bit */

	/* Step 1. */
	/* If X is the empty string, then return the empty string as Y. */
	if (inbyte == 0) {
		return 0;
	}

	/* Step 2. */
	/* n = len(X)/128 */
	bits = bit_length(inbyte);

	/* Step 3. */
	/* X = X1||X2||...||Xn */
	memcpy(buf, in, inbyte);

	xi = bs_new(128);
	w = bs_new(128);
	p = buf;

	/*
	 * Iterate the following steps:
	 * . Get 128bit block (xi) from the buf.
	 * . Apply block chipher (ciph) to the counter blocks (cb).
	 * . XOR the block (xi) with the counter blocks (cb).
	 * . Overwrite the buf with xi.
	 * . Increment the counter block (cb).
	 */
	while (bits > 0) {
		mod = bits > 128 ? 128 : bits;

		/* copy 128 bits block from p to Xi  */
		bs_set_data(xi, p, mod);

		/* E(K, CBi) => w */
		bs_get_data(cb, cbi, 128);
		(*ciph)(key, cbi);
		bs_set_data(w, cbi, mod);

		/* Overwrite p with (Xi xor w).  */
		bs_xor(xi, w);
		bs_get_data(xi, p, mod);

		bits -= mod;

		if (bits > 0) {
			/* Step 5. */
			/* CBi=inc32(CBi-1) */
			inc_counter_block(cb, 128);
			p += 16;
		}
	}
	bs_free(w);
	bs_free(xi);

	/* Step 8. */
	/* Y = Y1||Y2||..||Yn */
	memcpy(out, buf, inbyte);

	return 0;
}

int inc32(uint8_t *p)
{
	uint32_t q = 0;

	q |= p[0];
	q = (q << 8) | p[1];
	q = (q << 8) | p[2];
	q = (q << 8) | p[3];
	q++;
	p[0] = q>>(32-8);
	p[1] = (q<<8)>>(32-8);
	p[2] = (q<<16)>>(32-8);
	p[3] = (q<<24)>>(32-8);

	return 0;
}

int copy_uint64_to_uint8(uint8_t *p, uint64_t q)
{
	p[0] = q>>(64-8);
	p[1] = (q<<8)>>(64-8);
	p[2] = (q<<16)>>(64-8);
	p[3] = (q<<24)>>(64-8);
	p[4] = (q<<32)>>(64-8);
	p[5] = (q<<40)>>(64-8);
	p[6] = (q<<48)>>(64-8);
	p[7] = (q<<56)>>(64-8);
	return 0;
}
