/*
 * Copyright (c) 2016-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_handshake_ecdh.h"
#include "tls_alert.h"
#include "tls_cert.h"

/**
 * ECParameters
 *
 *  struct {
 *      ECCurveType    curve_type;
 *      select (curve_type) {
 *          case explicit_prime:
 *              opaque      prime_p <1..2^8-1>;
 *              ECCurve     curve;
 *              ECPoint     base;
 *              opaque      order <1..2^8-1>;
 *              opaque      cofactor <1..2^8-1>;
 *          case explicit_char2:
 *              uint16      m;
 *              ECBasisType basis;
 *              select (basis) {
 *                  case ec_trinomial:
 *                      opaque  k <1..2^8-1>;
 *                  case ec_pentanomial:
 *                      opaque  k1 <1..2^8-1>;
 *                      opaque  k2 <1..2^8-1>;
 *                      opaque  k3 <1..2^8-1>;
 *              };
 *              ECCurve     curve;
 *              ECPoint     base;
 *              opaque      order <1..2^8-1>;
 *              opaque      cofactor <1..2^8-1>;
 *          case named_curve:
 *              NamedCurve namedcurve;
 *      };
 *  } ECParameters;
 *
 * Currentry the named_curve is only supported.
 */
static int32_t write_curve_params(TLS *tls, struct tls_hs_msg *msg);

/**
 * ECPoint
 *
 * struct {
 *     opaque point <1..2^8-1>;
 * } ECPoint;
 *
 * point:   This is the byte string representation of an elliptic curve
 *    point following the conversion routine in Section 4.3.6 of ANSI
 *    X9.62 [7].  This byte string may represent an elliptic curve point
 *    in uncompressed or compressed format; it MUST conform to what the
 *    client has requested through a Supported Point Formats Extension
 *    if this extension was used.
 */
static int32_t write_public(struct tls_hs_ecdh_key *key,
			    struct tls_hs_msg *msg);

/**
 * Read ECParameters.
 */
static int32_t read_curve_params(TLS *tls, struct tls_hs_msg *msg,
				 const uint32_t offset);

/**
 * Read ECPoint.
 */
static int32_t read_public(struct tls_hs_ecdh_key *key,
			   const struct tls_hs_msg *msg,
			   const uint32_t offset);

/**
 * Calculate shared secret using ECDSA key.
 */
static int32_t calc_ecdsa_shared_secret(TLS *tls, struct tls_hs_ecdh *ctx);

/**
 * Calculate shared secret using X25519 key.
 */
static int32_t calc_x25519_shared_secret(TLS *tls, struct tls_hs_ecdh *ctx);

/**
 * Calculate shared secret using X448 key.
 */
static int32_t calc_x448_shared_secret(TLS *tls, struct tls_hs_ecdh *ctx);

static int32_t write_curve_params(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	/* ServerKeyExchange message has following structure.
	 *
	 * curve_type == expllicit_prime
	 * | curve_type                (1) |
	 * | prime_p                   (x) |
	 * | curve                     (x) |
	 * | base                      (x) |
	 * | order                     (x) |
	 * | cofactor                  (x) |
	 * | public                    (x) |
	 *
	 * curve_type == expllicit_char2
	 * | curve_type                (1) |
	 * | m                         (x) |
	 * | basis                     (x) |
	 * | k                         (x) | 1
	 * | k1                        (x) | 2
	 * | k2                        (x) | 2
	 * | k3                        (x) | 2
	 * | curve                     (x) |
	 * | base                      (x) |
	 * | order                     (x) |
	 * | cofactor                  (x) |
	 * | public                    (x) |
	 *
	 * curve_type == named_curve
	 * | curve_type                (1) |
	 * | namedcurve                (x) |
	 * | public                    (x) |
	 */

	/* write dummy curve_type */
	const int32_t pos = msg->len;
	const int32_t len_bytes = 1;
	if (tls_hs_msg_write_1(msg, 0) == false) {
		return -1;
	}
	offset += len_bytes;

	switch (tls->ecdh->curvetype) {
	case TLS_ECC_CTYPE_EXPLICIT_PRIME:
		/* TODO: do implemetation. */
		break;
	case TLS_ECC_CTYPE_EXPLICIT_CHAR2:
		/* TODO: do implemetation. */
		break;
	case TLS_ECC_CTYPE_NAMED_CURVE:
		if (tls_hs_msg_write_2(msg, tls->ecdh->namedcurve) == false) {
			return -1;
		}
		offset += 2;
		break;
	default:
		/* ECCurveType is wrong. */
		offset   -= len_bytes;
		msg->len -= len_bytes;
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_CURVE_TYPE,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 6, NULL);
		return -1;
	}

	msg->msg[pos] = tls->ecdh->curvetype;

	return offset;
}

static int32_t write_public(struct tls_hs_ecdh_key *key, struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	if (tls_hs_ecdhkey_set_to_ecpoint(key) != true) {
		return -1;
	}

	/*
	 *  struct {
	 *          opaque point <1..2^8-1>;
	 *  } ECPoint;
	 */
	if (tls_hs_msg_write_1(msg, key->ecpoint_len) == false) {
		return -1;
	}
	offset += 1;

	if (tls_hs_msg_write_n(msg, key->ecpoint, key->ecpoint_len) == false) {
		return -1;
	}
	offset += key->ecpoint_len;

	return offset;
}

static int32_t read_curve_params(TLS *tls, struct tls_hs_msg *msg,
				 const uint32_t offset)
{
	uint32_t read_bytes = 0;

	/* read ECCurveType */
	tls->ecdh->curvetype = msg->msg[offset];
	read_bytes += 1;

	switch (tls->ecdh->curvetype) {
	case TLS_ECC_CTYPE_EXPLICIT_PRIME:
		/* TODO: do implemetation. */
		break;
	case TLS_ECC_CTYPE_EXPLICIT_CHAR2:
		/* TODO: do implemetation. */
		break;
	case TLS_ECC_CTYPE_NAMED_CURVE:
		/*
		 *  NamedCurve namedcurve;
		 */
		tls->ecdh->namedcurve =
			tls_util_read_2(&(msg->msg[offset + read_bytes]));
		/* check ? */
		read_bytes += 2;
		break;
	default:
		/* unsupported */
		msg->len -= 1;
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_CURVE_TYPE,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 7, NULL);
		return -1;
	}

	return read_bytes;
}

static int32_t read_public(struct tls_hs_ecdh_key *key,
			   const struct tls_hs_msg *msg,
			   const uint32_t offset)
{
	uint32_t read_bytes = 0;

	/*
	 * | length                    (1) |
	 * | public                    (n) | 1 <= n < 256
	 */
	const int32_t length_bytes = 1;
	if (msg->len < (offset + length_bytes)) {
		return -1;
	}
	read_bytes += length_bytes;

	const uint8_t public_length = msg->msg[offset];

	if (msg->len < (offset + read_bytes + public_length)) {
		return -1;
	}

	/* TODO if (public_length < 1 || public_length > 255) */

	key->ecpoint_len = public_length;
	memcpy(key->ecpoint, &(msg->msg[offset + read_bytes]),
			     public_length);

	read_bytes += public_length;

	return read_bytes;
}

static int32_t calc_ecdsa_shared_secret(TLS *tls, struct tls_hs_ecdh *ctx)
{
	ECp *ret;

	if ((ret = ECp_new()) == NULL) {
		OK_set_error(ERR_ST_TLS_ECP_NEW,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 0, NULL);
		return -1;
	}

	/*
	 * RFC4492 section 5.10:
	 *
	 * All ECDH calculations (including parameter and key generation as well
	 * as the shared secret calculation) are performed according to [6]
	 * using the ECKAS-DH1 scheme with the identity map as key derivation
	 * function (KDF), so that the premaster secret is the x-coordinate of
	 * the ECDH shared secret elliptic curve point represented as an octet
	 * string.  ...
	 */
	Pubkey_ECDSA *peer_pubkey = (Pubkey_ECDSA *)ctx->peer_pubkey;
	Prvkey_ECDSA *my_prvkey = (Prvkey_ECDSA *)ctx->my_prvkey;
	if (ECp_multi(peer_pubkey->E, peer_pubkey->W,
					my_prvkey->k, ret) != 0) {
		OK_set_error(ERR_ST_TLS_ECP_MULTI,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 1, NULL);
		ECp_free(ret);
		return -1;
	}

	/*
	 * RFC4492 section 5.10:
	 *
	 * ...      Note that this octet string (Z in IEEE 1363 terminology) as
	 * output by FE2OSP, the Field Element to Octet String Conversion
	 * Primitive, has constant length for any given field; leading zeros
	 * found in this octet string MUST NOT be truncated.
	 */
	memset(tls->premaster_secret, 0, sizeof(tls->premaster_secret));
	int byte = LN_now_byte(ret->x);
	/* FE2OSP */
	if (LN_get_num_c(ret->x, byte, tls->premaster_secret) != 0) {
		OK_set_error(ERR_ST_TLS_FE2OSP,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 2, NULL);
		ECp_free(ret);
		return -1;
	}
	tls->premaster_secret_len = byte;

	ECp_free(ret);
	return 0;
}

static int32_t calc_x25519_shared_secret(TLS *tls, struct tls_hs_ecdh *ctx)
{
	Pubkey_X25519 *peer_pubkey = (Pubkey_X25519 *)ctx->peer_pubkey;
	Prvkey_X25519 *my_prvkey = (Prvkey_X25519 *)ctx->my_prvkey;
	if (X25519_generate_shared_secret(my_prvkey, peer_pubkey,
					  tls->premaster_secret) < 0) {
		TLS_DPRINTF("X25519_generate_shared_secret");
		OK_set_error(ERR_ST_TLS_X25519_SHARED_SECRET,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 8, NULL);
		return -1;
	}
	tls->premaster_secret_len = X25519_KEY_LENGTH;

	return 0;
}

static int32_t calc_x448_shared_secret(TLS *tls, struct tls_hs_ecdh *ctx)
{
	Pubkey_X448 *peer_pubkey = (Pubkey_X448 *)ctx->peer_pubkey;
	Prvkey_X448 *my_prvkey = (Prvkey_X448 *)ctx->my_prvkey;
	if (X448_generate_shared_secret(my_prvkey, peer_pubkey,
					tls->premaster_secret) < 0) {
		TLS_DPRINTF("X448_generate_shared_secret");
		OK_set_error(ERR_ST_TLS_X448_SHARED_SECRET,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 9, NULL);
		return -1;
	}
	tls->premaster_secret_len = X448_KEY_LENGTH;

	return 0;
}

int32_t tls_hs_ecdh_calc_shared_secret(TLS *tls, struct tls_hs_ecdh *ctx)
{
	switch (ctx->namedcurve) {
	case TLS_ECC_CURVE_SECP192R1:
	case TLS_ECC_CURVE_SECP224R1:
	case TLS_ECC_CURVE_SECP256R1:
	case TLS_ECC_CURVE_SECP384R1:
	case TLS_ECC_CURVE_SECP521R1:
		return calc_ecdsa_shared_secret(tls, ctx);

	case TLS_ECC_CURVE_X25519:
		return calc_x25519_shared_secret(tls, ctx);

	case TLS_ECC_CURVE_X448:
		return calc_x448_shared_secret(tls, ctx);

	default:
		TLS_DPRINTF("unknown named curve");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_CURVE,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 10, NULL);
		return -1;
	}
}

enum tls_hs_named_curve tls_hs_ecdh_get_named_curve(int curve)
{
	switch (curve) {
	case ECP_X962_prime192v1:
		return TLS_ECC_CURVE_SECP192R1;

	case ECP_secp224r1:
		return TLS_ECC_CURVE_SECP224R1;

	case ECP_X962_prime256v1:
		return TLS_ECC_CURVE_SECP256R1;

	case ECP_secp384r1:
		return TLS_ECC_CURVE_SECP384R1;

	case ECP_secp521r1:
		return TLS_ECC_CURVE_SECP521R1;

	default:
		/* unsupported */
		assert(!"unsupported curve.");
		return 0;
	}
}

enum tls_hs_ecc_ec_point_format tls_hs_ecdh_get_point_format(int curve_type)
{
	switch (curve_type) {
	case ECP_ORG_char2Param:
		return TLS_ECC_PF_ANSIX962_COMPRESSED_CHAR2;

	case ECP_ORG_primeParam:
		return TLS_ECC_PF_ANSIX962_COMPRESSED_PRIME;

	default:
		return TLS_ECC_PF_UNCOMPRESSED;
	}
}

enum tls_hs_ecc_ec_curve_type tls_hs_ecdh_get_curve_type(int curve_type)
{
	switch (curve_type) {
	case ECP_ORG_char2Param:
		return TLS_ECC_CTYPE_EXPLICIT_CHAR2;

	case ECP_ORG_primeParam:
		return TLS_ECC_CTYPE_EXPLICIT_PRIME;

	default:
		return TLS_ECC_CTYPE_NAMED_CURVE;
	}
}

int32_t tls_hs_ecdh_skeyexc_write_server_params(TLS *tls,
						struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	struct tls_hs_ecdh_key key;

	if (tls_hs_ecdhkey_gen_for_server(tls->ecdh->namedcurve,
					  &key) != true) {
		return -1;
	}

	/* The private key's memory will be released at tls_hs_ecdh_free() */
	tls->ecdh->my_prvkey = key.prv;
	tls->ecdh->prvkey_ephemeral = true;

	/* write ECParameters */
	int32_t curve_params_len = 0;
	curve_params_len = write_curve_params(tls, msg);
	if (curve_params_len < 0) {
		return -1;
	}
	offset += curve_params_len;

	/* write ECPoint (the ephemeral ECDH public key) */
	int32_t public_len = 0;
	public_len = write_public(&key, msg);
	Key_free(key.pub);

	if (public_len < 0) {
		return -1;
	}
	offset += public_len;

	return offset;
}

int32_t tls_hs_ecdh_skeyexc_read_server_params(TLS *tls, struct tls_hs_msg *msg,
					       const uint32_t offset)
{
	uint32_t read_bytes = 0;
	struct tls_hs_ecdh_key key;

	/* read ECParameters */
	int32_t curve_params_len = 0;
	curve_params_len = read_curve_params(tls, msg, offset);
	if (curve_params_len < 0) {
		TLS_DPRINTF("read_curve_params");
		return -1;
	}
	read_bytes += curve_params_len;

	/* read the ephemeral ECDH public key */
	int32_t public_len = 0;
	public_len = read_public(&key, msg, offset + read_bytes);
	if (public_len < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += public_len;

	if (tls_hs_ecdhkey_set_peer_pubkey(tls->ecdh, &key) != true) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	return read_bytes;
}

int32_t tls_hs_ecdh_ckeyexc_write_exchange_keys(TLS *tls,
						struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	struct tls_hs_ecdh_key key;

	/*
	 * RFC4492 section 5.7.
	 *  struct {
	 *      select (PublicValueEncoding) {
	 *          case implicit: struct { };
	 *          case explicit: ECPoint ecdh_Yc;
	 *      } ecdh_public;
	 *  } ClientECDiffieHellmanPublic;
	 */

	/* TODO PublicValueEncoding */

	if (tls_hs_ecdhkey_gen_for_client(tls->ecdh->peer_pubkey,
					  &key) != true) {
		return -1;
	}

	/* The private key's memory will be released at tls_hs_ecdh_free() */
	tls->ecdh->my_prvkey = key.prv;
	tls->ecdh->prvkey_ephemeral = true;

	/* write ECPoint (the ephemeral ECDH public key) */
	int32_t public_len = 0;
	public_len = write_public(&key, msg);
	Key_free(key.pub);

	if (public_len < 0) {
		return -1;
	}
	offset += public_len;

	if (tls_hs_ecdh_calc_shared_secret(tls, tls->ecdh) != 0) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	return offset;
}

int32_t tls_hs_ecdh_ckeyexc_read_exchange_keys(TLS *tls,
					       const struct tls_hs_msg *msg,
					       const uint32_t offset)
{
	uint32_t read_bytes = 0;
	struct tls_hs_ecdh_key key;

	assert(tls->ecdh != NULL);

	/* TODO PublicValueEncoding */

	/* read ECPoint (the ephemeral ECDH public key) */
	int32_t public_len = 0;
	public_len = read_public(&key, msg, offset + read_bytes);
	if (public_len < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 4, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += public_len;

	if (tls_hs_ecdhkey_set_peer_pubkey(tls->ecdh, &key) != true) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	if (tls_hs_ecdh_calc_shared_secret(tls, tls->ecdh) != 0) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	return read_bytes;
}

bool tls_hs_ecdh_set_privkey_from_pkcs12(TLS *tls, PKCS12 *p12)
{
	return tls_hs_ecdhkey_set_my_privkey_from_pkcs12(tls->ecdh, p12);
}

bool tls_hs_ecdh_set_pubkey_from_pkcs12(TLS *tls, PKCS12 *p12)
{
	return tls_hs_ecdhkey_set_peer_pubkey_from_pkcs12(tls->ecdh, p12);
}

bool tls_hs_ecdh_alloc(TLS *tls)
{
	struct tls_hs_ecdh *ecdh;

	if (tls->ecdh != NULL) {
		return false;
	}

	ecdh = (struct tls_hs_ecdh *) calloc(1, sizeof(struct tls_hs_ecdh));
	if (ecdh == NULL) {
		TLS_DPRINTF("ecdh: calloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_CALLOC,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDH + 5, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return false;
	}

	tls->ecdh = ecdh;

	return true;
}

void tls_hs_ecdh_set_curve(struct tls_hs_ecdh *ecdh)
{
	assert(ecdh != NULL);

	ecdh->curvetype = TLS_ECC_CTYPE_NAMED_CURVE;
	ecdh->namedcurve = ecdh->eclist->list[0];
}

void tls_hs_ecdh_set_curve_by_cert(struct tls_hs_ecdh *ecdh,
				   struct tls_cert_info *cinfo)
{
	assert(ecdh != NULL);
	assert(cinfo != NULL);

	ecdh->curvetype = tls_cert_info_ecc_get_type(cinfo);

	switch (ecdh->curvetype) {
	case TLS_ECC_CTYPE_EXPLICIT_CHAR2:
		/* TODO : not implemented */
		break;

	case TLS_ECC_CTYPE_EXPLICIT_PRIME:
		/* TODO : not implemented */
		break;

	case TLS_ECC_CTYPE_NAMED_CURVE:
		ecdh->namedcurve = tls_cert_info_ecc_get_curve(cinfo);
		break;

	default:
		;
	}
}

void tls_hs_ecdh_free(struct tls_hs_ecdh *ecdh)
{
	assert(ecdh);

	tls_hs_ecc_eclist_free(ecdh->peer_eclist);
	ecdh->peer_eclist = NULL;

	tls_hs_ecc_eclist_free(ecdh->eclist);
	ecdh->eclist = NULL;

	tls_hs_ecc_pflist_free(ecdh->pflist);
	ecdh->pflist = NULL;

	if (ecdh->pubkey_ephemeral == true && ecdh->peer_pubkey != NULL) {
		Key_free(ecdh->peer_pubkey);
		ecdh->peer_pubkey = NULL;
	}

	if (ecdh->prvkey_ephemeral == true && ecdh->my_prvkey != NULL) {
		Key_free(ecdh->my_prvkey);
		ecdh->my_prvkey = NULL;
	}

	free(ecdh);
}

