/*
 * Copyright (c) 2015-2016 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_prf.h"

#include <string.h>

/* for HMAC_SHA256 */
#include <aicrypto/ok_hmac.h>

/**
 * calculate P_hash SHA256.
 */
static void p_hash_sha256(const uint8_t *secret,
			  const uint32_t secret_len,
			  const uint8_t *seed,
			  const uint32_t seed_len,
			  uint8_t *dest,
			  const uint32_t dest_len);

/**
 * calculate P_hash SHA384.
 */
static void p_hash_sha384(const uint8_t *secret,
			  const uint32_t secret_len,
			  const uint8_t *seed,
			  const uint32_t seed_len,
			  uint8_t *dest,
			  const uint32_t dest_len);

static void p_hash_sha256(const uint8_t *secret,
			  const uint32_t secret_len,
			  const uint8_t *seed,
			  const uint32_t seed_len,
			  uint8_t *dest,
			  const uint32_t dest_len) {
	/* P_hash use following algorithm (RFC 5246 section 5).
	 *
	 *   p_<hash>(secret, seed) = HMAC_hash(secret, A(1) + seed) +
	 *                            HMAC_hash(secret, A(2) + seed) +
	 *                            HMAC_hash(secret, A(3) + seed) +
	 *                            ...
	 *
	 *   A(0) = seed
	 *   A(1) = HMAC_hash(secret, A(i - 1))
	 *
	 * if HMAC_hash === HMAC_SHA256, HMAC_hash function generate
	 * 32byte data. So, if you want to create 80 byte data, n of
	 * A(n) is 3 (32 * 3 = 96).
	 */

	/* computed MAC length is fixed length 32 (in the case of
	 * HMAC_SHA256). */
	const uint32_t hmac_len = 32;

	/* save last computed MAC to recycle A(i - 1).
	 *
	 * A(0) is seed value. this is variable length.
	 * A(n) is HMAC_hash value. this is fixed length.
	 */
	uint32_t last_len = (seed_len < hmac_len) ? hmac_len : seed_len;
	uint8_t  last[seed_len];

	/* NOTE: compiler outputs warning to pass to secret the argument
	 * of HMAC_SHA256 since arguments of HMAC_SHA256 function is not
	 * a const (I think 1st -4th arg of HMAC_SHA256 should be const
	 * at least). So, use temporary variable sec. this should be
	 * fixed. */
	uint8_t sec[secret_len];
	memcpy(&(sec[0]), &(secret[0]), secret_len);

	/* A(0) = seed. */
	memcpy(&(last[0]), &(seed[0]), last_len);

	for (uint32_t offset = 0; offset < dest_len; offset += hmac_len) {
		/* A(i) = HMAC_hash(secret, A(i - 1)) */
		uint8_t hmac_a_i[hmac_len];

		HMAC_SHA256(last_len, last, secret_len, sec, hmac_a_i);

		/* HMAC_SHA256(secret, A(i) + seed) */
		uint32_t tmp_len = hmac_len + seed_len;
		uint8_t  tmp[tmp_len];

		memcpy(&(tmp[0]),        &(hmac_a_i[0]), hmac_len);
		memcpy(&(tmp[hmac_len]), &(seed[0]),     seed_len);

		uint8_t hmac[hmac_len];
		HMAC_SHA256(tmp_len, tmp, secret_len, sec, hmac);

		uint32_t length = hmac_len;
		if ((dest_len - offset) < hmac_len) {
			length = dest_len - offset;
		}

		memcpy(&(dest[offset]), &(hmac[0]), length);

		/* save A(i) for next loop. */
		memcpy(&(last[0]), &(hmac_a_i[0]), hmac_len);
		last_len = hmac_len;
	}
}

static void p_hash_sha384(const uint8_t *secret,
			  const uint32_t secret_len,
			  const uint8_t *seed,
			  const uint32_t seed_len,
			  uint8_t *dest,
			  const uint32_t dest_len) {
	/* P_hash use following algorithm (RFC 5246 section 5).
	 *
	 *   p_<hash>(secret, seed) = HMAC_hash(secret, A(1) + seed) +
	 *                            HMAC_hash(secret, A(2) + seed) +
	 *                            HMAC_hash(secret, A(3) + seed) +
	 *                            ...
	 *
	 *   A(0) = seed
	 *   A(1) = HMAC_hash(secret, A(i - 1))
	 *
	 * if HMAC_hash === HMAC_SHA384, HMAC_hash function generate
	 * 48byte data. So, if you want to create 80 byte data, n of
	 * A(n) is 2 (48 * 2 = 96).
	 */

	/* computed MAC length is fixed length 48 (in the case of
	 * HMAC_SHA384). */
	const uint32_t hmac_len = SHA384_DIGESTSIZE;

	/* save last computed MAC to recycle A(i - 1).
	 *
	 * A(0) is seed value. this is variable length.
	 * A(n) is HMAC_hash value. this is fixed length.
	 */
	uint32_t last_len = (seed_len < hmac_len) ? hmac_len : seed_len;
	uint8_t  last[seed_len];

	/* NOTE: compiler outputs warning to pass to secret the argument
	 * of HMAC_SHA384 since arguments of HMAC_SHA384 function is not
	 * a const (I think 1st -4th arg of HMAC_SHA384 should be const
	 * at least). So, use temporary variable sec. this should be
	 * fixed. */
	uint8_t sec[secret_len];
	memcpy(&(sec[0]), &(secret[0]), secret_len);

	/* A(0) = seed. */
	memcpy(&(last[0]), &(seed[0]), last_len);

	for (uint32_t offset = 0; offset < dest_len; offset += hmac_len) {
		/* A(i) = HMAC_hash(secret, A(i - 1)) */
		uint8_t hmac_a_i[hmac_len];

		HMAC_SHA384(last_len, last, secret_len, sec, hmac_a_i);

		/* HMAC_SHA384(secret, A(i) + seed) */
		uint32_t tmp_len = hmac_len + seed_len;
		uint8_t  tmp[tmp_len];

		memcpy(&(tmp[0]),        &(hmac_a_i[0]), hmac_len);
		memcpy(&(tmp[hmac_len]), &(seed[0]),     seed_len);

		uint8_t hmac[hmac_len];
		HMAC_SHA384(tmp_len, tmp, secret_len, sec, hmac);

		uint32_t length = hmac_len;
		if ((dest_len - offset) < hmac_len) {
			length = dest_len - offset;
		}

		memcpy(&(dest[offset]), &(hmac[0]), length);

		/* save A(i) for next loop. */
		memcpy(&(last[0]), &(hmac_a_i[0]), hmac_len);
		last_len = hmac_len;
	}
}

void tls_prf_sha256(const uint8_t *secret,
		    const uint32_t secret_len,
		    const uint8_t *label,
		    const uint32_t label_len,
		    const uint8_t *seed,
		    const uint32_t seed_len,
		    uint8_t *dest,
		    const uint32_t dest_len) {
	/* PRF SHA256 use following algorithm (RFC 5246 section 5).
	 *
	 *   PRF(secret, label, seed) = p_<hash>(secret, label + seed)
	 *
	 * */
	const uint32_t len = label_len + seed_len;

	uint8_t hash_seed[len];

	memcpy(&(hash_seed[0]),         &(label[0]), label_len);
	memcpy(&(hash_seed[label_len]), &(seed[0]),  seed_len);

	p_hash_sha256(secret, secret_len, hash_seed, len, dest, dest_len);
}

void tls_prf_sha384(const uint8_t *secret,
		    const uint32_t secret_len,
		    const uint8_t *label,
		    const uint32_t label_len,
		    const uint8_t *seed,
		    const uint32_t seed_len,
		    uint8_t *dest,
		    const uint32_t dest_len)
{
	/* PRF SHA384 use following algorithm (RFC 5246 section 5).
	 *
	 *   PRF(secret, label, seed) = p_<hash>(secret, label + seed)
	 *
	 * */
	const uint32_t len = label_len + seed_len;

	uint8_t hash_seed[len];

	memcpy(&(hash_seed[0]),         &(label[0]), label_len);
	memcpy(&(hash_seed[label_len]), &(seed[0]),  seed_len);

	p_hash_sha384(secret, secret_len, hash_seed, len, dest, dest_len);
}
