/*
 * Copyright (c) 2016-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_handshake_ecdh.h"

/**
 * Get ECParam type.
 */
static int get_ECParam_type(enum tls_hs_named_curve curve);

/**
 * Validate ephemeral ECDH key of ECDSA.
 */
static bool validate_ecdsa_pubkey(Pubkey_ECDSA *pubkey);

/**
 * Generate ephemeral ECDH key of ECDSA.
 */
static bool ecdsakey_generate(ECParam *E, struct tls_hs_ecdh_key *key);

/**
 * Generate ephemeral ECDH key of X25519.
 */
static bool x25519key_generate(struct tls_hs_ecdh_key *key);

/**
 * Generate ephemeral ECDH key of X448.
 */
static bool x448key_generate(struct tls_hs_ecdh_key *key);

/**
 * Generate ephemeral ECDSA key pair for server.
 */
static bool ecdsakey_gen_for_server(enum tls_hs_named_curve curve,
				     struct tls_hs_ecdh_key *key);

/**
 * Set ECDSA public key to ecpoint in tls_hs_ecdh_key structure.
 */
static bool set_ecdsakey_to_ecpoint(struct tls_hs_ecdh_key *key);

/**
 * Set X25519 public key to ecpoint in tls_hs_ecdh_key structure.
 */
static bool set_x25519key_to_ecpoint(struct tls_hs_ecdh_key *key);

/**
 * Set X448 public key to ecpoint in tls_hs_ecdh_key structure.
 */
static bool set_x448key_to_ecpoint(struct tls_hs_ecdh_key *key);

/**
 * Set ECDSA public key received from peer to tls_hs_ecdh structure.
 */
static bool set_peer_ecdsa_pubkey(struct tls_hs_ecdh *ctx,
				  struct tls_hs_ecdh_key *key);

/**
 * Set X25519 public key received from peer to tls_hs_ecdh structure.
 */
static bool set_peer_x25519_pubkey(struct tls_hs_ecdh *ctx,
				   struct tls_hs_ecdh_key *key);

/**
 * Set X448 public key received from peer to tls_hs_ecdh structure.
 */
static bool set_peer_x448_pubkey(struct tls_hs_ecdh *ctx,
				   struct tls_hs_ecdh_key *key);

static int get_ECParam_type(enum tls_hs_named_curve curve)
{
	switch (curve) {
	case TLS_ECC_CURVE_SECP192R1:
		return ECP_X962_prime192v1;

	case TLS_ECC_CURVE_SECP224R1:
		return ECP_secp224r1;

	case TLS_ECC_CURVE_SECP256R1:
		return ECP_X962_prime256v1;

	case TLS_ECC_CURVE_SECP384R1:
		return ECP_secp384r1;

	case TLS_ECC_CURVE_SECP521R1:
		return ECP_secp521r1;

	default:
		/* unsupported */
		assert(!"unsupported curve.");
		return -1;
	}
}

static bool validate_ecdsa_pubkey(Pubkey_ECDSA *pubkey)
{
	bool ret = false;
	ECp *pub = pubkey->W;
	ECParam *ecparam = pubkey->E;

	LNm *x_third_power = NULL;
	LNm *ax = NULL;
	LNm *rightside = NULL;
	LNm *leftside = NULL;
	LNm *tmp1 = NULL;
	LNm *tmp2 = NULL;

	/*
	 * RFC8446 4.2.8.2.  ECDHE Parameters
	 *
	 *    For the curves secp256r1, secp384r1, and secp521r1, peers MUST
	 *    validate each other's public value Q by ensuring that the point is a
	 *    valid point on the elliptic curve.  The appropriate validation
	 *    procedures are defined in Section 4.3.7 of [ECDSA] and alternatively
	 *    in Section 5.6.2.3 of [KEYAGREEMENT].  This process consists of three
	 *    steps: (1) verify that Q is not the point at infinity (O), (2) verify
	 *    that for Q = (x, y) both integers x and y are in the correct
	 *    interval, and (3) ensure that (x, y) is a correct solution to the
	 *    elliptic curve equation.  For these curves, implementors do not need
	 *    to verify membership in the correct subgroup.
	 */
	/*
	 * RFC8422 5.11.  Public Key Validation
	 *
	 *    With the NIST curves, each party MUST validate the public key sent by
	 *    its peer in the ClientKeyExchange and ServerKeyExchange messages.  A
	 *    receiving party MUST check that the x and y parameters from the
	 *    peer's public value satisfy the curve equation, y^2 = x^3 + ax + b
	 *    mod p.  See Section 2.3 of [Menezes] for details.  Failing to do so
	 *    allows attackers to gain information about the private key to the
	 *    point that they may recover the entire private key in a few requests
	 *    if that key is not really ephemeral.
	 */

	/* step (1): verify that Q is not the point at infinity (O) */
	if (pub->infinity) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 12, NULL);
		return false;
	}

	/*
	 * step (2): verify that for Q = (x, y) both integers x and y are in
	 * the correct interval
	 */
	/* check 0 <= x  and x < p */
	if (pub->x->neg || LN_cmp(pub->x, ecparam->p) >= 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 13, NULL);
		return false;
	}

	/* check 0 <= y  and y < p */
	if (pub->y->neg || LN_cmp(pub->y, ecparam->p) >= 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 14, NULL);
		return false;
	}

	/*
	 * step (3): ensure that (x, y) is a correct solution to the elliptic
	 * curve equation
	 */
	if ((x_third_power = LN_alloc()) == NULL ||
	    (ax = LN_alloc()) == NULL ||
	    (rightside = LN_alloc()) == NULL ||
	    (leftside = LN_alloc()) == NULL ||
	    (tmp1 = LN_alloc()) == NULL ||
	    (tmp2 = LN_alloc()) == NULL) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 15,
			     NULL);
		goto failed;
	}

	/* calculate x^3 */
	if (LN_sqr(pub->x, tmp1) < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 16,
			     NULL);
		goto failed;
	}

	if (LN_multi(pub->x, tmp1, x_third_power) < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 17,
			     NULL);
		goto failed;
	}

	/* calculate ax */
	if (LN_multi(pub->x, ecparam->a, ax) < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 18,
			     NULL);
		goto failed;
	}

	/* calculate x^3 + ax + b */
	if (LN_plus(x_third_power, ax, tmp1) < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 19,
			     NULL);
		goto failed;
	}

	if (LN_plus(tmp1, ecparam->b, tmp2) < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 20,
			     NULL);
		goto failed;
	}

	/* calculate x^3 + ax + b mod p */
	if (LN_div_mod(tmp2, ecparam->p, tmp1, rightside) < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 21,
			     NULL);
		goto failed;
	}

	/* calculate y^2 */
	if (LN_sqr(pub->y, tmp1) < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 22,
			     NULL);
		goto failed;
	}

	/* calculate y^2 mod p */
	if (LN_div_mod(tmp1, ecparam->p, tmp2, leftside) < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 23,
			     NULL);
		goto failed;
	}

	/* compare leftside with rightside */
	if (LN_cmp(leftside, rightside) != 0) {
		OK_set_error(ERR_ST_TLS_INVALID_NIST_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 24,
			     NULL);
		goto failed;
	}

	ret = true;

failed:
	LN_free(x_third_power);
	LN_free(ax);
	LN_free(tmp1);
	LN_free(tmp2);
	LN_free(rightside);
	LN_free(leftside);

	return ret;
}

static bool ecdsakey_generate(ECParam *E, struct tls_hs_ecdh_key *key)
{
	Pubkey_ECDSA *pub = NULL;
	Prvkey_ECDSA *prv = NULL;
	bool generated = false;

	prv = ECDSAprvkey_new();
	if (prv != NULL) {
		pub = ECDSApubkey_new();
		if (pub != NULL) {
			if (ECDSAprv_generate(E, prv) == 0) {
				if (ECDSAprv_2pub(prv, pub) == 0) {
					if (validate_ecdsa_pubkey(pub)
					    == true) {
						key->prv = (Key *)prv;
						key->pub = (Key *)pub;
						generated = true;
					}
				}
			}
			if (!generated) {
				ECDSAkey_free((Key*)pub);
			}
		}
		if (!generated) {
			ECDSAkey_free((Key*)prv);
		}
	}
	return generated;
}

static bool x25519key_generate(struct tls_hs_ecdh_key *key)
{
	Pubkey_X25519 *pub = NULL;
	Prvkey_X25519 *prv = NULL;
	bool generated = false;

	prv = X25519prvkey_new();
	if (prv != NULL) {
		pub = X25519pubkey_new();
		if (pub != NULL) {
			if (X25519prv_generate(prv) == 0) {
				if (X25519prv_2pub(prv, pub) == 0) {
					key->prv = (Key *)prv;
					key->pub = (Key *)pub;
					generated = true;
				}
			}
			if (!generated) {
				X25519key_free((Key*)pub);
			}
		}
		if (!generated) {
			X25519key_free((Key*)prv);
		}
	}
	return generated;
}

static bool x448key_generate(struct tls_hs_ecdh_key *key)
{
	Pubkey_X448 *pub = NULL;
	Prvkey_X448 *prv = NULL;
	bool generated = false;

	prv = X448prvkey_new();
	if (prv != NULL) {
		pub = X448pubkey_new();
		if (pub != NULL) {
			if (X448prv_generate(prv) == 0) {
				if (X448prv_2pub(prv, pub) == 0) {
					key->prv = (Key *)prv;
					key->pub = (Key *)pub;
					generated = true;
				}
			}
			if (!generated) {
				X448key_free((Key*)pub);
			}
		}
		if (!generated) {
			X448key_free((Key*)prv);
		}
	}
	return generated;
}

static bool ecdsakey_gen_for_server(enum tls_hs_named_curve curve,
				     struct tls_hs_ecdh_key *key)
{
	ECParam *E = NULL;
	bool generated = false;
	int type;

	type = get_ECParam_type(curve);
	if (type < 0) {
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_CURVE,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 0, NULL);
		return false;
	}

	E = ECPm_get_std_parameter(type);
	if (E != NULL) {
		generated = ecdsakey_generate(E, key);
		/*
		 * E is no longer needed because it is replicated
		 * to the pub and prv,
		 */
		ECPm_free(E);
	} else {
		OK_set_error(ERR_ST_TLS_ECPM_GET_STD_PARAMETER,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 1, NULL);
	}

	return generated;
}

static bool set_ecdsakey_to_ecpoint(struct tls_hs_ecdh_key *key)
{
	unsigned char *oct;
	int32_t len;

	Pubkey_ECDSA *pubkey = (Pubkey_ECDSA *)(key->pub);
	oct = ECp_P2OS(pubkey->W, 4, &len);
	if (oct == NULL) {
		OK_set_error(ERR_ST_TLS_ECP_P2OS,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 2, NULL);
		return false;
	}

	key->ecpoint_len = len;
	memcpy(key->ecpoint, oct, len);
	free(oct);
	return true;
}

static bool set_x25519key_to_ecpoint(struct tls_hs_ecdh_key *key)
{
	Pubkey_X25519 *pubkey = (Pubkey_X25519 *)(key->pub);
	key->ecpoint_len = X25519_KEY_LENGTH;
	memcpy(key->ecpoint, pubkey->key, X25519_KEY_LENGTH);
	return true;
}

static bool set_x448key_to_ecpoint(struct tls_hs_ecdh_key *key)
{
	Pubkey_X448 *pubkey = (Pubkey_X448 *)(key->pub);
	key->ecpoint_len = X448_KEY_LENGTH;
	memcpy(key->ecpoint, pubkey->key, X448_KEY_LENGTH);
	return true;
}

static bool set_peer_ecdsa_pubkey(struct tls_hs_ecdh *ctx,
				  struct tls_hs_ecdh_key *key)
{
	ECParam *E = NULL;
	Pubkey_ECDSA *pub = NULL;
	int type;

	type = get_ECParam_type(ctx->namedcurve);
	if (type < 0) {
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_CURVE,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 3, NULL);
		return false;
	}

	E = ECPm_get_std_parameter(type);
	if (E == NULL) {
		OK_set_error(ERR_ST_TLS_ECPM_GET_STD_PARAMETER,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 4, NULL);
		return false;
	}

	pub = ECDSApubkey_new();
	if (pub == NULL) {
		OK_set_error(ERR_ST_TLS_ECDSAPUBKEY_NEW,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 5, NULL);
		ECPm_free(E);
		return false;
	}

	pub->E = E;
	ECp_free(pub->W);
	pub->W = NULL;

	pub->W = ECp_OS2P(E, key->ecpoint, key->ecpoint_len);
	if (pub->W == NULL) {
		OK_set_error(ERR_ST_TLS_ECP_OS2EP,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 6, NULL);
		ECDSAkey_free((Key*) pub);
		return false;
	}

	pub->size = pub->E->psize >> 3;

	if (validate_ecdsa_pubkey(pub) == false) {
		ECDSAkey_free((Key*) pub);
		return false;
	}

	/* The public key's memory will be released at tls_hs_ecdh_free() */
	ctx->peer_pubkey = (Key *)pub;
	ctx->pubkey_ephemeral = true;

	return true;
}

static bool set_peer_x25519_pubkey(struct tls_hs_ecdh *ctx,
				   struct tls_hs_ecdh_key *key)
{
	Pubkey_X25519 *pub = NULL;
	if ((pub = X25519pubkey_new()) == NULL) {
		TLS_DPRINTF("X25519pubkey_new");
		return false;
	}

	memcpy(pub->key, key->ecpoint, key->ecpoint_len);

	/* The public key's memory will be released at tls_hs_ecdh_free() */
	ctx->peer_pubkey = (Key *)pub;
	ctx->pubkey_ephemeral = true;

	return true;
}

static bool set_peer_x448_pubkey(struct tls_hs_ecdh *ctx,
				 struct tls_hs_ecdh_key *key)
{
	Pubkey_X448 *pub = NULL;
	if ((pub = X448pubkey_new()) == NULL) {
		TLS_DPRINTF("X448pubkey_new");
		return false;
	}

	memcpy(pub->key, key->ecpoint, key->ecpoint_len);

	/* The public key's memory will be released at tls_hs_ecdh_free() */
	ctx->peer_pubkey = (Key *)pub;
	ctx->pubkey_ephemeral = true;

	return true;
}

bool tls_hs_ecdhkey_gen_for_server(enum tls_hs_named_curve curve,
				   struct tls_hs_ecdh_key *key)
{
	switch (curve) {
	case TLS_ECC_CURVE_SECP192R1:
	case TLS_ECC_CURVE_SECP224R1:
	case TLS_ECC_CURVE_SECP256R1:
	case TLS_ECC_CURVE_SECP384R1:
	case TLS_ECC_CURVE_SECP521R1:
		return ecdsakey_gen_for_server(curve, key);

	case TLS_ECC_CURVE_X25519:
		return x25519key_generate(key);

	case TLS_ECC_CURVE_X448:
		return x448key_generate(key);

	default:
		TLS_DPRINTF("unknown curve");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_CURVE,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 25,
			     NULL);
		return false;
	}
}

bool tls_hs_ecdhkey_gen_for_client(Key *peer_pubkey,
				   struct tls_hs_ecdh_key *key)
{
	switch (peer_pubkey->key_type) {
	case KEY_ECDSA_PUB:
		/*
		 * The client selects an ephemeral ECDH public key corresponding
		 * to the parameters it received from the server
		 * (ctx->server_pub->E).
		 */
		return ecdsakey_generate(((Pubkey_ECDSA *)peer_pubkey)->E, key);

	case KEY_X25519_PUB:
		return x25519key_generate(key);

	case KEY_X448_PUB:
		return x448key_generate(key);

	default:
		TLS_DPRINTF("unknown key type");
		OK_set_error(ERR_ST_TLS_NOT_ECDH_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 26,
			     NULL);
		return false;
	}
}

bool tls_hs_ecdhkey_set_to_ecpoint(struct tls_hs_ecdh_key *key)
{
	switch (key->pub->key_type) {
	case KEY_ECDSA_PUB:
		return set_ecdsakey_to_ecpoint(key);

	case KEY_X25519_PUB:
		return set_x25519key_to_ecpoint(key);

	case KEY_X448_PUB:
		return set_x448key_to_ecpoint(key);

	default:
		TLS_DPRINTF("unknown key type");
		OK_set_error(ERR_ST_TLS_NOT_ECDH_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 27,
			     NULL);
		return false;
	}
}

bool tls_hs_ecdhkey_set_peer_pubkey(struct tls_hs_ecdh *ctx,
				    struct tls_hs_ecdh_key *key)
{
	switch (ctx->namedcurve) {
	case TLS_ECC_CURVE_SECP192R1:
	case TLS_ECC_CURVE_SECP224R1:
	case TLS_ECC_CURVE_SECP256R1:
	case TLS_ECC_CURVE_SECP384R1:
	case TLS_ECC_CURVE_SECP521R1:
		return set_peer_ecdsa_pubkey(ctx, key);

	case TLS_ECC_CURVE_X25519:
		return set_peer_x25519_pubkey(ctx, key);

	case TLS_ECC_CURVE_X448:
		return set_peer_x448_pubkey(ctx, key);

	default:
		TLS_DPRINTF("unknown named curve");
		OK_set_error(ERR_ST_TLS_NOT_ECDH_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 28,
			     NULL);
		return false;
	}
}

bool tls_hs_ecdhkey_set_my_privkey_from_pkcs12(struct tls_hs_ecdh *ctx,
					       PKCS12 *p12)
{
	Key *priv_key;

	priv_key = P12_get_privatekey(p12);

	if (priv_key == NULL) {
		OK_set_error(ERR_ST_TLS_P12_GET_PRIVATEKEY,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 7, NULL);
		return false;
	}
	if (priv_key->key_type != KEY_ECDSA_PRV) {
		OK_set_error(ERR_ST_TLS_NOT_ECDH_PRIV,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 8, NULL);
		return false;
	}

	ctx->my_prvkey = priv_key;
	ctx->prvkey_ephemeral = false;
	return true;
}

bool tls_hs_ecdhkey_set_peer_pubkey_from_pkcs12(struct tls_hs_ecdh *ctx,
						PKCS12 *p12)
{
	Cert* server_cert = P12_get_usercert(p12);
	Key	*pubkey;

	if (server_cert == NULL) {
		OK_set_error(ERR_ST_TLS_P12_GET_USERCERT,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 9, NULL);
		return false;
	}
	pubkey = server_cert->pubkey;

	if (pubkey == NULL) {
		OK_set_error(ERR_ST_TLS_CET_NO_PUBKEY,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 10,
			     NULL);
		return false;
	}
	if (pubkey->key_type != KEY_ECDSA_PUB) {
		OK_set_error(ERR_ST_TLS_NOT_ECDH_PUB,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_ECDH_ECDHKEY + 11,
			     NULL);
		return false;
	}

	ctx->peer_pubkey = pubkey;
	ctx->pubkey_ephemeral = false;
	return true;
}
