/* sha3test.c */
/*
 * Copyright (c) 2014-2017 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <ctype.h>

#include <aicrypto/nrg_sha3.h>
#include "Modes/KeccakHash.h"
#include "sha3test.h"
#include "sha3testnist.h"

#define TESTVECTORS_DIR	"KeccakCodePackage/TestVectors/"

#ifndef PATH
# define PATH	"."
#endif

typedef void (*HASHCompute)(int len, unsigned char *in, unsigned char *ret);
typedef void (*KCPHashInit)(Keccak_HashInstance *instance);
typedef uint8_t *(*NISTFunc)(size_t databitlen);

/* test/getfpath.c */
char *get_fpath(char *path, char* filename);


/**
 * Print the hexadecimal string.
 *
 */
static void print_charhex(const char *name, const unsigned char *cs,
			  const int len)
{
	int i;

	printf("%s: ", name);
	for (i = 0; i < len; i++) {
		printf("%02x", cs[i]);
	}
	puts("");
}

/**
 * Compute the digest message of all SHA-3 variants
 *
 */
void test_sha3_all(unsigned char *in)
{
	unsigned char ret_sha3_224[SHA3_224_DIGESTSIZE];
	unsigned char ret_sha3_256[SHA3_256_DIGESTSIZE];
	unsigned char ret_sha3_384[SHA3_384_DIGESTSIZE];
	unsigned char ret_sha3_512[SHA3_512_DIGESTSIZE];
	unsigned char ret_shake128[DEFAULT_SHAKE128_DIGESTSIZE];
	unsigned char ret_shake256[DEFAULT_SHAKE256_DIGESTSIZE];
	char hashname[32];

	NRG_SHA3_224(strlen((const char *)in), in, ret_sha3_224);
	print_charhex("SHA3-224", ret_sha3_224, sizeof ret_sha3_224);

	NRG_SHA3_256(strlen((const char *)in), in, ret_sha3_256);
	print_charhex("SHA3-256", ret_sha3_256, sizeof ret_sha3_256);

	NRG_SHA3_384(strlen((const char *)in), in, ret_sha3_384);
	print_charhex("SHA3-384", ret_sha3_384, sizeof ret_sha3_384);

	NRG_SHA3_512(strlen((const char *)in), in, ret_sha3_512);
	print_charhex("SHA3-512", ret_sha3_512, sizeof ret_sha3_512);

	snprintf(hashname, sizeof hashname,
		 "SHAKE128[%zu]", (sizeof ret_shake128) * 8);
	NRG_SHAKE128(strlen((const char *)in), in, ret_shake128);
	print_charhex(hashname, ret_shake128, sizeof ret_shake128);

	snprintf(hashname, sizeof hashname,
		 "SHAKE256[%zu]", (sizeof ret_shake128) * 8);
	NRG_SHAKE256(strlen((const char *)in), in, ret_shake256);
	print_charhex(hashname, ret_shake256, sizeof ret_shake256);
}

/**
 * Read one line.
 *
 */
static unsigned char *readline(const char *hash_name,
			       const char *fn_testvector,
			       unsigned char *linebuf, const int linebufsz,
			       FILE *fp)
{
	char *nlptr;

	while (fgets((char *)linebuf, linebufsz, fp) != NULL) {
		if ((nlptr = strchr((char *)linebuf, '\n')) == NULL) {
			printf("%s: too long line in %s\n",
			       hash_name, fn_testvector);
			while (fgets((char *)linebuf, linebufsz, fp) != NULL) {
				if (strchr((char *)linebuf, '\n') != NULL) {
					break;
				}
			}
			continue;
		}
		*nlptr = '\0';
		return linebuf;
	}

	return NULL;
}

/**
 * Parse "Len = " line.
 *
 */
static int parse_len(const unsigned char *linebuf)
{
	unsigned char *p;

	p = (unsigned char *)strchr((const char *)linebuf, '=');
	p += 2;

	return (int)strtol((const char *)p, NULL, 10);
}

/**
 * Parse hexadecimal string line.
 *
 */
static int parse_hexstr(unsigned char *ret, const int retsz,
			const unsigned char *linebuf)
{
	unsigned char *p;
	int i;
	int err = 0;

	p = (unsigned char *)strchr((const char *)linebuf, '=');
	p += 2;

	memset(ret, 0, retsz);

	i = 0;
	for (i = 0; i < retsz && *p != '\0' && *(p + 1) != '\0'; i++) {
		char hex[3];

		if (isxdigit(*p) == 0 || isxdigit(*(p + 1)) == 0) {
			err = -1;
			break;
		}

		hex[0] = *p++;
		hex[1] = *p++;
		hex[2] = '\0';

		ret[i] = (uint8_t)strtol(hex, NULL, 16);
	}

	return err;
}

/**
 * Common function of SHA-3 testcase using KCP test vector.
 *
 */
static int test_sha3_kcp_vector(const char *hash_name, char *testvector,
				HASHCompute hashfunc, const int digestsize)
{
	char *fn_testvector = get_fpath(PATH, testvector);
	FILE *fp;
	unsigned char linebuf[5120];

	int msglen;
	uint8_t msg[1024];
	uint8_t expect[digestsize];
	uint8_t digest[digestsize];

	if (fn_testvector == NULL) {
		printf("%s: not found %s\n", hash_name, testvector);
		printf("%s: skip test using keccak test vector.\n", hash_name);
		return 0;
	}

	if ((fp = fopen(fn_testvector, "r")) == NULL) {
		printf("%s: cannot open %s\n", hash_name, fn_testvector);
		printf("%s: skip test using keccak test vector.\n", hash_name);
		free(fn_testvector);
		return 0;
	}

	while (readline(hash_name, fn_testvector, linebuf, sizeof linebuf,
			fp) != NULL) {

		/* Ignore a line that begin with # and a blank line*/
		if ((linebuf[0] == '#') || linebuf[0] == '\0') {
			continue;
		}

		/* Parse "Len = " line. */
		if (strncmp((const char *)linebuf, "Len = ",
			    strlen("Len = ")) == 0) {
			msglen = parse_len(linebuf);
			continue;
		}

		/*
		  The len argument of NRG_SHA3_224 and so on is the number
		  in bytes. Thus skip when msglen is not multiple of 8
		*/
		if (msglen % 8 != 0) {
			continue;
		}

		/* Parse "Msg = " line. */
		if (strncmp((const char *)linebuf, "Msg = ",
			    strlen("Msg = ")) == 0) {
			parse_hexstr(msg, sizeof msg, linebuf);
			continue;
		} 

		/* Parse "MD = " line and "Squeezed = " line. */
		if ((strncmp((const char *)linebuf, "MD = ",
			     strlen("MD = ")) == 0) ||
		    (strncmp((const char *)linebuf, "Squeezed = ",
			     strlen("Squeezed = ")) == 0)) {
			parse_hexstr(expect, sizeof expect, linebuf);
		}

		/* Compute message digest */
		(*hashfunc)(msglen / 8, msg, digest);

		/* Verify expect digest */
		if (memcmp(expect, digest, sizeof digest) == 0) {
			printf("%s: Len=%d test ok\n", hash_name, msglen);
		} else {
			printf("%s: Len=%d test error\n", hash_name, msglen);
                        print_charhex("  expect", expect, sizeof expect);
                        print_charhex("  digest", digest, sizeof digest);
			fclose(fp);
			free(fn_testvector);
			return -1;
		}
	}

	fclose(fp);
	free(fn_testvector);

	return 0;
}

static int test_sha3_nist_nrg(const char *hash_name, HASHCompute nrghash,
			      struct tv *tv, const uint8_t *expected,
			      const int digestsize)
{
	uint8_t digest[digestsize];

	memset(digest, 0, digestsize);

	/* Compute message digest */
	(*nrghash)(tv->databitlen/8, tv->data, digest);

	/* Verify expect digest */
	if (memcmp(expected, digest, sizeof digest) == 0) {
		printf("%s: NRG Len=%zu test ok\n", hash_name, tv->databitlen);
	} else {
		printf("%s: NRG Len=%zu test error\n", hash_name, tv->databitlen);
		print_charhex("  expect", expected, digestsize);
		print_charhex("  digest", digest, sizeof digest);
		return -1;
	}

	return 0;
}

static int test_sha3_nist_kcp(const char *hash_name, KCPHashInit kcphashinit,
			      struct tv *tv, const uint8_t *expected,
			      const int digestsize, const int squeezedoutputsize)
{
	uint8_t digest[digestsize];
	Keccak_HashInstance hashInstance;

	memset(digest, 0, digestsize);

	/* Compute message digest */
	(*kcphashinit)(&hashInstance);
	Keccak_HashUpdate(&hashInstance, tv->data, tv->databitlen);
	Keccak_HashFinal(&hashInstance, digest);
	if (squeezedoutputsize > 0) {
		Keccak_HashSqueeze(&hashInstance, digest, squeezedoutputsize);
	}

	/* Verify expect digest */
	if (memcmp(expected, digest, sizeof digest) == 0) {
		printf("%s: keccak Len=%zu test ok\n",
		       hash_name, tv->databitlen);
	} else {
		printf("%s: keccak Len=%zu test error\n",
		       hash_name, tv->databitlen);
		print_charhex("  expect", expected, digestsize);
		print_charhex("  digest", digest, sizeof digest);
		return -1;
	}

	return 0;
}

/**
 * Testcase for SHA3-224 that exists at NIST's web page.
 *
 */
static int test_sha3_nist_vector(const char *hash_name, HASHCompute nrghash,
				 KCPHashInit kcphashinit,
				 NISTFunc get_expected,
				 const int digestsize,
				 const int squeezedoutputsize)
{
	int i;
	uint8_t *expected;

	for (i = 0; i < NIST_TEST_NUM; i++) {
		expected = (*get_expected)(nist_tv[i].databitlen);

		if (nist_tv[i].databitlen == 0 ||
		    nist_tv[i].databitlen == 1600) {
			if (test_sha3_nist_nrg(hash_name, nrghash,
					       &nist_tv[i], expected,
					       digestsize) != 0) {
				return -1;
			}
		}

		if (test_sha3_nist_kcp(hash_name, kcphashinit,
				       &nist_tv[i], expected,
				       digestsize, squeezedoutputsize) != 0) {
			return 	-1;
		}
	}

	return 0;
}

static void KCP_SHA3_224(Keccak_HashInstance *hashInstance)
{
	Keccak_HashInitialize_SHA3_224(hashInstance);
}

/**
 * Testcase for SHA3-224
 *
 */
int test_sha3_224()
{
	char *hash_name = "sha3-224";
	char *testvector = TESTVECTORS_DIR "ShortMsgKAT_SHA3-224.txt";

	if (test_sha3_kcp_vector(hash_name, testvector,
				 (HASHCompute)NRG_SHA3_224,
				 SHA3_224_DIGESTSIZE) < 0) {
		return -1;
	}

	return test_sha3_nist_vector(hash_name, (HASHCompute)NRG_SHA3_224,
				     (KCPHashInit)KCP_SHA3_224,
				     get_sha3_224_expected,
				     SHA3_224_DIGESTSIZE, 0);
}

static void KCP_SHA3_256(Keccak_HashInstance *hashInstance)
{
	Keccak_HashInitialize_SHA3_256(hashInstance);
}

/**
 * Testcase for SHA3-256
 *
 */
int test_sha3_256()
{
	char *hash_name = "sha3-256";
	char *testvector = TESTVECTORS_DIR "ShortMsgKAT_SHA3-256.txt";

	if (test_sha3_kcp_vector(hash_name, testvector,
				 (HASHCompute)NRG_SHA3_256,
				 SHA3_256_DIGESTSIZE) < 0) {
		return -1;
	}

	return test_sha3_nist_vector(hash_name, (HASHCompute)NRG_SHA3_256,
				     (KCPHashInit)KCP_SHA3_256,
				     get_sha3_256_expected,
				     SHA3_256_DIGESTSIZE, 0);
}

static void KCP_SHA3_384(Keccak_HashInstance *hashInstance)
{
	Keccak_HashInitialize_SHA3_384(hashInstance);
}

/**
 * Testcase for SHA3-384
 *
 */
int test_sha3_384()
{
	char *hash_name = "sha3-384";
	char *testvector = TESTVECTORS_DIR "ShortMsgKAT_SHA3-384.txt";

	if (test_sha3_kcp_vector(hash_name, testvector,
				 (HASHCompute)NRG_SHA3_384,
				 SHA3_384_DIGESTSIZE) <  0) {
		return -1;
	}

	return test_sha3_nist_vector(hash_name, (HASHCompute)NRG_SHA3_384,
				     (KCPHashInit)KCP_SHA3_384,
				     get_sha3_384_expected,
				     SHA3_384_DIGESTSIZE, 0);
}

static void KCP_SHA3_512(Keccak_HashInstance *hashInstance)
{
	Keccak_HashInitialize_SHA3_512(hashInstance);
}

/**
 * Testcase for SHA3-512
 *
 */
int test_sha3_512()
{
	char *hash_name = "sha3-512";
	char *testvector = TESTVECTORS_DIR "ShortMsgKAT_SHA3-512.txt";

	if (test_sha3_kcp_vector(hash_name, testvector,
				 (HASHCompute)NRG_SHA3_512,
				 SHA3_512_DIGESTSIZE) < 0) {
		return -1;
	}

	return test_sha3_nist_vector(hash_name, (HASHCompute)NRG_SHA3_512,
				     (KCPHashInit)KCP_SHA3_512,
				     get_sha3_512_expected,
				     SHA3_512_DIGESTSIZE, 0);
}

static void KCP_SHAKE128(Keccak_HashInstance *hashInstance)
{
	Keccak_HashInitialize_SHAKE128(hashInstance);
}

/**
 * Testcase for SHAKE128
 *
 */
int test_shake128()
{
	char *hash_name = "shake128";
	char *testvector = TESTVECTORS_DIR "ShortMsgKAT_SHAKE128.txt";

	if (test_sha3_kcp_vector(hash_name, testvector,
				 (HASHCompute)NRG_SHAKE128,
				 DEFAULT_SHAKE128_DIGESTSIZE) < 0) {
		return -1;
	}

	return test_sha3_nist_vector(hash_name, (HASHCompute)NRG_SHAKE128,
				     (KCPHashInit)KCP_SHAKE128,
				     get_shake128_expected,
				     DEFAULT_SHAKE128_DIGESTSIZE, 4096);
}

static void KCP_SHAKE256(Keccak_HashInstance *hashInstance)
{
	Keccak_HashInitialize_SHAKE256(hashInstance);
}

/**
 * Testcase for SHAKE256
 *
 */
int test_shake256()
{
	char *hash_name = "shake256";
	char *testvector = TESTVECTORS_DIR "ShortMsgKAT_SHAKE256.txt";

	if (test_sha3_kcp_vector(hash_name, testvector,
				 (HASHCompute)NRG_SHAKE256,
				 DEFAULT_SHAKE256_DIGESTSIZE) < 0) {
		return -1;
	}

	return test_sha3_nist_vector(hash_name, (HASHCompute)NRG_SHAKE256,
				     (KCPHashInit)KCP_SHAKE256,
				     get_shake256_expected,
				     DEFAULT_SHAKE256_DIGESTSIZE, 4096);
}
