/*
 * Copyright (c) 2016-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_alert.h"

#include <stdint.h>

/**
 * Set of Named Curves that are implemented by ecc module.
 *
 * @see ecc/ecc_std.c ECPm_get_recommended_elliptic_curve()
 */
static enum tls_hs_named_curve tls_hs_ecc_supported_named_curves[] = {
	/** Curve448 */
	TLS_ECC_CURVE_X448,
	/** Curve25519 */
	TLS_ECC_CURVE_X25519,
	/** NIST P-256: ECP_X962_prime256v1 */
	TLS_ECC_CURVE_SECP256R1,
#if 0 /* not implemented */
#ifdef TLS_OBSOLETE_ALGO
	/** NIST P-192: ECP_X962_prime192v1 */
	TLS_ECC_CURVE_SECP192R1,
	/** NIST P-224: ECP_secp224r1 */
	TLS_ECC_CURVE_SECP224R1,
#endif
	/** NIST P-384: ECP_secp384r1 */
	TLS_ECC_CURVE_SECP384R1,
	/** NIST P-521: ECP_secp521r1 */
	TLS_ECC_CURVE_SECP521R1,
#endif
};

/**
 * Set of ECPointFormat.
 *
 * The uncompressed point format is the default format in that
 * implementations of this document MUST support it for all of their
 * supported curves.
 */
enum tls_hs_ecc_ec_point_format tls_hs_ecc_supported_point_format[] = {
	/** default format. */
	TLS_ECC_PF_UNCOMPRESSED,
};

/**
 * Write the struct EllipticCurveList and its length.
 *
 * RFC4492 5.1.1.  Supported Elliptic Curves Extension
 *
 * struct {
 *     NamedCurve elliptic_curve_list<1..2^16-1>
 * } EllipticCurveList;
 */
static int32_t write_elliptic_curve_list(TLS *tls, struct tls_hs_msg *msg);

/**
 * Write the struct EllipticCurveList and its length.
 *
 * RFC4492 5.1.2.  Supported Point Formats Extension
 *
 *      struct {
 *          ECPointFormat ec_point_format_list<1..2^8-1>
 *      } ECPointFormatList;
 */
static int32_t write_ec_point_format_list(TLS *tls, struct tls_hs_msg *msg);

/**
 * compose tls_hs_ecc_eclist structure from array of elliptic curve passed as
 * argument.
 */
static struct tls_hs_ecc_eclist *compose_eclist(enum tls_hs_named_curve *list,
						int16_t count);

/**
 * Save specified Elliptic Curve list to the tls structure.
 */
static int32_t save_eclist(TLS *tls, enum tls_hs_named_curve *list,
			   int16_t count);

/**
 * Save specified Elliptic Curve as is list to the tls structure.
 */
static int32_t save_peer_eclist(TLS *tls, enum tls_hs_named_curve *list,
			   int16_t count);

/**
 * Save specified Point Format list to the tls structure.
 */
static int32_t save_pflist(TLS *tls, enum tls_hs_ecc_ec_point_format *list,
			   int16_t count);

/**
 */
static bool check_named_curve_supported(enum tls_hs_named_curve ec);

/**
 */
static bool check_point_format_supported(enum tls_hs_ecc_ec_point_format pf);


static int32_t write_elliptic_curve_list(TLS *tls UNUSED,
					 struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	int32_t len = sizeof(tls_hs_ecc_supported_named_curves);
	int32_t n = len / sizeof(enum tls_hs_named_curve);

	/*
	 * RFC8446 4.2.7.  Supported Groups
	 *
	 *           NamedGroup named_group_list<2..2^16-1>;
	 *
	 * RFC4492 5.1.1.  Supported Elliptic Curves Extension
	 *
	 *             NamedCurve elliptic_curve_list<1..2^16-1>
	 */
	const int32_t list_length_min = 2;
	const int32_t list_length_max = (2 << (16 - 1)) - 2;
	if (len < list_length_min || list_length_max < len) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 12, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/*
	 * Write the length of the struct EllipticCurveList.
	 * This length occupies two octets.
	 */
	if (! tls_hs_msg_write_2(msg, n * 2)) {
		return -1;
	}
	offset += 2;

	/*
	 * Write the list of the NamedCurve.
	 * The one element occupies two octets.
	 */
	for (int i = 0; i < n; ++i) {
		if (! tls_hs_msg_write_2(msg,
					tls_hs_ecc_supported_named_curves[i])) {
			return -1;
		}
		offset += 2;
	}

	return offset;
}

static int32_t write_ec_point_format_list(TLS *tls,
					  struct tls_hs_msg *msg)
{
	int32_t offset = 0;
	enum tls_hs_ecc_ec_point_format *list = NULL;
	/* number of bytes */
	int32_t len = 0;
	/* number of elements */
	int32_t n = 0;

	switch (msg->type) {
	case TLS_HANDSHAKE_SERVER_HELLO:
		list = tls->ecdh->pflist->list;
		len = tls->ecdh->pflist->len;
		n = len / 1;
		break;
	case TLS_HANDSHAKE_CLIENT_HELLO:
		list = tls_hs_ecc_supported_point_format;
		n = sizeof(tls_hs_ecc_supported_point_format)
			/ sizeof(enum tls_hs_ecc_ec_point_format);
		len = n * 1;
		break;
	default:
		assert(!"message type error");
	}

	/*
	 * RFC4492 5.1.2.  Supported Point Formats Extension
	 *
	 *             ECPointFormat ec_point_format_list<1..2^8-1>
	 */
	const int32_t list_length_min = 1;
	const int32_t list_length_max = TLS_VECTOR_1_BYTE_SIZE_MAX;
	if (len < list_length_min || list_length_max < len) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 13, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/*
	 * Write the length of the struct ECPointFormatList.
	 * This length fits on one octet.
	 */
	if (! tls_hs_msg_write_1(msg, len)) {
		return -1;
	}
	offset++;

	/*
	 * Write the list of ECPointFormat.
	 * The one format fits on one octet.
	 */
	for (int i = 0; i < n; ++i) {
		if (! tls_hs_msg_write_1(msg, list[i])) {
			return -1;
		}
		offset++;
	}

	return offset;
}

static struct tls_hs_ecc_eclist *compose_eclist(enum tls_hs_named_curve *list,
						int16_t count) {
	struct tls_hs_ecc_eclist *eclist;

	if ((eclist = malloc(
		     1 * sizeof (struct tls_hs_ecc_eclist))) == NULL) {
		TLS_DPRINTF("ecc: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 1, NULL);
		return NULL;
	}

	eclist->len = count;
	uint32_t eclist_size =
		count * sizeof (enum tls_hs_named_curve);

	if ((eclist->list = malloc(eclist_size)) == NULL) {
		TLS_DPRINTF("ecc: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 8, NULL);
		free(eclist);
		return NULL;
	}

	memcpy(&(eclist->list[0]), list, eclist_size);

	return eclist;
}

static int32_t save_eclist(TLS *tls, enum tls_hs_named_curve *list,
			   int16_t count)
{
	struct tls_hs_ecc_eclist *eclist;

	/*
	 * tls->ecdh must have been initialized with init_handshake().
	 */
	if (tls->ecdh == NULL || tls->ecdh->eclist != NULL) {
		OK_set_error(ERR_ST_NULLPOINTER,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 9, NULL);
		return -1;
	}

	if ((eclist = compose_eclist(list, count)) == NULL) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls->ecdh->eclist = eclist;

	return 0;
}

static int32_t save_peer_eclist(TLS *tls, enum tls_hs_named_curve *list,
			   int16_t count)
{
	/*
	 * tls->ecdh must have been initialized with init_handshake().
	 */
	if (tls->ecdh == NULL || tls->ecdh->peer_eclist != NULL) {
		OK_set_error(ERR_ST_NULLPOINTER,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 10, NULL);
		return -1;
	}

	struct tls_hs_ecc_eclist *eclist;
	if ((eclist = compose_eclist(list, count)) == NULL) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls->ecdh->peer_eclist = eclist;

	return 0;
}

static int32_t save_pflist(TLS *tls, enum tls_hs_ecc_ec_point_format *list,
			   int16_t count)
{
	struct tls_hs_ecc_pflist *pflist;

	/*
	 * tls->ecdh must have been initialized with init_handshake().
	 */
	if (tls->ecdh == NULL || tls->ecdh->pflist != NULL) {
		OK_set_error(ERR_ST_NULLPOINTER,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 11, NULL);
		return -1;
	}

	if ((pflist = malloc(
		     1 * sizeof (struct tls_hs_ecc_pflist))) == NULL) {
		TLS_DPRINTF("ecc: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 2, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	pflist->len = count;
	uint32_t pflist_size =
		count * sizeof (enum tls_hs_ecc_ec_point_format);

	if ((pflist->list = malloc(pflist_size)) == NULL) {
		TLS_DPRINTF("ecc: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		free(pflist);
		return -1;
	}

	memcpy(&(pflist->list[0]), list, pflist_size);

	tls->ecdh->pflist = pflist;

	return 0;
}


static bool check_named_curve_supported(enum tls_hs_named_curve ec)
{
	int n = sizeof(tls_hs_ecc_supported_named_curves) /
		sizeof(enum tls_hs_named_curve);

	int found = 0;

	for (int i = 0; i < n; i++) {
		if (tls_hs_ecc_supported_named_curves[i] == ec) {
			found = 1;
			break;
		}
	}

	return found;
}

static bool check_point_format_supported(enum tls_hs_ecc_ec_point_format pf)
{
	int n = sizeof(tls_hs_ecc_supported_point_format) /
		sizeof(enum tls_hs_ecc_ec_point_format);

	int found = 0;

	for (int i = 0; i < n; i++) {
		if (tls_hs_ecc_supported_point_format[i] == pf) {
			found = 1;
			break;
		}
	}

	return found;
}

int32_t tls_hs_ecc_write_elliptic_curves(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t off = 0;

	if (msg->type == TLS_HANDSHAKE_ENCRYPTED_EXTENSIONS) {
		/* check whether client sent the extension */
		bool *recv_exts = tls->interim_params->recv_ext_flags;
		if (recv_exts[TLS_EXT_ELLIPTIC_CURVES] == false) {
			return 0;
		}
	} else {
		uint16_t version;
		version = tls_util_convert_protover_to_ver(
			&(tls->client_version));

		switch (version) {
		case TLS_VER_TLS10:
		case TLS_VER_TLS11:
		case TLS_VER_TLS12:
			break;

		case TLS_VER_SSL30:
		default:
			return 0;
		}
	}

	const uint32_t type_bytes = 2;
	if (tls_hs_msg_write_2(msg, TLS_EXT_ELLIPTIC_CURVES) == false) {
		return -1;
	}
	off += type_bytes;

	/* write dummy length bytes. */
	int32_t pos = msg->len;

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	off += len_bytes;

	int32_t act_len;
	if ((act_len = write_elliptic_curve_list(tls, msg)) < 0) {
		return -1;
	}
	off += act_len;

	const int32_t extlen_max = TLS_EXT_SIZE_MAX;
	if (act_len > extlen_max) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 14, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), act_len);

	return off;
}

int32_t tls_hs_ecc_write_ec_point_formats(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t off = 0;

	uint16_t version;

	if (msg->type == TLS_HANDSHAKE_CLIENT_HELLO) {
		bool valid_version = false;

		for (int i = 0; i < tls->supported_versions.len; i++) {
			version = tls->supported_versions.list[i];
			switch (version) {
			case TLS_VER_TLS10:
			case TLS_VER_TLS11:
			case TLS_VER_TLS12:
				valid_version = true;
				break;

			default:
				break;
			}
		}

		if (valid_version == false) {
			return 0;
		}
	} else {
		/* If ClientHello.extention did not include point format... */
		if (tls->ecdh->pflist == NULL) {
			return 0;
		}

		version = tls_util_convert_protover_to_ver(
			&(tls->negotiated_version));

		switch (version) {
		case TLS_VER_TLS10:
		case TLS_VER_TLS11:
		case TLS_VER_TLS12:
			break;

		case TLS_VER_SSL30:
		case TLS_VER_TLS13:
		default:
			return 0;
		}
	}

	const uint32_t type_bytes = 2;
	if (tls_hs_msg_write_2(msg, TLS_EXT_EC_POINT_FORMATS) == false) {
		return -1;
	}
	off += type_bytes;

	/* write dummy length bytes. */
	int32_t pos = msg->len;

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	off += len_bytes;

	int32_t act_len;
	if ((act_len = write_ec_point_format_list(tls, msg)) < 0) {
		return -1;
	}
	off += act_len;

	const int32_t extlen_max = TLS_EXT_SIZE_MAX;
	if (act_len > extlen_max) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 15, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), act_len);

	return off;
}

int32_t tls_hs_ecc_read_elliptic_curves(TLS *tls,
					const struct tls_hs_msg *msg,
					const uint32_t offset)
{
	uint32_t read_bytes = 0;

	const int32_t length_bytes = 2;
	if (msg->len < (offset + length_bytes)) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 4, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += length_bytes;

	uint16_t list_length = tls_util_read_2(&(msg->msg[offset]));
	if (msg->len < (offset + read_bytes + list_length)) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 5, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/*
	 * RFC8446 4.2.7.  Supported Groups
	 *
	 *           NamedGroup named_group_list<2..2^16-1>;
	 *
	 * RFC4492 5.1.1.  Supported Elliptic Curves Extension
	 *
	 *             NamedCurve elliptic_curve_list<1..2^16-1>
	 */
	const int32_t list_length_min = 2;
	const int32_t list_length_max = (2 << (16 - 1)) - 2;
	if (list_length < list_length_min || list_length_max < list_length) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 16, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	const uint32_t base = offset + read_bytes;
	uint16_t n = list_length / 2;
	enum tls_hs_named_curve eclist[n];
	enum tls_hs_named_curve peer_eclist[n];
	int32_t off = 0;
	int16_t count = 0;

	for (uint16_t i = 0; i < n; i++) {
		enum tls_hs_named_curve ec;
		ec = tls_util_read_2(&(msg->msg[base + off]));
		TLS_DPRINTF("ecc: check curve = %.2x", ec);
		if (check_named_curve_supported(ec)) {
			TLS_DPRINTF("ecc: accept curve = %.2x", ec);
			eclist[count] = ec;
			count++;
		}
		peer_eclist[i] = ec;
		off += 2;
	}
	read_bytes += list_length;

	if (msg->type == TLS_HANDSHAKE_ENCRYPTED_EXTENSIONS) {
		/* TODO: can learn preferable groups for next session */
		TLS_DPRINTF("ecc: ignore supported_groups");
	} else {
		if (save_eclist(tls, eclist, count) != 0) {
			return -1;
		}

		if (save_peer_eclist(tls, peer_eclist, n) != 0) {
			return -1;
		}
	}

	return read_bytes;
}

int32_t tls_hs_ecc_read_point_format(TLS *tls,
				     const struct tls_hs_msg *msg,
				     const uint32_t offset)
{
	uint32_t read_bytes = 0;

	const int32_t length_bytes = 1;
	if (msg->len < (offset + length_bytes)) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 6, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += length_bytes;

	uint8_t list_length = msg->msg[offset];
	/*
	 * RFC4492 5.1.2.  Supported Point Formats Extension
	 *
	 *             ECPointFormat ec_point_format_list<1..2^8-1>
	 */
	const int32_t list_length_min = 1;
	if (list_length < list_length_min) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 17, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	if (msg->len < (offset + read_bytes + list_length)) {
		TLS_DPRINTF("ecc: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_EXT_ECC + 7, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	const uint32_t base = offset + read_bytes;
	enum tls_hs_ecc_ec_point_format pflist[list_length];
	int32_t count = 0;
	for (uint16_t i = 0; i < list_length; i++) {
		enum tls_hs_ecc_ec_point_format pf = msg->msg[base + i];
		if (check_point_format_supported(pf)) {
			pflist[count] = pf;
			count++;
		}
	}
	read_bytes += list_length;

	if (save_pflist(tls, pflist, count) != 0) {
		return -1;
	}

	return read_bytes;
}

void tls_hs_ecc_get_supported_eclist(struct tls_hs_ecc_eclist *eclist)
{
	int32_t len = sizeof(tls_hs_ecc_supported_named_curves);
	int32_t n = len / sizeof(enum tls_hs_named_curve);

	eclist->len = n;
	eclist->list = tls_hs_ecc_supported_named_curves;
}

void tls_hs_ecc_get_supported_pflist(struct tls_hs_ecc_pflist *pflist)
{
	int32_t len = sizeof(tls_hs_ecc_supported_point_format);
	int32_t n = len / sizeof(enum tls_hs_ecc_ec_point_format);

	pflist->len = n;
	pflist->list = tls_hs_ecc_supported_point_format;
}

void tls_hs_ecc_eclist_free(struct tls_hs_ecc_eclist *eclist)
{
	if (eclist == NULL) {
		return;
	}

	free(eclist->list);
	free(eclist);
}

void tls_hs_ecc_pflist_free(struct tls_hs_ecc_pflist *pflist)
{
	if (pflist == NULL) {
		return;
	}

	free(pflist->list);
	free(pflist);
}
