/* hkdf.c */
/*
 * Copyright (c) 2017-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

#include <aicrypto/nrg_kdf.h>
#include <aicrypto/ok_tool.h>
#include <aicrypto/ok_hmac.h>
#include <aicrypto/ok_sha2.h>
#include <aicrypto/nrg_sha3.h>

#ifdef HKDF_DEBUG
static inline void dump(const uint8_t *buf, const size_t bytes)
{
	int i;

	for (i = 0; i < bytes; i++) {
		printf("%02x ", buf[i]);
	}
	printf("\n");
}
#endif /* HKDF_DEBUG */

/**
 * Get pointer of HMAC function.
 *
 * Currently only supports OBJ_HASH_SHA1, OBJ_HASH_SHA256,
 * OBJ_HASH_SHA384 and OBJ_HASH_SHA512. Other specifications are
 * invalid and NULL is returned.
 *
 * @param[in] hash_algo aioid of hash.
 */
static inline hash_func_t get_hmac(int hash_algo)
{
	hash_func_t hash;

	switch (hash_algo) {
	case OBJ_HASH_SHA1:
		hash = HMAC_SHA1;
		break;
	case OBJ_HASH_SHA256:
		hash = HMAC_SHA256;
		break;
	case OBJ_HASH_SHA384:
		hash = HMAC_SHA384;
		break;
	case OBJ_HASH_SHA512:
		hash = HMAC_SHA512;
		break;
	default:
		/* unsupported */
		hash = NULL;
	}
	return (hash);
}


int HKDF_get_hash(const int hash_algo, hkdf_hash_t *hash)
{
	hash_func_t hmac;
	size_t len;

	if (hash == NULL) {
		return -1;
	}

	hmac = get_hmac(hash_algo);
	if (hmac == NULL) {
		return -1;
	}
	len = hash_size(hash_algo);
	if (len == -1) {
		return -1;
	}

	hash->hmac = hmac;
	hash->len = len;

	return 0;
}


int HKDF_Extract(const hkdf_hash_t *hash,
		 const uint8_t *salt, const size_t salt_len,
		 const uint8_t *ikm, const size_t ikm_len, uint8_t *prk)
{
	const uint8_t default_salt[0];

	if (hash == NULL || hash->hmac == NULL || hash->len == 0) {
		return -1;
	}

	if (salt == NULL) {
		if (salt_len != 0) {
			return -1;
		}
		salt = default_salt;
	}

	if (ikm == NULL || ikm_len == 0) {
		return -1;
	}

	if (prk == NULL) {
		return -1;
	}

	(*(hash->hmac))(ikm_len, (unsigned char *) ikm,
			salt_len, (unsigned char *) salt,
			(unsigned char *) prk);
#ifdef HKDF_DEBUG
	dump(prk, hash->len);
#endif /* HKDF_DEBUG */

	return 0;
}

int HKDF_Expand(const hkdf_hash_t *hash,
		const uint8_t *prk, const size_t prk_len,
		const uint8_t *info, const size_t info_len,
		uint8_t *okm, const size_t okm_len)
{
	int n;
	int msg_len;
	uint8_t i;

	if (hash == NULL || hash->hmac == NULL || hash->len == 0) {
		return -1;
	}

	if (prk == NULL || prk_len < hash->len) {
		return -1;
	}

	/* info is optional, can be a zero-length string. */
	if (info == NULL && info_len != 0) {
		return -1;
	}

	if (okm == NULL || okm_len == 0 || okm_len > hash->len * 255) {
		return -1;
	}

	unsigned char t[hash->len];
	unsigned char msg[hash->len + info_len + 1];

	/* N = ceil(L/HashLen) */
	n = okm_len / hash->len;
	if (okm_len % hash->len > 0) {
		n++;
	}

#ifdef HKDF_DEBUG
	printf("ceil(%d/%d)=%d\n", (int)okm_len, (int)hash->len, n);
#endif /* HKDF_DEBUG */

	/* the buffer for concatenating N number of T */
	unsigned char buf[hash->len * n];
	int buf_len = 0;

	/* T(0) = empty string (zero length) */
	memset(t, 0, hash->len);

	/* T = T(1) | T(2) | T(3) | ... | T(N); */
	for (i = 1; i <= n; i++) {
		msg_len = 0;
		if (i > 1) {
			memcpy(msg + 0, t, hash->len);
			msg_len += hash->len;
		}
		if (info_len != 0) {
			memcpy(msg + msg_len, info, info_len);
			msg_len += info_len;
		}
		memcpy(msg + msg_len, &i, 1);
		msg_len += 1;

#ifdef HKDF_DEBUG
		printf("T(%d)|info|n: ", i-1);
		dump(msg, msg_len);
#endif /* HKDF_DEBUG */

		/* T(n) = HMAC-Hash(PRK, T(n-1) | info | n) */
		(*(hash->hmac))(msg_len, (unsigned char *) msg,
			prk_len, (unsigned char *) prk,
			(unsigned char *) t);
#ifdef HKDF_DEBUG
		printf("T(%d)= ", i);
		dump(t, hash->len);
#endif /* HKDF_DEBUG */

		memcpy(buf + buf_len, t, hash->len);
		buf_len += hash->len;
	}

	memcpy(okm, buf, okm_len);

#ifdef HKDF_DEBUG
	printf("OKM: ");
	dump(okm, okm_len);
#endif /* HKDF_DEBUG */

	return 0;
}


int HKDF(const hkdf_hash_t *hash,
	 const uint8_t *salt, const size_t salt_len,
	 const uint8_t *ikm, const size_t ikm_len,
	 const uint8_t *info, const size_t info_len,
	 uint8_t *okm, const size_t okm_len)
{
	if (hash == NULL || hash->hmac == NULL || hash->len == 0) {
		return -1;
	}

	uint8_t prk[hash->len];

	if (HKDF_Extract(hash, salt, salt_len, ikm, ikm_len, prk) != 0) {
		return -1;
	}
	if (HKDF_Expand(hash, prk, hash->len, info, info_len,
			okm, okm_len) != 0) {
		return -1;
	}

	return 0;
}

