/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_prf.h"
#include "tls_alert.h"
#include "tls_hkdf.h"
#include "tls_cipher.h"

#include <string.h>

/* for SHA256init, SHA256update and SHA256final. */
#include <aicrypto/ok_sha2.h>

/* for HMAC_SHA256, HMAC_SHA384. */
#include <aicrypto/ok_hmac.h>

/**
 * enum for indicating the direction of finished handshake protocol
 * message.
 */
enum tls_finished_stat {
	/** write status */
	FINISHED_WRITE_STAT,

	/** read status  */
	FINISHED_READ_STAT
};

/** label for client that is used by making a verify data. */
static const uint8_t TLS_FINISHED_CLIENT_LABEL[] = "client finished";

/** label for server that is used by making a verify data. */
static const uint8_t TLS_FINISHED_SERVER_LABEL[] = "server finished";

/** length of label (TLS_FINISHED_CLIENT_LABEL and
 * TLS_FINISHED_SERVER_LABEL). */
static const uint32_t TLS_FINISHED_LABEL_LEN = (
	sizeof (TLS_FINISHED_CLIENT_LABEL) - 1);

/**
 * get label that is used by client to make a verify data.
 */
static int32_t get_label_client(const enum tls_finished_stat status,
				uint8_t *label);

/**
 * get label that is used by server to make a verify data.
 */
static int32_t get_label_server(const enum tls_finished_stat status,
				uint8_t *label);

/**
 * get label to make a verify data.
 */
static int32_t get_label(const TLS *tls,
			 const enum tls_finished_stat status,
			 uint8_t *label);

/**
 * make a verify data by prf_sha256 algorithm.
 */
static void do_prf_sha256(TLS *tls,
			  const uint8_t *label, const int32_t label_len,
			  uint8_t *dest, const uint32_t dest_len);

/**
 * make a verify data by prf_sha384 algorithm.
 */
static void do_prf_sha384(TLS *tls,
			  const uint8_t *label, const int32_t label_len,
			  uint8_t *dest, const uint32_t dest_len);

/**
 * make a verify data by hmac_sha256 algorithm.
 */
static bool do_hmac_sha256(TLS *tls, uint8_t *base_key,
			   char *label, size_t label_len,
			   uint8_t *dest);

/**
 * make a verify data by hmac_sha384 algorithm.
 */
static bool do_hmac_sha384(TLS *tls, uint8_t *base_key,
			   char *label, size_t label_len,
			   uint8_t *dest);

/**
 * write verify data to the send data.
 */
static int32_t write_verify_data_tls12(TLS *tls, struct tls_hs_msg *msg);

static int32_t write_verify_data_tls13(TLS *tls, struct tls_hs_msg *msg);

static int32_t write_verify_data(TLS *tls, struct tls_hs_msg *msg);

/**
 * read a verify data from received handshake. and verify it.
 */
static int32_t read_verify_data_tls12(TLS *tls,
				const struct tls_hs_msg *msg,
				const uint32_t offset);

static int32_t read_verify_data_tls13(TLS *tls,
				const struct tls_hs_msg *msg,
				const uint32_t offset);

static int32_t read_verify_data(TLS *tls,
				const struct tls_hs_msg *msg,
				const uint32_t offset);

static int32_t get_label_client(const enum tls_finished_stat status,
				uint8_t *label) {
	int32_t len = -1;

	switch(status) {
	case FINISHED_WRITE_STAT:
		len = sizeof (TLS_FINISHED_CLIENT_LABEL) - 1;
		memcpy(&(label[0]), TLS_FINISHED_CLIENT_LABEL, len);
		return len;

	case FINISHED_READ_STAT:
		len = sizeof (TLS_FINISHED_SERVER_LABEL) - 1;
		memcpy(&(label[0]), TLS_FINISHED_SERVER_LABEL, len);
		return len;

	default:
		assert(!"unknown label stat");
	}

	return len;
}

static int32_t get_label_server(const enum tls_finished_stat status,
				uint8_t *label) {
	int32_t len = -1;

	switch(status) {
	case FINISHED_WRITE_STAT:
		len = sizeof (TLS_FINISHED_SERVER_LABEL) - 1;
		memcpy(&(label[0]), TLS_FINISHED_SERVER_LABEL, len);
		return len;

	case FINISHED_READ_STAT:
		len = sizeof (TLS_FINISHED_CLIENT_LABEL) - 1;
		memcpy(&(label[0]), TLS_FINISHED_CLIENT_LABEL, len);
		return len;

	default:
		assert(!"unknown label stat");
	}

	return len;
}

static int32_t get_label(const TLS *tls,
			 const enum tls_finished_stat status,
			 uint8_t *label) {
	switch(tls->entity) {
	case TLS_CONNECT_CLIENT:
		return get_label_client(status, label);

	case TLS_CONNECT_SERVER:
		return get_label_server(status, label);

	default:
		assert(!"unknown connection entity.");
	}

	OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
		     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 4,
		     NULL);
	return -1;
}

static void do_prf_sha256(TLS *tls,
			  const uint8_t *label, const int32_t label_len,
			  uint8_t *dest, const uint32_t dest_len) {
	enum tls_hs_sighash_hash_algo hash = TLS_HASH_ALGO_SHA256;

	/* hash algorithm is determined by prf algorithm. */
	uint8_t seed_len = tls_hs_sighash_get_hash_size(hash);
	uint8_t seed[seed_len];

	tls_hs_hash_get_digest(hash, tls, seed);

	tls_prf_sha256(tls->pending->master_secret,
		       sizeof (tls->pending->master_secret),
		       label,
		       label_len,
		       seed,
		       seed_len,
		       dest,
		       dest_len);
}

static void do_prf_sha384(TLS *tls,
			  const uint8_t *label, const int32_t label_len,
			  uint8_t *dest, const uint32_t dest_len) {
	enum tls_hs_sighash_hash_algo hash = TLS_HASH_ALGO_SHA384;

	/* hash algorithm is determined by prf algorithm. */
	uint8_t seed_len = tls_hs_sighash_get_hash_size(hash);
	uint8_t seed[seed_len];

	tls_hs_hash_get_digest(hash, tls, seed);

	tls_prf_sha384(tls->pending->master_secret,
		       sizeof (tls->pending->master_secret),
		       label,
		       label_len,
		       seed,
		       seed_len,
		       dest,
		       dest_len);
}

static bool do_hmac_sha256(TLS *tls, uint8_t *base_key,
			   char *label, size_t label_len,
			   uint8_t *dest) {
	/*
	 * RFC8446 4.4.4.  Finished
	 *
	 *    The verify_data value is computed as follows:
	 *
	 *       verify_data =
	 *           HMAC(finished_key,
	 *                Transcript-Hash(Handshake Context,
	 *                                Certificate*, CertificateVerify*))
	 *
	 *    HMAC [RFC2104] uses the Hash algorithm for the handshake.  As noted
	 *    above, the HMAC input can generally be implemented by a running hash,
	 *    i.e., just the handshake hash at this point.
	 *
	 *    In previous versions of TLS, the verify_data was always 12 octets
	 *    long.  In TLS 1.3, it is the size of the HMAC output for the Hash
	 *    used for the handshake.
	 */
	enum tls_hs_sighash_hash_algo tls_hash;
	if ((tls_hash = tls_cipher_hashalgo(tls->pending->cipher_suite))
	    == TLS_HASH_ALGO_NONE) {
		return false;
	}

	int32_t ai_hash;
	if ((ai_hash = tls_hs_sighash_get_ai_hash_type(tls_hash)) < 0) {
		return false;
	}

	hkdf_hash_t hash;
	if (HKDF_get_hash(ai_hash, &hash) < 0) {
		return false;
	}

	/*
	 * RFC8446 4.4.4.  Finished
	 *
	 *    finished_key =
	 *        HKDF-Expand-Label(BaseKey, "finished", "", Hash.length)
	 */
	uint8_t *context = NULL;
	size_t context_len = 0;
	uint8_t finished_key[hash.len];
	if (tls_key_hkdf_expand_label(&hash, base_key, label, label_len,
			  context, context_len, hash.len, finished_key) == false) {
		return false;
	}

	uint8_t ts_hash[hash.len];
	tls_hs_hash_get_digest(tls_hash, tls, ts_hash);
	HMAC_SHA256(hash.len, ts_hash,
		    hash.len, finished_key, ((unsigned char *) dest));

	return true;
}

static bool do_hmac_sha384(TLS *tls, uint8_t *base_key,
			   char *label, size_t label_len,
			   uint8_t *dest) {
	/*
	 * RFC8446 4.4.4.  Finished
	 *
	 *    The verify_data value is computed as follows:
	 *
	 *       verify_data =
	 *           HMAC(finished_key,
	 *                Transcript-Hash(Handshake Context,
	 *                                Certificate*, CertificateVerify*))
	 *
	 *    HMAC [RFC2104] uses the Hash algorithm for the handshake.  As noted
	 *    above, the HMAC input can generally be implemented by a running hash,
	 *    i.e., just the handshake hash at this point.
	 *
	 *    In previous versions of TLS, the verify_data was always 12 octets
	 *    long.  In TLS 1.3, it is the size of the HMAC output for the Hash
	 *    used for the handshake.
	 */
	enum tls_hs_sighash_hash_algo tls_hash;
	if ((tls_hash = tls_cipher_hashalgo(tls->pending->cipher_suite))
	    == TLS_HASH_ALGO_NONE) {
		return false;
	}

	int32_t ai_hash;
	if ((ai_hash = tls_hs_sighash_get_ai_hash_type(tls_hash)) < 0) {
		return false;
	}

	hkdf_hash_t hash;
	if (HKDF_get_hash(ai_hash, &hash) < 0) {
		return false;
	}

	/*
	 * RFC8446 4.4.4.  Finished
	 *
	 *    finished_key =
	 *        HKDF-Expand-Label(BaseKey, "finished", "", Hash.length)
	 */
	uint8_t *context = NULL;
	size_t context_len = 0;
	uint8_t finished_key[hash.len];
	if (tls_key_hkdf_expand_label(&hash, base_key, label, label_len,
			  context, context_len, hash.len, finished_key) == false) {
		return false;
	}

	uint8_t ts_hash[hash.len];
	tls_hs_hash_get_digest(tls_hash, tls, ts_hash);
	HMAC_SHA384(sizeof(ts_hash), ts_hash,
		    hash.len, finished_key, ((unsigned char *) dest));

	return true;
}

static int32_t write_verify_data_tls12(TLS *tls, struct tls_hs_msg *msg) {
	const int32_t verify_length = tls->active_write.cipher.verify_length;

	uint8_t label[TLS_FINISHED_LABEL_LEN];
	const int32_t label_len = get_label(tls, FINISHED_WRITE_STAT, label);
	if (label_len < 0) {
		/* should not happen */
		return -1;
	}

	switch(tls->active_write.cipher.prf_algorithm) {
	case TLS_PRF_SHA256:
		do_prf_sha256(tls, label, label_len,
			      &(msg->msg[0]), verify_length);
		msg->len = verify_length;
		break;

	case TLS_PRF_SHA384:
		do_prf_sha384(tls, label, label_len,
			      &(msg->msg[0]), verify_length);
		msg->len = verify_length;
		break;

	default:
		assert(!"PRF algorithm is unknown.");
	}

	return verify_length;
}

static int32_t write_verify_data_tls13(TLS *tls, struct tls_hs_msg *msg) {
	const int32_t verify_length = tls->active_write.cipher.verify_length;

	char label[] = "finished";
	size_t label_len = sizeof(label) - 1;

	/*
	 * RFC8446 4.4.  Authentication Messages
	 *
	 *    The following table defines the Handshake Context and MAC Base Key
	 *    for each scenario:
	 *
	 *    +-----------+-------------------------+-----------------------------+
	 *    | Mode      | Handshake Context       | Base Key                    |
	 *    +-----------+-------------------------+-----------------------------+
	 *    | Server    | ClientHello ... later   | server_handshake_traffic_   |
	 *    |           | of EncryptedExtensions/ | secret                      |
	 *    |           | CertificateRequest      |                             |
	 *    |           |                         |                             |
	 *    | Client    | ClientHello ... later   | client_handshake_traffic_   |
	 *    |           | of server               | secret                      |
	 *    |           | Finished/EndOfEarlyData |                             |
	 *    |           |                         |                             |
	 *    | Post-     | ClientHello ... client  | client_application_traffic_ |
	 *    | Handshake | Finished +              | secret_N                    |
	 *    |           | CertificateRequest      |                             |
	 *    +-----------+-------------------------+-----------------------------+
	 */
	switch (tls->active_write.cipher.prf_algorithm) {
	case TLS_HKDF_SHA256:
		if (do_hmac_sha256(tls, tls->active_write.secret, label,
				   label_len, &(msg->msg[0])) == false) {
			return -1;
		}
		msg->len = verify_length;
		break;

	case TLS_HKDF_SHA384:
		if (do_hmac_sha384(tls, tls->active_write.secret, label,
				   label_len, &(msg->msg[0])) == false) {
			return -1;
		}
		msg->len = verify_length;
		break;

	default:
		assert(!"PRF algorithm is unknown.");
	}

	return verify_length;
}

static int32_t write_verify_data(TLS *tls, struct tls_hs_msg *msg) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch(version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
		/* Not implemented */
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 5,
			     NULL);
		return -1;

	case TLS_VER_TLS12:
		return write_verify_data_tls12(tls, msg);

	case TLS_VER_TLS13:
		return write_verify_data_tls13(tls, msg);

	default:
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 6,
			     NULL);
		return -1;
	}
}

static int32_t read_verify_data_tls12(TLS *tls,
				const struct tls_hs_msg *msg,
				const uint32_t offset) {
	uint8_t label[TLS_FINISHED_LABEL_LEN];
	const int32_t label_len = get_label(tls, FINISHED_READ_STAT, label);
	if (label_len < 0) {
		/* should not happen */
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	const uint32_t verify_length = tls->active_read.cipher.verify_length;
	uint8_t dest[verify_length];

	switch(tls->active_read.cipher.prf_algorithm) {
	case TLS_PRF_SHA256:
		do_prf_sha256(tls, label, label_len, dest, verify_length);
		break;

	case TLS_PRF_SHA384:
		do_prf_sha384(tls, label, label_len, dest, verify_length);
		break;

	default:
		assert(!"PRF algorithm is unknown.");
	}

	/*
	 * RFC5246 7.4.9.  Finished
	 *
	 *           opaque verify_data[verify_data_length];
	 */
	if (msg->len < verify_length) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 0, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	if (memcmp(&(msg->msg[offset]), &(dest[0]), verify_length) != 0) {
		OK_set_error(ERR_ST_TLS_INVALID_DIGEST,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 1, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECRYPT_ERROR);
		return -1;
	}

	return verify_length;
}

static int32_t read_verify_data_tls13(TLS *tls,
				const struct tls_hs_msg *msg,
				const uint32_t offset) {
	const uint32_t verify_length = tls->active_read.cipher.verify_length;

	char label[] = "finished";
	size_t label_len = sizeof(label) - 1;

	/*
	 * RFC8446 4.4.  Authentication Messages
	 *
	 *    The following table defines the Handshake Context and MAC Base Key
	 *    for each scenario:
	 *
	 *    +-----------+-------------------------+-----------------------------+
	 *    | Mode      | Handshake Context       | Base Key                    |
	 *    +-----------+-------------------------+-----------------------------+
	 *    | Server    | ClientHello ... later   | server_handshake_traffic_   |
	 *    |           | of EncryptedExtensions/ | secret                      |
	 *    |           | CertificateRequest      |                             |
	 *    |           |                         |                             |
	 *    | Client    | ClientHello ... later   | client_handshake_traffic_   |
	 *    |           | of server               | secret                      |
	 *    |           | Finished/EndOfEarlyData |                             |
	 *    |           |                         |                             |
	 *    | Post-     | ClientHello ... client  | client_application_traffic_ |
	 *    | Handshake | Finished +              | secret_N                    |
	 *    |           | CertificateRequest      |                             |
	 *    +-----------+-------------------------+-----------------------------+
	 */
	uint8_t dest[verify_length];
	switch (tls->active_read.cipher.prf_algorithm) {
	case TLS_HKDF_SHA256:
		if (do_hmac_sha256(tls, tls->active_read.secret,
				   label, label_len, dest) == false) {
			return -1;
		}
		break;

	case TLS_HKDF_SHA384:
		if (do_hmac_sha384(tls, tls->active_read.secret,
				   label, label_len, dest) == false) {
			return -1;
		}
		break;

	default:
		assert(!"PRF algorithm is unknown.");
	}

	/*
	 * RFC8446 4.4.4.  Finished
	 *
	 *           opaque verify_data[Hash.length];
	 */
	if (msg->len < verify_length) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 7,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/*
	 * RFC8446 4.4.4.  Finished
	 *
	 *    Recipients of Finished messages MUST verify that the contents are
	 *    correct and if incorrect MUST terminate the connection with a
	 *    "decrypt_error" alert.
	 */
	if (memcmp(&(msg->msg[offset]), &(dest[0]), verify_length) != 0) {
		TLS_DPRINTF("verify data mismatch");
		OK_set_error(ERR_ST_TLS_INVALID_DIGEST,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 8,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECRYPT_ERROR);
		return -1;
	}

	return verify_length;
}

static int32_t read_verify_data(TLS *tls,
				const struct tls_hs_msg *msg,
				const uint32_t offset) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch(version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
		/* Not implemented */
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 9,
			     NULL);
		return -1;

	case TLS_VER_TLS12:
		return read_verify_data_tls12(tls, msg, offset);

	case TLS_VER_TLS13:
		return read_verify_data_tls13(tls, msg, offset);

	default:
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 10,
			     NULL);
		return -1;
	}
}

struct tls_hs_msg * tls_hs_finished_compose(TLS *tls) {
	uint32_t offset = 0;

	struct tls_hs_msg *msg;

	if ((msg = tls_hs_msg_init()) == NULL) {
		TLS_DPRINTF("tls_hs_msg_init");
		return NULL;
	}

	msg->type = TLS_HANDSHAKE_FINISHED;

	int32_t verify_length;
	if ((verify_length = write_verify_data(tls, msg)) < 0) {
		TLS_DPRINTF("write_verify_data");
		goto failed;
	}
	offset += verify_length;

	return msg;

failed:
	tls_hs_msg_free(msg);
	return NULL;
}

bool tls_hs_finished_parse(TLS *tls, struct tls_hs_msg *msg) {
	uint32_t offset = 0;

	if (msg->type != TLS_HANDSHAKE_FINISHED) {
		OK_set_error(ERR_ST_TLS_UNEXPECTED_MSG,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 2, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
		return false;
	}

	int32_t verify_length;
	if ((verify_length = read_verify_data(tls, msg, offset)) < 0) {
		TLS_DPRINTF("read_verify_data");
		return false;
	}
	offset += verify_length;

	if (msg->len != offset) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_FINISHED + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return false;
	}

	return true;
}
