/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_alert.h"

#include <string.h>

/* for Key */
#include <aicrypto/ok_x509.h>

/* for PKCS12 and P12_get_usercert */
#include <aicrypto/ok_pkcs12.h>

/**
 * write RSA encrypted premaster to the send data.
 */
static int32_t write_rsa_encypted_premaster(TLS *tls,
					    struct tls_hs_msg *msg);

/**
 * write body of client key exchange to the send data.
 */
static int32_t write_exchange_keys(TLS *tls, struct tls_hs_msg *msg);

/**
 * make dummy premaster secret if faild to read data of client key
 * exchnage from client.
 */
static int32_t make_dummy_premster(TLS *tls, const struct tls_hs_msg *msg);

/**
 * read RSA encrypted premaster from the received handshake.
 */
static int32_t read_rsa_encypted_premaster(TLS *tls,
					   const struct tls_hs_msg *msg,
					   const uint32_t offset);

/**
 * read body of client key exchange from the received handshake.
 */
static int32_t read_exchange_keys(TLS *tls,
				  const struct tls_hs_msg *msg,
				  const uint32_t offset);

struct tls_hs_msg * tls_hs_ckeyexc_compose(TLS *tls);

static int32_t write_rsa_encypted_premaster(TLS *tls,
					    struct tls_hs_msg *msg) {
	/* RSA-Encrypted Premaster Secret Message has following
	 * structure (see RFC 5246 section 7.4.7.).
	 *
	 *   struct {
	 *     public-key-encrypted PreMasterSecret pre-master_secret;
	 *   } EncryptedPreMasterSecret;
	 *
	 *   struct {
	 *     ProtocolVersion client_version;
	 *     opaque random[46];
	 *   } PreMasterSecret;
	 *
	 * public-key-encrypted means real data has following
	 * structure.
	 *
	 *   opaque vector <0..2^16-1>
	 */

	/* RFC 5246 says.
	 *
	 *   he version number in the PreMasterSecret is the version
	 *   offered by the client in the ClientHello.client_version,
	 *   not the version negotiated for the connection.  This
	 *   feature is designed to prevent rollback attacks. */
	tls->premaster_secret[0] = tls->client_version.major;
	tls->premaster_secret[1] = tls->client_version.minor;

	if (! tls_util_get_random(&(tls->premaster_secret[2]), 46)) {
		TLS_DPRINTF("tls_util_get_random");
		OK_set_error(ERR_ST_TLS_GET_RANDOM,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 0, NULL);
		return -1;
	}
	tls->premaster_secret_len = TLS_PREMASTER_SECRET_RSA_LEN;

	/* RSA encryption */
	Cert *cert;
	if ((cert = P12_get_usercert(tls->pkcs12_server)) == NULL) {
		TLS_DPRINTF("P12_get_usercert");
		OK_set_error(ERR_ST_TLS_P12_GET_USERCERT,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 1, NULL);
		return -1;
	}

	Key *pubkey;
	if ((pubkey = cert->pubkey) == NULL) {
		TLS_DPRINTF("pubkey");
		OK_set_error(ERR_ST_TLS_CET_NO_PUBKEY,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 2, NULL);
		return -1;
	}

	/* check allocated buffer length. */
	if (msg->max < (uint32_t) (2 + pubkey->size)) {
		TLS_DPRINTF("msg->max");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 3, NULL);
		return -1;
	}

	/* 2 is length bytes of pkcs1 encrypted text. */
	const int32_t length_bytes = 2;

	int32_t len;
	if ((len = tls_util_pkcs1_encrypt(pubkey,
					  &(msg->msg[length_bytes]),
					  &(tls->premaster_secret[0]),
					  TLS_PREMASTER_SECRET_RSA_LEN)) < 0) {
		TLS_DPRINTF("tls_util_pkcs1_encrypt");
		OK_set_error(ERR_ST_TLS_PKCS1_ENCRYPT,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 4, NULL);
		return -1;
	}

	const int32_t len_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (len > len_max) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 14, NULL);
		return -1;
	}

	if (! tls_hs_msg_write_2(msg, len)) {
		TLS_DPRINTF("tls_hs_msg_write_2");
		return -1;
	}

	msg->len = len + length_bytes;

	return len + length_bytes;
}

static int32_t write_exchange_keys(TLS *tls, struct tls_hs_msg *msg) {
	int32_t len;

	switch(tls->keymethod) {
	case TLS_KXC_RSA:
		if ((len = write_rsa_encypted_premaster(tls, msg)) < 0) {
			return -1;
		}
		break;

	case TLS_KXC_DHE_DSS:
	case TLS_KXC_DHE_RSA:
	case TLS_KXC_DH_DSS:
	case TLS_KXC_DH_RSA:
	case TLS_KXC_DH_anon:
		/* ClientDiffieHellmanPublic */
		/* TODO: do implemetation */
		return -1;

	case TLS_KXC_ECDHE_ECDSA:
	case TLS_KXC_ECDHE_RSA:
	case TLS_KXC_ECDH_anon:
	case TLS_KXC_ECDH_ECDSA:
	case TLS_KXC_ECDH_RSA:
		/*
		 * RFC4492 section 5.7.
		 *  struct {
		 *      select (KeyExchangeAlgorithm) {
		 *          case ec_diffie_hellman: ClientECDiffieHellmanPublic;
		 *      } exchange_keys;
		 *  } ClientKeyExchange;
		 */
		len = tls_hs_ecdh_ckeyexc_write_exchange_keys(tls, msg);
		if (len < 0) {
			return -1;
		}
		break;

	default:
		return -1;
	}

	return len;
}

static int32_t make_dummy_premster(TLS *tls, const struct tls_hs_msg *msg) {
	uint32_t buff_len = tls->premaster_secret_len;
	uint8_t buff[buff_len];

	TLS_DPRINTF("ckeyexc: make dummy premaster secret");
	if (tls_util_get_random(&(buff[0]), buff_len) == false) {
		memset(&(buff[0]), 0x0U, buff_len);
	}

	memcpy(&(tls->premaster_secret[0]), &(buff[0]), buff_len);

	return msg->len;
}

static int32_t read_rsa_encypted_premaster(TLS *tls,
					   const struct tls_hs_msg *msg,
					   const uint32_t offset) {
	/* RFC 5246 section 7.4.7.1 says
	 *
	 *   In any case, a TLS server MUST NOT generate an alert if
	 *   processing an RSA-encrypted premaster secret message fails,
	 *   or the version number is not as expected.  Instead, it MUST
	 *   continue the handshake with a randomly generated premaster
	 *   secret.  It may be useful to log the real cause of failure
	 *   for troubleshooting purposes; however, care must be taken
	 *   to avoid leaking the information to an attacker (through,
	 *   e.g., timing, log files, or other channels.)
	 */
	int32_t read_bytes = 0;

	const int32_t length_bytes = 2;
	if (msg->len < (offset + length_bytes)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 5, NULL);
		return make_dummy_premster(tls, msg);
	}
	read_bytes += length_bytes;

	int32_t len = tls_util_read_2(&(msg->msg[0]));
	TLS_DPRINTF("ckeyexc: len = %d", len);

	/*
	 * RFC5246 7.4.7.1.  RSA-Encrypted Premaster Secret Message
	 *
	 *       struct {
	 *           public-key-encrypted PreMasterSecret pre_master_secret;
	 *       } EncryptedPreMasterSecret;
	 *
	 * RFC5246 4.7.  Cryptographic Attributes
	 *
	 *                 A public-key-encrypted element is encoded as an opaque
	 *    vector <0..2^16-1>, where the length is specified by the encryption
	 *    algorithm and key.
	 */

	Key *privkey;
	if ((privkey = P12_get_privatekey(tls->pkcs12_server)) == NULL) {
		TLS_DPRINTF("P12_get_privatekey");
		OK_set_error(ERR_ST_TLS_P12_GET_PRIVATEKEY,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 6, NULL);
		return make_dummy_premster(tls, msg);
	}

	/* check whether the key type is private key of rsa. */
	if (privkey->key_type != KEY_RSA_PRV) {
		OK_set_error(ERR_ST_TLS_NOT_RSA_PRIV,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 7, NULL);
		return make_dummy_premster(tls, msg);
	}

	if (msg->len < (offset + read_bytes + len)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 8, NULL);
		return make_dummy_premster(tls, msg);
	}

	uint8_t buff[msg->len];
	int32_t seclen;
	if ((seclen = tls_util_pkcs1_decrypt(privkey,
					     &(buff[0]),
					     &(msg->msg[offset + read_bytes]),
					     len)) < 0) {
		OK_set_error(ERR_ST_TLS_PKCS1_DECRYPT,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 9, NULL);
		return make_dummy_premster(tls, msg);
	}
	read_bytes += len;

	if (seclen != TLS_PREMASTER_SECRET_RSA_LEN) {
		TLS_DPRINTF("premaster_secret length");
		OK_set_error(ERR_ST_TLS_INVALID_PREMASTER_SECRET,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 10, NULL);
		return make_dummy_premster(tls, msg);
	}
	tls->premaster_secret_len = seclen;

	/* RFC 5246 section 7.4.7.1 says.
	 *
	 *   If ClientHello.client_version is TLS 1.1 or higher, server
	 *   implementations MUST check the version number as described
	 *   in the note below.  If the version number is TLS 1.0 or
	 *   earlier, server implementations SHOULD check the version
	 *   number, but MAY have a configuration option to disable the
	 *   check.  Note that if the check fails, the PreMasterSecret
	 *   SHOULD be randomized as described below. */
	bool check_version = true;
	switch(tls->client_version.minor) {
	case TLS_MINOR_SSL30:
	case TLS_MINOR_TLS10:
		/* TODO: implement option */
		break;

	case TLS_MINOR_TLS11:
	case TLS_MINOR_TLS12:
	default:
		break;
	}

	if ((check_version == true) &&
	    (buff[0] != tls->client_version.major) &&
	    (buff[1] != tls->client_version.minor)) {
		TLS_DPRINTF("client_version unmatch.");
		OK_set_error(ERR_ST_TLS_INVALID_PREMASTER_SECRET,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 11, NULL);
		return make_dummy_premster(tls, msg);
	}

	memcpy(&(tls->premaster_secret[0]),
	       &(buff[0]), tls->premaster_secret_len);

	return read_bytes;
}

static int32_t read_exchange_keys(TLS *tls,
				  const struct tls_hs_msg *msg,
				  const uint32_t offset) {
	int32_t len = -1;
	switch(tls->keymethod) {
	case TLS_KXC_RSA:
		if ((len = read_rsa_encypted_premaster(tls, msg, offset)) < 0) {
			return -1;
		}
		break;

	case TLS_KXC_ECDHE_ECDSA:
	case TLS_KXC_ECDHE_RSA:
	case TLS_KXC_ECDH_anon:
	case TLS_KXC_ECDH_ECDSA:
	case TLS_KXC_ECDH_RSA:
		len = tls_hs_ecdh_ckeyexc_read_exchange_keys(tls, msg, offset);
		break;

	default:
		assert(!"unknown key algorithm selected.");
	}

	return len;
}

struct tls_hs_msg * tls_hs_ckeyexc_compose(TLS *tls) {
	struct tls_hs_msg *msg;

	if ((msg = tls_hs_msg_init()) == NULL) {
		TLS_DPRINTF("tls_hs_msg_init");
		return NULL;
	}

	msg->type = TLS_HANDSHAKE_CLIENT_KEY_EXCHANGE;

	if (write_exchange_keys(tls, msg) < 0) {
		goto failed;
	}

	return msg;

failed:
	tls_hs_msg_free(msg);
	return NULL;
}

bool tls_hs_ckeyexc_parse(TLS *tls,
			  struct tls_hs_msg *msg) {
	uint32_t offset = 0;

	if (msg->type != TLS_HANDSHAKE_CLIENT_KEY_EXCHANGE) {
		OK_set_error(ERR_ST_TLS_UNEXPECTED_MSG,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 12, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
		return false;
	}

	int32_t keylen;
	if ((keylen = read_exchange_keys(tls, msg, offset)) < 0) {
		return false;
	}
	offset += keylen;

	if (msg->len != offset) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CKEYEXC + 13, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
	}

	return true;
}
