/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_session.h"
#include "tls_cipher.h"
#include "tls_prf.h"
#include "tls_hkdf.h"
#include "tls_handshake.h"

#include <string.h>

/* for AESkey_new and AES_set_iv */
#include <aicrypto/ok_aes.h>

/* for CHACHAkey_new and CHACHA_set_iv */
#include <aicrypto/nrg_chacha.h>

/* for DES3key_new_c and DES3_set_iv */
#include <aicrypto/ok_des.h>

#ifdef HAVE_ARC4
/* for RC4key_new */
#include <aicrypto/ok_rc4.h>
#endif

/* for OK_do_digest. */
#include <aicrypto/ok_tool.h>

enum context_type {
	ZERO_LENGTH_STRING,
	MESSAGE_HASH,
};

/**
 * set write_MAC_key that was gotten by the generated keyblock to the
 * tls structure.
 *
 * @param[in/out] client TLS connetion object to set client_write_MAC_key.
 * @param[in/out] server TLS connetion object to set server_write_MAC_key.
 * @param[in] block generated keyblock.
 * @param[in] block_len size of block.
 * Return an error if the set value is shorter than (len * 2).
 * @param[in] len write_MAC_key length.
 * When len is 0 does not set write_MAC_key.
 */
static int32_t make_mac_key(struct tls_connection *client,
			    struct tls_connection *server,
			    const uint8_t *block,
			    const uint32_t block_len,
			    const uint32_t len);

/**
 * get the key of aicrypto that is generated by specified keyblock.
 *
 * every cipher suite, key of aicrypto is expressed by write_key and
 * wirte_IV of keyblock.
 */
static Key * make_aicrypto_key(const struct tls_cipher_param param,
			       uint8_t *enckey,
			       uint8_t *iv);

/**
 * set the key of aicrypto to the tls structure.
 *
 * the key that is generated is expressed by structure of aicrypto.
 * that key is a composition of write_key and write_IV. this is done by
 * make_aicrypto_key function.
 */
static int32_t make_enc_key(const struct tls_cipher_param param,
			    struct tls_connection *client,
			    struct tls_connection *server,
			    const uint8_t *block,
			    const uint32_t block_len);

/**
 * set information of keyblock to the tls structure.
 *
 * this function suppose that a generation of keyblock is done by
 * upperside function.
 */
static bool make_key_block(TLS *tls,
			   const struct tls_cipher_param param,
			   const uint8_t *block,
			   const uint32_t block_len);

/**
 * compose HkdfLabel structure.
 */
static int32_t compose_hkdf_label(uint8_t *buf, char *label, size_t label_len,
				  uint8_t *context, size_t context_len,
				  size_t hash_len);

/**
 * perform Derive-Secret(Secret, Label, Messages) process.
 *
 * RFC8446 7.1.  Key Schedule
 *
 *        Derive-Secret(Secret, Label, Messages) =
 *             HKDF-Expand-Label(Secret, Label,
 *                               Transcript-Hash(Messages), Hash.length)
 */
static bool derive_secret(TLS *tls, hkdf_hash_t *hash, uint8_t *secret,
			  char *label, size_t label_len, enum context_type type,
			  uint8_t *okm);

/**
 * perform Derive-Secret(Secret, Label, Messages) process for
 * client_application_traffic_secret_0, server_application_traffic_secret_0 and
 * exporter_master_secret.
 *
 * RFC8446 7.1.  Key Schedule
 *
 *              +-----> Derive-Secret(., "c ap traffic",
 *              |                     ClientHello...server Finished)
 *              |                     = client_application_traffic_secret_0
 *              |
 *              +-----> Derive-Secret(., "s ap traffic",
 *              |                     ClientHello...server Finished)
 *              |                     = server_application_traffic_secret_0
 *              |
 *              +-----> Derive-Secret(., "exp master",
 *              |                     ClientHello...server Finished)
 *              |                     = exporter_master_secret
 */
static bool derive_application_secret(TLS *tls, hkdf_hash_t *hash,
				      uint8_t *secret, char *label,
				      size_t label_len, uint8_t *okm);

/**
 * get hash for hkdf.
 */
static bool get_hkdf_hash(TLS *tls, hkdf_hash_t *hash);

static int32_t make_mac_key(struct tls_connection *client,
			    struct tls_connection *server,
			    const uint8_t *block,
			    const uint32_t block_len,
			    const uint32_t len) {
	int32_t offset = 0;

	if (len == 0) {
		return offset;
	}

	if (block_len < (2 * len)) {
		OK_set_error(ERR_ST_TLS_INVALID_BLOCK_LENGTH,
			     ERR_LC_TLS4, ERR_PT_TLS_KEY + 0, NULL);
		return -1;
	}

	if ((client->mac_key = malloc(1 * len)) == NULL) {
		TLS_DPRINTF("key: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS4, ERR_PT_TLS_KEY + 1, NULL);
		return -1;
	}

	if ((server->mac_key = malloc(1 * len)) == NULL) {
		TLS_DPRINTF("key: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS4, ERR_PT_TLS_KEY + 2, NULL);
		free(client->mac_key);
		return -1;
	}

	memcpy(&(client->mac_key[0]), &(block[offset]), len);
	offset += len;

	memcpy(&(server->mac_key[0]), &(block[offset]), len);
	offset += len;

	return offset;
}

static Key * make_aicrypto_key(const struct tls_cipher_param param,
			       uint8_t *enckey,
			       uint8_t *iv) {
	Key *key = NULL;

	uint32_t enc_key_length = param.enc_key_length;
	uint32_t block_length   = param.block_length;

	TLS_DPRINTF("key: enc_key_length %d", enc_key_length);
	TLS_DPRINTF("key: block_length   %d", block_length);

	switch (param.cipher_algorithm) {
	case TLS_BULK_CIPHER_AES:
		if ((key = (Key *)AESkey_new(enc_key_length, enckey,
					     block_length)) == NULL) {
			OK_set_error(ERR_ST_TLS_AESKEY_NEW, ERR_LC_TLS4,
				     ERR_PT_TLS_KEY + 3, NULL);
			return NULL;
		}

		/*
		 * {client, server} write IV is required some AEAD ciphers.
		 */
		AES_set_iv((Key_AES *)key, iv);
		break;

	case TLS_BULK_CIPHER_CHACHA20:
		if ((key = (Key *)ChaChakey_new(enckey, iv)) == NULL) {
			OK_set_error(ERR_ST_TLS_CHACHAKEY_NEW,
				     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 0, NULL);
			return NULL;
		}
		break;

	case TLS_BULK_CIPHER_3DES:
		if ((key = (Key *)DES3key_new_c(enc_key_length,
						enckey)) == NULL) {
			OK_set_error(ERR_ST_TLS_DES3KEY_NEW, ERR_LC_TLS4,
				     ERR_PT_TLS_KEY + 4, NULL);
			return NULL;
		}

#if 0
		/* this following code is necessary in SSL 3.0
		   and TLS 1.0. */
		DES3_set_iv((Key_3DES *)key, iv);
#endif /* 0 */
		break;

#ifdef HAVE_ARC4
	case TLS_BULK_CIPHER_RC4:
		if ((key = (Key *)RC4key_new(enc_key_length,
					     enckey)) == NULL) {
			OK_set_error(ERR_ST_TLS_RC4KEY_NEW, ERR_LC_TLS4,
				     ERR_PT_TLS_KEY + 5, NULL);
			return NULL;
		}
		break;
#endif

	default:
		assert(!"selected cipher algorithm is unknown.");
	}

	return key;
}

static int32_t make_enc_key(const struct tls_cipher_param param,
			    struct tls_connection *client,
			    struct tls_connection *server,
			    const uint8_t *block,
			    const uint32_t block_len) {
	int32_t offset = 0;

	int32_t enc_key_length  = param.enc_key_length;
	int32_t fixed_iv_length = param.fixed_iv_length;
	/*
	 * NOTE: assume iv length of AES and 3DES is equal to each block length
	 * here. if new algorithm appears and its iv length is longer than its
	 * block length, iv length will used as buffer length.
	 */
	int32_t iv_length = fixed_iv_length >= param.block_length ?
	    fixed_iv_length : param.block_length;

	/* calculate client_write_key and server_write_key. */
	uint8_t client_key[enc_key_length];
	uint8_t server_key[enc_key_length];

	if (block_len < (uint32_t)(offset + (enc_key_length * 2))) {
		OK_set_error(ERR_ST_TLS_INVALID_BLOCK_LENGTH,
			     ERR_LC_TLS4, ERR_PT_TLS_KEY + 6, NULL);
		return -1;
	}

	if (enc_key_length > 0) {
		memcpy(&(client_key[0]), &(block[offset]), enc_key_length);
		offset += enc_key_length;

		memcpy(&(server_key[0]), &(block[offset]), enc_key_length);
		offset += enc_key_length;
	}

	/* calculate client_write_iv and server_write_iv. */
	uint8_t client_iv[iv_length];
	uint8_t server_iv[iv_length];

	memset(client_iv, 0, iv_length);
	memset(server_iv, 0, iv_length);

	if (block_len < (uint32_t)(offset + (fixed_iv_length * 2))) {
		OK_set_error(ERR_ST_TLS_INVALID_BLOCK_LENGTH,
			     ERR_LC_TLS4, ERR_PT_TLS_KEY + 7, NULL);
		goto err;
	}

	if (fixed_iv_length > 0) {
		memcpy(&(client_iv[0]), &(block[offset]), fixed_iv_length);
		offset += fixed_iv_length;

		memcpy(&(server_iv[0]), &(block[offset]), fixed_iv_length);
		offset += fixed_iv_length;
	}

	/* make aicrypto's encryption key structure. */
	if ((client->key = make_aicrypto_key(param,
					     client_key, client_iv)) == NULL) {
		goto err;
	}

	if ((server->key = make_aicrypto_key(param,
					     server_key, server_iv)) == NULL) {
		Key_free(client->key);
		goto err;
	}

	return offset;

err:
	free(client->mac_key);
	free(server->mac_key);
	return -1;
}

static bool make_key_block(TLS *tls,
			   const struct tls_cipher_param param,
			   const uint8_t *block,
			   const uint32_t block_len) {
	uint32_t offset = 0;

	struct tls_connection *client;
	struct tls_connection *server;

	/*       client  server
	 *        |        |
	 *   read | <----- | write ... server write
	 *        |        |
	 *  write | -----> | read  ... client write
	 *        |        |
	 */
	if (tls->entity == TLS_CONNECT_CLIENT) {
		client = &(tls->active_write);
		server = &(tls->active_read);
	} else {
		client = &(tls->active_read);
		server = &(tls->active_write);
	}

	/* generate {client,server}_write_MAC_key. */
	int32_t  skip;
	uint32_t mlen = param.mac_key_length;
	if ((skip = make_mac_key(client, server,
				 &(block[0]), block_len, mlen)) < 0) {
		return false;
	}
	offset += skip;

	/* generate {client,server}_{key,IV}. */
	/* if bulk cipher algorithm is NULL, it it not necessary to
	 * generate {client,server}_key and {client,server}_IV (because
	 * NULL means no encryption). */
	if (param.cipher_algorithm == TLS_BULK_CIPHER_NULL) {
		return true;
	}

	if ((skip = make_enc_key(param, client, server,
				 &(block[offset]), block_len - offset)) < 0) {
		return false;
	}
	offset += skip;

	return true;
}

static int32_t compose_hkdf_label(uint8_t *buf, char *label, size_t label_len,
				  uint8_t *context, size_t context_len,
				  size_t hash_len) {
	uint16_t offset = 0;

	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *        struct {
	 *            uint16 length = Length;
	 *            opaque label<7..255> = "tls13 " + Label;
	 *            opaque context<0..255> = Context;
	 *        } HkdfLabel;
	 */

	if (hash_len > TLS_HKDF_LABEL_HASH_SIZE_MAX) {
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_HASH,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 1, NULL);
		return -1;
	}

	tls_util_write_2(buf, hash_len);
	offset += 2;

	char lbuf[TLS_HKDF_INNER_LABEL_SIZE_MAX + 1];
	memset(lbuf, 0, sizeof(lbuf));

	char label_header[] = "tls13 ";
	strcpy(lbuf, label_header);

	/*
	 * total label length is at least 7, so length of argument "label"
	 * must be greater than 0.
	 */
	if (label_len == 0) {
		OK_set_error(ERR_ST_TLS_INVALID_HKDF_LABEL_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 2, NULL);
		return -1;
	}

	size_t label_header_len = sizeof(label_header) - 1;
	if (label_header_len + label_len > TLS_HKDF_INNER_LABEL_SIZE_MAX) {
		OK_set_error(ERR_ST_TLS_INVALID_HKDF_LABEL_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 3, NULL);
		return -1;
	}
	strncpy(&lbuf[label_header_len], label, label_len);

	size_t lbuf_len = strlen(lbuf);
	buf[offset] = lbuf_len;
	offset += 1;

	memcpy(&buf[offset], lbuf, lbuf_len);
	offset += lbuf_len;

	if (context_len > TLS_HKDF_LABEL_CONTEXT_SIZE_MAX) {
		OK_set_error(ERR_ST_TLS_INVALID_HKDF_CONTEXT_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 4, NULL);
		return -1;
	}

	buf[offset] = context_len;
	offset += 1;

	if (context_len > 0) {
		memcpy(&buf[offset], context, context_len);
		offset += context_len;
	}

	return offset;
}

bool tls_key_hkdf_expand_label(hkdf_hash_t *hash, uint8_t *secret,
			      char *label, size_t label_len,
			      uint8_t *context, size_t context_len,
			      size_t length, uint8_t *okm) {
	uint8_t hkdflabel[TLS_HKDF_LABEL_SIZE_MAX];
	int32_t hkdflabel_len;
	if ((hkdflabel_len = compose_hkdf_label(hkdflabel, label, label_len,
					    context, context_len,
					    length)) < 0) {
		return false;
	}

	if (HKDF_Expand(hash, secret, hash->len,
		    hkdflabel, hkdflabel_len,
		    okm, length) < 0) {
		return false;
	}

	return true;
}

static bool derive_secret(TLS *tls, hkdf_hash_t *hash, uint8_t *secret,
			  char *label, size_t label_len, enum context_type type,
			  uint8_t *okm) {
	uint8_t context[hash->len];
	size_t context_len = sizeof(context);;

	enum tls_hs_sighash_hash_algo tls_hash;
	if ((tls_hash = tls_cipher_hashalgo(tls->pending->cipher_suite))
	    == TLS_HASH_ALGO_NONE) {
		return false;
	}

	int32_t ai_hash;
	uint8_t data[hash->len];
	int ret_len;
	switch (type) {
	case ZERO_LENGTH_STRING:
		if ((ai_hash = tls_hs_sighash_get_ai_hash_type(tls_hash)) < 0) {
			return false;
		}

		if (OK_do_digest(ai_hash, data, 0, context, &ret_len) == NULL) {
			TLS_DPRINTF("OK_do_digest");
			return false;
		}
		break;

	case MESSAGE_HASH:
		tls_hs_hash_get_digest(tls_hash, tls, context);
		break;

	default:
		TLS_DPRINTF("unknown context type %d", type);
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 5, NULL);
		return false;
	}

	if (tls_key_hkdf_expand_label(hash, secret, label, label_len,
			  context, context_len, hash->len, okm) == false) {
		return false;
	}

	return true;
}

static bool derive_application_secret(TLS *tls, hkdf_hash_t *hash,
				      uint8_t *secret, char *label,
				      size_t label_len, uint8_t *okm) {
	uint8_t *context = tls->application_secret_context;
	size_t context_len = hash->len;

	if (tls_key_hkdf_expand_label(hash, secret, label, label_len,
			  context, context_len, hash->len, okm) == false) {
		return false;
	}

	return true;
}

static bool get_hkdf_hash(TLS *tls, hkdf_hash_t *hash) {
	enum tls_hs_sighash_hash_algo tls_hash;
	if ((tls_hash = tls_cipher_hashalgo(tls->pending->cipher_suite))
	    == TLS_HASH_ALGO_NONE) {
		return false;
	}

	int32_t ai_hash;
	if ((ai_hash = tls_hs_sighash_get_ai_hash_type(tls_hash)) < 0) {
		return false;
	}

	if (HKDF_get_hash(ai_hash, hash) < 0) {
		return false;
	}

	return true;
}

bool tls_key_derive_early_secret(TLS *tls) {
	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *    The Hash function used by Transcript-Hash and HKDF is the cipher
	 *    suite hash algorithm.
	 */
	hkdf_hash_t hash;
	if (get_hkdf_hash(tls, &hash) == false) {
		return false;
	}

	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *   If a given secret is not available, then the 0-value consisting of a
	 *   string of Hash.length bytes set to zeros is used.  Note that this
	 *   does not mean skipping rounds, so if PSK is not in use, Early Secret
	 *   will still be HKDF-Extract(0, 0).
	 */
	uint8_t salt[hash.len];
	uint8_t ikm[hash.len];
	uint8_t prk[hash.len];

	memset(salt, 0, sizeof(salt));
	memset(ikm, 0, sizeof(ikm));

	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *    This produces a full key derivation schedule shown in the diagram
	 *    below.  In this diagram, the following formatting conventions apply:
	 *
	 *    -  HKDF-Extract is drawn as taking the Salt argument from the top and
	 *       the IKM argument from the left, with its output to the bottom and
	 *       the name of the output on the right.
	 *
	 *    -  Derive-Secret's Secret argument is indicated by the incoming
	 *       arrow.  For instance, the Early Secret is the Secret for
	 *       generating the client_early_traffic_secret.
	 *
	 *    -  "0" indicates a string of Hash.length bytes set to zero.
	 *
	 *              0
	 *              |
	 *              v
	 *    PSK ->  HKDF-Extract = Early Secret
	 *              |
	 *              +-----> Derive-Secret(., "ext binder" | "res binder", "")
	 *              |                     = binder_key
	 *              |
	 *              +-----> Derive-Secret(., "c e traffic", ClientHello)
	 *              |                     = client_early_traffic_secret
	 *              |
	 *              +-----> Derive-Secret(., "e exp master", ClientHello)
	 *              |                     = early_exporter_master_secret
	 *              v
	 *        Derive-Secret(., "derived", "")
	 */
	if (HKDF_Extract(&hash, salt, hash.len, ikm, hash.len, prk) < 0) {
		return false;
	}

	if ((tls->early_secret = malloc(hash.len)) == NULL) {
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 6, NULL);
		return false;
	}

	memcpy(tls->early_secret, prk, hash.len);

	return true;
}

bool tls_key_derive_handshake_secret(TLS *tls) {
	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *    The Hash function used by Transcript-Hash and HKDF is the cipher
	 *    suite hash algorithm.
	 */
	hkdf_hash_t hash;
	if (get_hkdf_hash(tls, &hash) == false) {
		return false;
	}

	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *              0
	 *              |
	 *              v
	 *    PSK ->  HKDF-Extract = Early Secret
	 *              |
	 *              +-----> Derive-Secret(., "ext binder" | "res binder", "")
	 *              |                     = binder_key
	 *              |
	 *              +-----> Derive-Secret(., "c e traffic", ClientHello)
	 *              |                     = client_early_traffic_secret
	 *              |
	 *              +-----> Derive-Secret(., "e exp master", ClientHello)
	 *              |                     = early_exporter_master_secret
	 *              v
	 *        Derive-Secret(., "derived", "")
	 *              |
	 *              v
	 *    (EC)DHE -> HKDF-Extract = Handshake Secret
	 */
	char label[] = "derived";
	size_t label_len = sizeof(label) - 1;
	uint8_t okm[hash.len];
	if (derive_secret(tls, &hash, tls->early_secret, label, label_len,
			  ZERO_LENGTH_STRING, okm) == false) {
		return false;
	}

	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *    Once all the values which are to be derived from a given secret have
	 *    been computed, that secret SHOULD be erased.
	 */
	free(tls->early_secret);
	tls->early_secret = NULL;

	uint8_t *salt = okm;
	size_t salt_len = hash.len;
	uint8_t *ikm;
	size_t ikm_len;
	switch (tls->ecdh->namedcurve) {
	case TLS_NAMED_GROUP_SECP256R1:
	case TLS_NAMED_GROUP_SECP384R1:
	case TLS_NAMED_GROUP_SECP521R1:
	case TLS_NAMED_GROUP_X25519:
	case TLS_NAMED_GROUP_X448:
		if (tls_hs_ecdh_calc_shared_secret(tls, tls->ecdh) != 0) {
			TLS_DPRINTF("key: tls_hs_ecdh_calc_shared_secret");
			return false;
		}
		ikm = tls->premaster_secret;
		ikm_len = tls->premaster_secret_len;
		break;

	default:
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_CURVE,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 7, NULL);
		return false;
	}

	uint8_t prk[hash.len];
	if (HKDF_Extract(&hash, salt, salt_len, ikm, ikm_len, prk) < 0) {
		return false;
	}

	if ((tls->handshake_secret = malloc(hash.len)) == NULL) {
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 8, NULL);
		return false;
	}

	memcpy(tls->handshake_secret, prk, hash.len);

	return true;
}

bool tls_key_derive_handshake_traffic_secrets(TLS *tls) {

	hkdf_hash_t hash;
	if (get_hkdf_hash(tls, &hash) == false) {
		return false;
	}

	struct tls_connection *client;
	struct tls_connection *server;
	switch (tls->entity) {
	case TLS_CONNECT_CLIENT:
		client = &(tls->active_write);
		server = &(tls->active_read);
		break;

	case TLS_CONNECT_SERVER:
		client = &(tls->active_read);
		server = &(tls->active_write);
		break;

	default:
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 9, NULL);
		return false;
	}

	char client_label[] = "c hs traffic";
	size_t client_label_len = sizeof(client_label) - 1;
	uint8_t okm[hash.len];
	if (derive_secret(tls, &hash, tls->handshake_secret,
			  client_label, client_label_len, MESSAGE_HASH, okm)
	    == false) {
		return false;
	}

	if ((client->secret = malloc(hash.len)) == NULL) {
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 10, NULL);
		return false;
	}

	memcpy(client->secret, okm, hash.len);

	char server_label[] = "s hs traffic";
	size_t server_label_len = sizeof(server_label) - 1;
	if (derive_secret(tls, &hash, tls->handshake_secret,
			  server_label, server_label_len, MESSAGE_HASH, okm)
	    == false) {
		return false;
	}

	if ((server->secret = malloc(hash.len)) == NULL) {
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 11, NULL);
		return false;
	}

	memcpy(server->secret, okm, hash.len);

	return true;
}

bool tls_key_derive_master_secret(TLS *tls) {
	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *    The Hash function used by Transcript-Hash and HKDF is the cipher
	 *    suite hash algorithm.
	 */
	hkdf_hash_t hash;
	if (get_hkdf_hash(tls, &hash) == false) {
		return false;
	}

	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *    (EC)DHE -> HKDF-Extract = Handshake Secret
	 *              |
	 *              +-----> Derive-Secret(., "c hs traffic",
	 *              |                     ClientHello...ServerHello)
	 *              |                     = client_handshake_traffic_secret
	 *              |
	 *              +-----> Derive-Secret(., "s hs traffic",
	 *              |                     ClientHello...ServerHello)
	 *              |                     = server_handshake_traffic_secret
	 *              v
	 *        Derive-Secret(., "derived", "")
	 *              |
	 *              v
	 *    0 -> HKDF-Extract = Master Secret
	 */
	char label[] = "derived";
	size_t label_len = sizeof(label) - 1;
	uint8_t okm[hash.len];
	if (derive_secret(tls, &hash, tls->handshake_secret, label, label_len,
			  ZERO_LENGTH_STRING, okm) == false) {
		return false;
	}

	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *    Once all the values which are to be derived from a given secret have
	 *    been computed, that secret SHOULD be erased.
	 */
	free(tls->handshake_secret);
	tls->handshake_secret = NULL;

	uint8_t *salt = okm;
	size_t salt_len = hash.len;
	uint8_t ikm[hash.len];
	size_t ikm_len = hash.len;
	uint8_t prk[hash.len];

	memset(ikm, 0, sizeof(ikm));

	if (HKDF_Extract(&hash, salt, salt_len, ikm, ikm_len, prk) < 0) {
		return false;
	}

	/*
	 * NOTE: master secret in tls_session_param is fixed size arrary of
	 * 48 byte. SHA256 and SHA384 fit this but SHA512 can't do it and
	 * cause buffer overrun.
	 */
	memcpy(tls->pending->master_secret, prk, hash.len);

	return true;
}

bool tls_key_derive_application_traffic_secret(TLS *tls,
					      char *label,
					      struct tls_connection *connection) {
	hkdf_hash_t hash;
	if (get_hkdf_hash(tls, &hash) == false) {
		return false;
	}

	size_t label_len = strlen(label);
	uint8_t okm[hash.len];
	if (derive_application_secret(tls, &hash, tls->pending->master_secret,
				  label, label_len, okm) == false) {
		return false;
	}

	uint8_t *old_secret;
	uint8_t *new_secret;
	old_secret = connection->secret;
	if ((new_secret = malloc(hash.len)) == NULL) {
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 12, NULL);
		return false;
	}

	memcpy(new_secret, okm, hash.len);
	connection->secret = new_secret;
	free(old_secret);

	return true;
}

bool tls_key_derive_application_traffic_secrets(TLS *tls) {
	hkdf_hash_t hash;
	if (get_hkdf_hash(tls, &hash) == false) {
		return false;
	}

	struct tls_connection *client;
	struct tls_connection *server;
	switch (tls->entity) {
	case TLS_CONNECT_CLIENT:
		client = &(tls->active_write);
		server = &(tls->active_read);
		break;

	case TLS_CONNECT_SERVER:
		client = &(tls->active_read);
		server = &(tls->active_write);
		break;

	default:
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 13, NULL);
		return false;
	}

	char client_label[] = "c ap traffic";
	size_t client_label_len = sizeof(client_label) - 1;
	uint8_t okm[hash.len];
	if (derive_application_secret(tls, &hash, tls->pending->master_secret,
				  client_label, client_label_len, okm) == false) {
		return false;
	}

	uint8_t *old_secret;
	uint8_t *new_secret;
	old_secret = client->secret;
	if ((new_secret = malloc(hash.len)) == NULL) {
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 14, NULL);
		return false;
	}

	memcpy(new_secret, okm, hash.len);
	client->secret = new_secret;
	free(old_secret);

	char server_label[] = "s ap traffic";
	size_t server_label_len = sizeof(server_label) - 1;
	if (derive_application_secret(tls, &hash, tls->pending->master_secret,
				  server_label, server_label_len, okm) == false) {
		return false;
	}

	old_secret = server->secret;
	if ((new_secret = malloc(hash.len)) == NULL) {
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 15, NULL);
		return false;
	}

	memcpy(new_secret, okm, hash.len);
	server->secret = new_secret;
	free(old_secret);

	return true;
}

bool tls_key_make_master_secret(TLS *tls) {
	/* generate master_secret to use following algorithm (RFC 5246
	 * section 8.1).
	 *
	 *   For all key exchange methods, the same algorithm is used to
	 *   convert the pre_master_secret into the master_secret.  The
	 *   pre_master_secret should be deleted from memory once the
	 *   master_secret has been computed.
	 *
	 *      master_secret = PRF(pre_master_secret, "master secret",
	 *                          ClientHello.random + ServerHello.random)
	 *                          [0..47];
	 *
	 *   The master secret is always exactly 48 bytes in length.
	 *   The length of the premaster secret will vary depending on
	 *   key exchange method.
	 */

	/* get parameter of selected cipher suite. */
	struct tls_cipher_param param;
	if (! tls_cipher_param_set(tls->pending->cipher_suite, &param)) {
		return false;
	}

	const uint32_t secret_len = tls->premaster_secret_len;

	const uint8_t  label[]   = "master secret";
	const uint32_t label_len = sizeof (label) - 1;

	const uint32_t crandlen = sizeof (tls->client_random);
	const uint32_t srandlen = sizeof (tls->server_random);
	const uint32_t seed_len = crandlen + srandlen;

	uint8_t seed[seed_len];

	/* client is first. server is next. */
	memcpy(&(seed[0]),        &(tls->client_random[0]), crandlen);
	memcpy(&(seed[crandlen]), &(tls->server_random[0]), srandlen);

	switch(param.prf_algorithm) {
	case TLS_PRF_SHA256:
		tls_prf_sha256(tls->premaster_secret,
			       secret_len,
			       label,
			       label_len,
			       seed,
			       seed_len,
			       &(tls->pending->master_secret[0]),
			       sizeof (tls->pending->master_secret));
		break;

	case TLS_PRF_SHA384:
		tls_prf_sha384(tls->premaster_secret,
			       secret_len,
			       label,
			       label_len,
			       seed,
			       seed_len,
			       &(tls->pending->master_secret[0]),
			       sizeof (tls->pending->master_secret));
		break;

	default:
		assert(!"selected prf algorithm is unknown");
	}

	/* RFC 5246 section 8 says as follows. So, clear value of
	 * premaster secret.
	 *
	 *   The pre_master_secret should be deleted from memory once
	 *   the master_secret has been computed.
	 */
	memset(&(tls->premaster_secret[0]), 0, secret_len);

	return true;
}

bool tls_key_derive_application_traffic_secret_n(
	TLS *tls, struct tls_connection *connection) {
	hkdf_hash_t hash;
	if (get_hkdf_hash(tls, &hash) == false) {
		return false;
	}

	char label[] = "traffic upd";
	size_t label_len = sizeof(label) - 1;
	uint8_t okm[hash.len];
	if (tls_key_hkdf_expand_label(&hash, connection->secret, label,
	    label_len, NULL, 0, hash.len, okm) == false) {
		return false;
	}

	uint8_t *old_secret;
	uint8_t *new_secret;
	old_secret = connection->secret;
	if ((new_secret = malloc(hash.len)) == NULL) {
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS6, ERR_PT_TLS_KEY2 + 16, NULL);
		return false;
	}

	memcpy(new_secret, okm, hash.len);
	connection->secret = new_secret;
	free(old_secret);

	return true;
}

bool tls_key_make_key_block(TLS *tls) {
	/* generate key block to use following algorithm (RFC 5246
	 * section 6.3).
	 *
	 *   To generate the key material, compute
	 *
	 *      key_block = PRF(SecurityParameters.master_secret,
	 *                      "key expansion",
	 *                      SecurityParameters.server_random +
	 *                      SecurityParameters.client_random);
	 *
	 *   until enough output has been generated.  Then, the
	 *   key_block is partitioned as follows:
	 *
	 *      client_write_MAC_key[SecurityParameters.mac_key_length]
	 *      server_write_MAC_key[SecurityParameters.mac_key_length]
	 *      client_write_key[SecurityParameters.enc_key_length]
	 *      server_write_key[SecurityParameters.enc_key_length]
	 *      client_write_IV[SecurityParameters.fixed_iv_length]
	 *      server_write_IV[SecurityParameters.fixed_iv_length]
	 */

	/* get parameter of selected cipher suite. */
	struct tls_cipher_param param;
	if (! tls_cipher_param_set(tls->pending->cipher_suite, &param)) {
		return false;
	}

	uint32_t block_len = (param.mac_key_length  + param.mac_key_length  +
			      param.enc_key_length  + param.enc_key_length  +
			      param.fixed_iv_length + param.fixed_iv_length);

	uint8_t block[block_len];

	uint8_t  label[]   = "key expansion";
	uint32_t label_len = sizeof (label) - 1;

	const uint32_t srandlen = sizeof (tls->server_random);
	const uint32_t crandlen = sizeof (tls->client_random);
	const uint32_t seed_len = srandlen + crandlen;

	uint8_t seed[seed_len];

	/* server is first. client is next. */
	memcpy(&(seed[0]),        &(tls->server_random[0]), srandlen);
	memcpy(&(seed[srandlen]), &(tls->client_random[0]), crandlen);

	switch(param.prf_algorithm) {
	case TLS_PRF_SHA256:
		tls_prf_sha256(tls->pending->master_secret,
			       sizeof (tls->pending->master_secret),
			       label,
			       label_len,
			       seed,
			       seed_len,
			       &(block[0]),
			       block_len);
		break;

	case TLS_PRF_SHA384:
		tls_prf_sha384(tls->pending->master_secret,
			       sizeof (tls->pending->master_secret),
			       label,
			       label_len,
			       seed,
			       seed_len,
			       &(block[0]),
			       block_len);
		break;

	default:
		assert(!"selected prf algorithm is unknown");
	}

	/* set generated keybock to tls structure. */
	if (! make_key_block(tls, param, block, block_len)) {
		return false;
	}

	return true;
}

bool tls_key_make_traffic_key(TLS *tls, struct tls_connection *connection) {
	/*
	 * RFC8446 7.1.  Key Schedule
	 *
	 *    The Hash function used by Transcript-Hash and HKDF is the cipher
	 *    suite hash algorithm.
	 */
	hkdf_hash_t hash;
	if (get_hkdf_hash(tls, &hash) == false) {
		return false;
	}

	/*
	 * RFC8446 7.3.  Traffic Key Calculation
	 *
	 *    The traffic keying material is generated from an input traffic secret
	 *    value using:
	 *
	 *    [sender]_write_key = HKDF-Expand-Label(Secret, "key", "", key_length)
	 *    [sender]_write_iv  = HKDF-Expand-Label(Secret, "iv", "", iv_length)
	 *
	 *    [sender] denotes the sending side.  The value of Secret for each
	 *    record type is shown in the table below.
	 *
	 *        +-------------------+---------------------------------------+
	 *        | Record Type       | Secret                                |
	 *        +-------------------+---------------------------------------+
	 *        | 0-RTT Application | client_early_traffic_secret           |
	 *        |                   |                                       |
	 *        | Handshake         | [sender]_handshake_traffic_secret     |
	 *        |                   |                                       |
	 *        | Application Data  | [sender]_application_traffic_secret_N |
	 *        +-------------------+---------------------------------------+
	 */
	struct tls_cipher_param param = connection->cipher;
	uint8_t *secret = connection->secret;
	uint8_t context[hash.len];
	size_t context_len = 0;

	char key_label[] = "key";
	size_t key_label_len = sizeof(key_label) - 1;
	size_t key_len = param.enc_key_length;
	uint8_t key[key_len];
	if (tls_key_hkdf_expand_label(&hash, secret, key_label, key_label_len,
			  context, context_len, key_len, key) == false) {
		return false;
	}

	char iv_label[] = "iv";
	size_t iv_label_len = sizeof(iv_label) - 1;
	size_t iv_len = param.fixed_iv_length + param.record_iv_length;
	size_t iv_buf_len = iv_len >= param.block_length ?
	    iv_len : param.block_length;
	uint8_t iv[iv_buf_len];
	if (tls_key_hkdf_expand_label(&hash, secret, iv_label, iv_label_len,
			  context, context_len, iv_len, iv) == false) {
		return false;
	}

	/*
	 * RFC8446 5.3.  Per-Record Nonce
	 *
	 *                                               Each sequence number is
	 *    set to zero at the beginning of a connection and whenever the key is
	 *    changed; the first record transmitted under a particular traffic key
	 *    MUST use sequence number 0.
	 */
	Key *old_key = connection->key;
	Key *new_key;
	if ((new_key = make_aicrypto_key(param, key, iv)) == NULL) {
		return false;
	}
	connection->key = new_key;
	connection->seqnum = 0;

	if (old_key != NULL) {
		Key_free(old_key);
	}

	return true;
}
