/*
 * Copyright (c) 2016-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_alert.h"

/**
 * Write the struct KeyShareClientHello and its length.
 *
 * RFC8446 4.2.8.  Key Share
 *
 *       struct {
 *           NamedGroup group;
 *           opaque key_exchange<1..2^16-1>;
 *       } KeyShareEntry;
 *
 *       struct {
 *           KeyShareEntry client_shares<0..2^16-1>;
 *       } KeyShareClientHello;
 */
static int32_t write_keyshare_for_chello(TLS *tls, struct tls_hs_msg *msg);

/**
 * Write the struct KeyShareClientHello and its length for second client
 * hello.
 */
static int32_t write_keyshare_for_2ndchello(TLS *tls, struct tls_hs_msg *msg);

/**
 * Write the struct KeyShareServerHello.
 *
 * RFC8466 4.2.8.  Key Share
 *
 *       struct {
 *           NamedGroup group;
 *           opaque key_exchange<1..2^16-1>;
 *       } KeyShareEntry;
 *
 *       struct {
 *           KeyShareEntry server_share;
 *       } KeyShareServerHello;
 */
static int32_t write_keyshare_for_shello(TLS *tls, struct tls_hs_msg *msg);

/**
 * Write the struct KeyShareHelloRetryRequest.
 *
 * RFC8446 4.2.8.  Key Share
 *
 *       struct {
 *           NamedGroup selected_group;
 *       } KeyShareHelloRetryRequest;
 */
static int32_t write_keyshare_for_hrr(TLS *tls, struct tls_hs_msg *msg);

/**
 * Read the struct KeyShareClientHello.
 *
 * RFC8446 4.2.8.  Key Share
 *
 *       struct {
 *           NamedGroup group;
 *           opaque key_exchange<1..2^16-1>;
 *       } KeyShareEntry;
 *
 *       struct {
 *           KeyShareEntry client_shares<0..2^16-1>;
 *       } KeyShareClientHello;
 */
static int32_t read_keyshare_in_chello(TLS *tls,
					const struct tls_hs_msg *msg,
					const uint32_t offset);

/**
 * Read the struct KeyShareServerHello.
 *
 * RFC8466 4.2.8.  Key Share
 *
 *       struct {
 *           NamedGroup group;
 *           opaque key_exchange<1..2^16-1>;
 *       } KeyShareEntry;
 *
 *       struct {
 *           KeyShareEntry server_share;
 *       } KeyShareServerHello;
 */
static int32_t read_keyshare_in_shello(TLS *tls,
					const struct tls_hs_msg *msg,
					const uint32_t offset);

/**
 * Read the struct KeyShareHelloRetryRequest.
 *
 * RFC8446 4.2.8.  Key Share
 *
 *       struct {
 *           NamedGroup selected_group;
 *       } KeyShareHelloRetryRequest;
 */
static int32_t read_keyshare_in_hrr(TLS *tls,
				    const struct tls_hs_msg *msg,
				    const uint32_t offset);

static int32_t write_keyshare_for_chello(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t offset = 0;
	int32_t pos = msg->len;
	struct tls_hs_ecc_eclist glist;
	struct tls_hs_interim_params *params = tls->interim_params;
	struct tls_hs_key_share *share;
	struct tls_hs_ecdh_key *key;

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *    This vector MAY be empty if the client is requesting a
	 *    HelloRetryRequest.  Each KeyShareEntry value MUST correspond to a
	 *    group offered in the "supported_groups" extension and MUST appear in
	 *    the same order.  However, the values MAY be a non-contiguous subset
	 *    of the "supported_groups" extension and MAY omit the most preferred
	 *    groups.  Such a situation could arise if the most preferred groups
	 *    are new and unlikely to be supported in enough places to make
	 *    pregenerating key shares for them efficient.
	 */
	/*
	 * TODO: assume default supported ecc curve list is commonly used
	 * between "supported_groups" and "key_share" currently but it will be
	 * not true in future. it may be better to refer to ecc curve list used
	 * actually than referring to the default list.
	 */
	tls_hs_ecc_get_supported_eclist(&glist);

	/*
	 * write the length of list of the struct KeyShareEntry.
	 * this length occupies two octets. at this point, write dummy
	 * length bytes.
	 */
	if (! tls_hs_msg_write_2(msg, 0)) {
		return -1;
	}
	offset += 2;

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *    Clients can offer as many KeyShareEntry values as the number of
	 *    supported groups it is offering, each representing a single set of
	 *    key exchange parameters.  For instance, a client might offer shares
	 *    for several elliptic curves or multiple FFDHE groups.  The
	 *    key_exchange values for each KeyShareEntry MUST be generated
	 *    independently.  Clients MUST NOT offer multiple KeyShareEntry values
	 *    for the same group.  Clients MUST NOT offer any KeyShareEntry values
	 *    for groups not listed in the client's "supported_groups" extension.
	 *    Servers MAY check for violations of these rules and abort the
	 *    handshake with an "illegal_parameter" alert if one is violated.
	 */
	for (int i = 0; i < glist.len; i++) {
		if ((share = tls_hs_keyshare_init()) == NULL) {
			return -1;
		}

		/* shares in queue are freed later including error case. */
		TAILQ_INSERT_TAIL(&(params->share_head), share, link);

		if (tls_hs_keyshare_generate(share, glist.list[i]) != true) {
			return -1;
		}

		key = &(share->key);
		TLS_DPRINTF("keyshare: named group = %d", share->group);
		TLS_DPRINTF("keyshare: ecpoint_len = %d", key->ecpoint_len);

		/*
		 * RFC8446 4.2.8.  Key Share
		 *
		 *           opaque key_exchange<1..2^16-1>;
		 */
		const int32_t key_len_min = 1;
		if (key->ecpoint_len < key_len_min) {
			TLS_DPRINTF("key share: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_KEYSHARE + 29, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		uint16_t group = glist.list[i];
		if (! tls_hs_msg_write_2(msg, group)) {
			return -1;
		}
		offset += 2;

		if (tls_hs_msg_write_2(msg, key->ecpoint_len) == false) {
			return -1;
		}
		offset += 2;

		if (tls_hs_msg_write_n(msg, key->ecpoint, key->ecpoint_len)
		    == false) {
			return -1;
		}
		offset += key->ecpoint_len;
	}

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *           KeyShareEntry client_shares<0..2^16-1>;
	 */
	const int32_t list_length_min = 0;
	const int32_t list_length_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (offset - 2 < list_length_min || list_length_max < offset - 2) {
		TLS_DPRINTF("key share: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 30, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), offset - 2);

	return offset;
}

static int32_t write_keyshare_for_2ndchello(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *    Upon receipt of this extension in a HelloRetryRequest, the client
	 *    MUST verify that (1) the selected_group field corresponds to a group
	 *    which was provided in the "supported_groups" extension in the
	 *    original ClientHello and (2) the selected_group field does not
	 *    correspond to a group which was provided in the "key_share" extension
	 *    in the original ClientHello.  If either of these checks fails, then
	 *    the client MUST abort the handshake with an "illegal_parameter"
	 *    alert.  Otherwise, when sending the new ClientHello, the client MUST
	 *    replace the original "key_share" extension with one containing only a
	 *    new KeyShareEntry for the group indicated in the selected_group field
	 *    of the triggering HelloRetryRequest.
	 */

	/* write dummy length of list of the struct KeyShareEntry. */
	int32_t pos = msg->len;
	if (! tls_hs_msg_write_2(msg, 0)) {
		TLS_DPRINTF("tls_hs_msg_write_2");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	offset += 2;

	struct tls_hs_key_share *share;
	if ((share = tls_hs_keyshare_init()) == NULL) {
		TLS_DPRINTF("tls_hs_keyshare_init");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/* shares in queue are freed later including error case. */
	TAILQ_INSERT_TAIL(&(tls->interim_params->share_head), share, link);

	uint16_t group = tls->ecdh->namedcurve;
	if (tls_hs_keyshare_generate(share, group) != true) {
		TLS_DPRINTF("tls_hs_keyshare_generate");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	struct tls_hs_ecdh_key *key = &(share->key);
	TLS_DPRINTF("keyshare: named group = %d", share->group);
	TLS_DPRINTF("keyshare: ecpoint_len = %d", key->ecpoint_len);

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *           opaque key_exchange<1..2^16-1>;
	 */
	const int32_t key_len_min = 1;
	if (key->ecpoint_len < key_len_min) {
		TLS_DPRINTF("key share: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 31, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	if (! tls_hs_msg_write_2(msg, group)) {
		TLS_DPRINTF("tls_hs_msg_write_2");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	offset += 2;

	if (tls_hs_msg_write_2(msg, key->ecpoint_len) == false) {
		TLS_DPRINTF("tls_hs_msg_write_2");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	offset += 2;

	if (tls_hs_msg_write_n(msg, key->ecpoint, key->ecpoint_len) == false) {
		TLS_DPRINTF("tls_hs_msg_write_n");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	offset += key->ecpoint_len;

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *           KeyShareEntry client_shares<0..2^16-1>;
	 */
	const int32_t list_length_min = 0;
	const int32_t list_length_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (offset - 2 < list_length_min || list_length_max < offset - 2) {
		TLS_DPRINTF("key share: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 32, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), offset - 2);

	return offset;
}

static int32_t write_keyshare_for_shello(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t offset = 0;
	struct tls_hs_ecdh_key key;

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *    If using (EC)DHE key establishment, servers offer exactly one
	 *    KeyShareEntry in the ServerHello.  This value MUST be in the same
	 *    group as the KeyShareEntry value offered by the client that the
	 *    server has selected for the negotiated key exchange.  Servers
	 *    MUST NOT send a KeyShareEntry for any group not indicated in the
	 *    client's "supported_groups" extension and MUST NOT send a
	 *    KeyShareEntry when using the "psk_ke" PskKeyExchangeMode.
	 */
	/* TODO: add procedure for DHE and PSK */
	uint16_t group;
	switch (tls->keymethod) {
	case TLS_KXC_ECDHE:
		group = tls->ecdh->namedcurve;
		if (tls_hs_ecdhkey_gen_for_server(group, &key) != true) {
			TLS_DPRINTF("keyshare: tls_hs_ecdhkey_gen_for_server");
			return -1;
		}

		/* The private key's memory will be released at tls_hs_ecdh_free() */
		tls->ecdh->my_prvkey = key.prv;
		tls->ecdh->prvkey_ephemeral = true;

		if (tls_hs_ecdhkey_set_to_ecpoint(&key) != true) {
			TLS_DPRINTF("keyshare: tls_hs_ecdhkey_set_to_ecpoint");
			Key_free(key.pub);
			return -1;
		}

		if (! tls_hs_msg_write_2(msg, group)) {
			TLS_DPRINTF("keyshare: tls_hs_msg_write_2");
			Key_free(key.pub);
			return -1;
		}
		offset += 2;

		/*
		 * RFC8446 4.2.8.  Key Share
		 *
		 *           opaque key_exchange<1..2^16-1>;
		 */
		const int32_t key_len_min = 1;
		if (key.ecpoint_len < key_len_min) {
			TLS_DPRINTF("key share: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_KEYSHARE + 33, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		if (tls_hs_msg_write_2(msg, key.ecpoint_len) == false) {
			TLS_DPRINTF("keyshare: tls_hs_msg_write_2");
			Key_free(key.pub);
			return -1;
		}
		offset += 2;

		if (tls_hs_msg_write_n(msg, key.ecpoint, key.ecpoint_len) == false) {
			TLS_DPRINTF("keyshare: tls_hs_msg_write_n");
			Key_free(key.pub);
			return -1;
		}
		offset += key.ecpoint_len;

		Key_free(key.pub);

		TLS_DPRINTF("keyshare: named group = %d", group);
		TLS_DPRINTF("keyshare: ecpoint_len = %d", key.ecpoint_len);

		return offset;

	case TLS_KXC_DHE:
	case TLS_KXC_PSK:
	case TLS_KXC_PSK_DHE:
	case TLS_KXC_PSK_ECDHE:
		/* Not implemented */
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEY_EXCHANGE,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE, NULL);
		return -1;

	default:
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEY_EXCHANGE,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 1, NULL);
		return -1;
	}
}

static int32_t write_keyshare_for_hrr(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *    In a HelloRetryRequest message, the "extension_data" field of this
	 *    extension contains a KeyShareHelloRetryRequest value:
	 *
	 *       struct {
	 *           NamedGroup selected_group;
	 *       } KeyShareHelloRetryRequest;
	 *
	 *    selected_group:  The mutually supported group the server intends to
	 *       negotiate and is requesting a retried ClientHello/KeyShare for.
	 */

	uint16_t group = tls->ecdh->namedcurve;
	if (! tls_hs_msg_write_2(msg, group)) {
		TLS_DPRINTF("keyshare: tls_hs_msg_write_2");
		return -1;
	}
	offset += 2;

	return offset;
}

static int32_t read_keyshare_in_chello(TLS *tls,
					   const struct tls_hs_msg *msg,
					   const uint32_t offset)
{
	uint32_t read_bytes = 0;

	const uint32_t length_bytes = 2;
	if (msg->len < (offset + length_bytes)) {
		TLS_DPRINTF("key share: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 2, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += length_bytes;

	const uint16_t list_length = tls_util_read_2(&(msg->msg[offset]));
	if (msg->len < (offset + read_bytes + list_length)) {
		TLS_DPRINTF("key share: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	if (list_length == 0) {
		/*
		 * RFC8446 4.2.8.  Key Share
		 *
		 *    This vector MAY be empty if the client is requesting a
		 *    HelloRetryRequest.
		 */
		/*
		 * client requests hello retry request but return without doing
		 * anything here because selection of key exchange method is
		 * performed later.
		 */
		return read_bytes;
	}

	/* parse KeyShareEntry structure */
	int group_count = 0;
	int group_idx;
	int prev_group_idx = -1;
	struct tls_hs_ecdh *ecdh = tls->ecdh;
	struct tls_hs_ecc_eclist *eclist = ecdh->peer_eclist;
	int group_types[eclist->len];
	enum tls_hs_named_curve namedgroup;
	struct tls_hs_key_share *share;

	memset(group_types, 0, sizeof(group_types));

	while (read_bytes < length_bytes + list_length) {
		/*
		 * RFC8466 4.2.8.  Key Share
		 *
		 *    This vector MAY be empty if the client is requesting a
		 *    HelloRetryRequest.  Each KeyShareEntry value MUST correspond to a
		 *    group offered in the "supported_groups" extension and MUST appear in
		 *    the same order.  However, the values MAY be a non-contiguous subset
		 *    of the "supported_groups" extension and MAY omit the most preferred
		 *    groups.  Such a situation could arise if the most preferred groups
		 *    are new and unlikely to be supported in enough places to make
		 *    pregenerating key shares for them efficient.
		 *
		 *    Clients can offer as many KeyShareEntry values as the number of
		 *    supported groups it is offering, each representing a single set of
		 *    key exchange parameters.  For instance, a client might offer shares
		 *    for several elliptic curves or multiple FFDHE groups.  The
		 *    key_exchange values for each KeyShareEntry MUST be generated
		 *    independently.  Clients MUST NOT offer multiple KeyShareEntry values
		 *    for the same group.  Clients MUST NOT offer any KeyShareEntry values
		 *    for groups not listed in the client's "supported_groups" extension.
		 *    Servers MAY check for violations of these rules and abort the
		 *    handshake with an "illegal_parameter" alert if one is violated.
		 */

		/*
		 * check if the number of key shares exceeds the one
		 * in supported_groups extension.
		 */
		if (eclist->len == group_count) {
			TLS_DPRINTF("keyshare: the number of key shares exceeds the "
			            "one in suuporetd_groups");
			OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
				     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 4,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return -1;
		}

		const uint16_t named_group_bytes = 2;
		if (msg->len < (offset + read_bytes + named_group_bytes)) {
			TLS_DPRINTF("keyshare: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_KEYSHARE + 5, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
			return -1;
		}

		namedgroup = tls_util_read_2(&(msg->msg[offset + read_bytes]));
		read_bytes += named_group_bytes;
		TLS_DPRINTF("keyshare: named group = %d", namedgroup);

		const uint16_t share_length_bytes = 2;
		if (msg->len < (offset + read_bytes + share_length_bytes)) {
			TLS_DPRINTF("keyshare: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_KEYSHARE + 6, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
			return -1;
		}

		group_idx = -1;
		for (int i = 0; i < eclist->len; i++) {
			if (eclist->list[i] == namedgroup) {
				group_idx = i;
			}
		}

		/* check if the group is in supported_groups extension. */
		if (group_idx < 0) {
			TLS_DPRINTF("keyshare: group isn't in supported_groups");
			OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
				     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 7,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return -1;
		}

		/*
		 * check if the group appears in the same order with
		 * supported_groups extension.
		 */
		if (prev_group_idx > 0 && prev_group_idx >= group_idx) {
			TLS_DPRINTF("keyshare: not the same order with"
			            " supported_groups");
			OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
				     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 8,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return -1;
		}

		/* check if the group apears multiple times. */
		group_types[group_idx]++;
		if (group_types[group_idx] > 1) {
			TLS_DPRINTF("keyshare: the same group appears multiple"
			            " times");
			OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
				     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 9,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return -1;
		}

		prev_group_idx = group_idx;

		uint32_t share_length;
		share_length = tls_util_read_2(&(msg->msg[offset + read_bytes]));
		read_bytes += share_length_bytes;

		if (msg->len < (offset + read_bytes + share_length)) {
			TLS_DPRINTF("keyshare: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_KEYSHARE + 10, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
			return -1;
		}

		if (share_length > TLS_EXT_KEY_SHARE_EXCHANGE_SIZE_MAX) {
			TLS_DPRINTF("keyshare: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_KEYSHARE + 11, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
			return -1;
		}

		/* TODO: add procedure for Finite-Field-Group */
		switch (namedgroup) {
		case TLS_NAMED_GROUP_SECP256R1:
		case TLS_NAMED_GROUP_X25519:
		case TLS_NAMED_GROUP_X448:
			if (share_length > sizeof(share->key.ecpoint)) {
				TLS_DPRINTF("keyshare: ecpoint size overflow");
				OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
					     ERR_LC_TLS5,
					     ERR_PT_TLS_HS_EXT_KEYSHARE + 12,
					     NULL);
				TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
				return -1;
			}

			if ((share = malloc(sizeof(struct tls_hs_key_share))) == NULL) {
				TLS_DPRINTF("malloc: %s", strerror(errno));
				OK_set_error(ERR_ST_TLS_MALLOC,
					     ERR_LC_TLS5,
					     ERR_PT_TLS_HS_EXT_KEYSHARE + 13,
					     NULL);
				TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
				return -1;
			}
			memset(share, 0, sizeof(struct tls_hs_key_share));

			share->group = namedgroup;
			share->key.ecpoint_len = share_length;
			memcpy(&(share->key.ecpoint), &(msg->msg[offset + read_bytes]),
			       share_length);

			TAILQ_INSERT_TAIL(&(tls->interim_params->share_head),
					  share, link);
			break;

		case TLS_NAMED_GROUP_SECP384R1:
		case TLS_NAMED_GROUP_SECP521R1:
			/* Not implemented */
			break;

		default:
			break;
		}

		group_count++;
		read_bytes += share_length;
	}

	if (read_bytes - length_bytes != list_length) {
		TLS_DPRINTF("keyshare: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_KEYSHARE + 14, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	return read_bytes;
}

static int32_t read_keyshare_in_shello(TLS *tls,
					const struct tls_hs_msg *msg,
					const uint32_t offset) {
	uint32_t read_bytes = 0;

	const uint16_t named_group_bytes = 2;
	if (msg->len < (offset + named_group_bytes)) {
		TLS_DPRINTF("keyshare: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 15,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	enum tls_hs_named_curve namedgroup = tls_util_read_2(
		&(msg->msg[offset]));
	read_bytes += named_group_bytes;

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *    If using (EC)DHE key establishment, servers offer exactly one
	 *    KeyShareEntry in the ServerHello.  This value MUST be in the same
	 *    group as the KeyShareEntry value offered by the client that the
	 *    server has selected for the negotiated key exchange.  Servers
	 *    MUST NOT send a KeyShareEntry for any group not indicated in the
	 *    client's "supported_groups" extension and MUST NOT send a
	 *    KeyShareEntry when using the "psk_ke" PskKeyExchangeMode.  If using
	 *    (EC)DHE key establishment and a HelloRetryRequest containing a
	 *    "key_share" extension was received by the client, the client MUST
	 *    verify that the selected NamedGroup in the ServerHello is the same as
	 *    that in the HelloRetryRequest.  If this check fails, the client MUST
	 *    abort the handshake with an "illegal_parameter" alert.
	 */
	struct tls_hs_key_share *share;
	struct tls_hs_key_share *used_share = NULL;
	TAILQ_FOREACH(share, &(tls->interim_params->share_head), link) {
		if (share->group == namedgroup) {
			used_share = share;
			break;
		}
	}

	if (used_share == NULL) {
		TLS_DPRINTF("keyshare: named group is missing: %d", namedgroup);
		OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 16,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}

	if (tls_hs_check_state(tls, TLS_STATE_HS_AFTER_SEND_2NDCHELLO)
	    == true &&
	    tls->ecdh->namedcurve != namedgroup) {
		TLS_DPRINTF("keyshare: named group doesn't match the one"
			    " in hrr: %d", namedgroup);
		OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 17,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}

	const uint16_t share_length_bytes = 2;
	if (msg->len < (offset + read_bytes + share_length_bytes)) {
		TLS_DPRINTF("keyshare: invalid record length");
		OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 18,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	uint32_t share_length;
	share_length = tls_util_read_2(&(msg->msg[offset + read_bytes]));
	read_bytes += share_length_bytes;

	if (msg->len < (offset + read_bytes + share_length)) {
		TLS_DPRINTF("keyshare: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 19,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	if (share_length > TLS_EXT_KEY_SHARE_EXCHANGE_SIZE_MAX) {
		TLS_DPRINTF("keyshare: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 20,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/* TODO: add procedure for Finite-Field-Group */
	struct tls_hs_ecdh_key key;
	switch (namedgroup) {
	case TLS_NAMED_GROUP_SECP256R1:
	case TLS_NAMED_GROUP_X25519:
	case TLS_NAMED_GROUP_X448:
		if (share_length > sizeof(key.ecpoint)) {
			TLS_DPRINTF("keyshare: ecpoint size overflow");
			OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_KEYSHARE + 21, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		key.ecpoint_len = share_length;
		memcpy(&(key.ecpoint), &(msg->msg[offset + read_bytes]),
		       share_length);

		tls->ecdh->namedcurve = namedgroup;

		if (tls_hs_ecdhkey_set_peer_pubkey(tls->ecdh, &key) != true) {
			TLS_DPRINTF("keyshare: tls_hs_ecdhkey_set_peer_pubkey");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		/* The private key's memory will be released at tls_hs_ecdh_free() */
		tls->ecdh->my_prvkey = used_share->key.prv;
		tls->ecdh->prvkey_ephemeral = true;

		TAILQ_REMOVE(&(tls->interim_params->share_head), used_share, link);
		used_share->key.prv = NULL;
		tls_hs_keyshare_free(used_share);

		while (!TAILQ_EMPTY(&(tls->interim_params->share_head))) {
			share = TAILQ_FIRST(&(tls->interim_params->share_head));
			TAILQ_REMOVE(&(tls->interim_params->share_head), share, link);
			tls_hs_keyshare_free(share);
		}

		TLS_DPRINTF("keyshare: Decided: %d", namedgroup);
		TLS_DPRINTF("keyshare: ecpoint_len = %d", key.ecpoint_len);
		break;

	case TLS_NAMED_GROUP_SECP384R1:
	case TLS_NAMED_GROUP_SECP521R1:
		/* Not implemented */
		break;

	default:
		break;
	}

	return read_bytes;
}

static int32_t read_keyshare_in_hrr(TLS *tls,
				    const struct tls_hs_msg *msg,
				    const uint32_t offset) {
	uint32_t read_bytes = 0;

	/*
	 * RFC8446 4.2.8.  Key Share
	 *
	 *    Upon receipt of this extension in a HelloRetryRequest, the client
	 *    MUST verify that (1) the selected_group field corresponds to a group
	 *    which was provided in the "supported_groups" extension in the
	 *    original ClientHello and (2) the selected_group field does not
	 *    correspond to a group which was provided in the "key_share" extension
	 *    in the original ClientHello.  If either of these checks fails, then
	 *    the client MUST abort the handshake with an "illegal_parameter"
	 *    alert.  Otherwise, when sending the new ClientHello, the client MUST
	 *    replace the original "key_share" extension with one containing only a
	 *    new KeyShareEntry for the group indicated in the selected_group field
	 *    of the triggering HelloRetryRequest.
	 */
	const uint16_t named_group_bytes = 2;
	if (msg->len < (offset + named_group_bytes)) {
		TLS_DPRINTF("keyshare: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 22,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	enum tls_hs_named_curve namedgroup = tls_util_read_2(
		&(msg->msg[offset]));
	read_bytes += named_group_bytes;

	struct tls_hs_ecc_eclist glist;
	tls_hs_ecc_get_supported_eclist(&glist);
	for (int i = 0; i < glist.len; i++) {
		if (glist.list[i] == namedgroup) {
			goto found;
		}
	}

	TLS_DPRINTF("keyshare: missing named group in supported groups"
		    " of first client hello");
	OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
		     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 23,
		     NULL);
	TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
	return -1;

found:
	;
	struct tls_hs_key_share *share;
	TAILQ_FOREACH(share, &(tls->interim_params->share_head), link) {
		if (share->group == namedgroup) {
			TLS_DPRINTF("keyshare: named group sent in key share of"
				    " first client hello");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_KEYSHARE + 24, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return -1;
		}
	}

	tls->ecdh->namedcurve = namedgroup;

	return read_bytes;
}

struct tls_hs_key_share *tls_hs_keyshare_init(void)
{
	struct tls_hs_key_share *share;

	if ((share = calloc(1, sizeof(struct tls_hs_key_share))) == NULL) {
		TLS_DPRINTF("keyshare: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_CALLOC,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 25,
			     NULL);
		return NULL;
	}

	return share;
}

bool tls_hs_keyshare_generate(struct tls_hs_key_share *share,
			       enum tls_hs_named_curve group)
{
	share->group = group;
	if (tls_hs_ecdhkey_gen_for_server(group, &(share->key)) != true) {
		TLS_DPRINTF("keyshare: tls_hs_ecdhkey_gen_for_server failure");
		return false;
	}

	if (tls_hs_ecdhkey_set_to_ecpoint(&(share->key)) != true) {
		TLS_DPRINTF("keyshare: tls_hs_ecdhkey_set_to_ecpoint failure");
		return false;
	}

	return true;
}

void tls_hs_keyshare_free(struct tls_hs_key_share *share)
{
	if (share == NULL) {
		return;
	}

	if (share->key.prv != NULL) {
		Key_free(share->key.prv);
	}

	if (share->key.pub != NULL) {
		Key_free(share->key.pub);
	}

	free(share);
}

int32_t tls_hs_keyshare_write(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t off = 0;
	uint16_t version;
	bool *recv_exts = tls->interim_params->recv_ext_flags;

	switch (msg->type) {
	case TLS_HANDSHAKE_CLIENT_HELLO:
		if (tls_util_check_version_in_supported_version(
			&(tls->supported_versions), TLS_VER_TLS13) == false) {
			return 0;
		}
		break;

	case TLS_HANDSHAKE_SERVER_HELLO:
		version = tls_util_convert_protover_to_ver(
			&(tls->negotiated_version));
		switch (version) {
		case TLS_VER_TLS13:
			/* TODO: also check psk mode is used or not. */
			if (recv_exts[TLS_EXT_KEY_SHARE] == false) {
				TLS_DPRINTF("keyshare: didn't receive in client"
					    " hello");
				return 0;
			}
			break;

		case TLS_VER_SSL30:
		case TLS_VER_TLS10:
		case TLS_VER_TLS11:
		case TLS_VER_TLS12:
		default:
			return 0;
		}
		break;

	default:
		return 0;
	}

	const uint32_t type_bytes = 2;
	if (tls_hs_msg_write_2(msg, TLS_EXT_KEY_SHARE) == false) {
		return -1;
	}
	off += type_bytes;

	/* write dummy length bytes. */
	int32_t pos = msg->len;

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	off += len_bytes;

	int32_t act_len;
	switch (msg->type) {
	case TLS_HANDSHAKE_SERVER_HELLO:
		if (tls_hs_check_state(tls, TLS_STATE_HS_BEFORE_SEND_HRREQ)
		    == true) {
			if ((act_len = write_keyshare_for_hrr(tls, msg)) < 0) {
				return -1;
			}
		} else {
			if ((act_len = write_keyshare_for_shello(tls, msg))
			    < 0) {
				return -1;
			}
		}
		break;

	case TLS_HANDSHAKE_CLIENT_HELLO:
		if (tls_hs_check_state(tls, TLS_STATE_HS_BEFORE_SEND_2NDCHELLO)
		    == true) {
			if ((act_len = write_keyshare_for_2ndchello(tls, msg))
			    < 0) {
				return -1;
			}
		} else {
			if ((act_len = write_keyshare_for_chello(tls, msg))
			    < 0) {
				return -1;
			}
		}
		break;

	default:
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 26,
			     NULL);
		return -1;
	}
	off += act_len;

	const int32_t extlen_max = TLS_EXT_SIZE_MAX;
	if (act_len > extlen_max) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 34,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), act_len);

	return off;
}

int32_t tls_hs_keyshare_read(TLS *tls, const struct tls_hs_msg *msg,
			      const uint32_t offset)
{
	uint32_t read_bytes = 0;
	int32_t len;

	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_TLS13:
		break;

	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
	default:
		return 0;
	}

	bool *sent_exts = tls->interim_params->sent_ext_flags;
	switch (msg->type) {
	case TLS_HANDSHAKE_SERVER_HELLO:
		if (sent_exts[TLS_EXT_KEY_SHARE] == false) {
			TLS_DPRINTF("keyshare: didn't send in client hello");
			OK_set_error(ERR_ST_TLS_UNSUPPORTED_EXTENSION,
				     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 27,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNSUPPORTED_EXTENSION);
			return -1;
		}

		if (tls_hs_check_state(tls, TLS_STATE_HS_AFTER_RECV_HRREQ)
		    == true) {
			if ((len = read_keyshare_in_hrr(tls, msg, offset))
			    < 0) {
				TLS_DPRINTF("keyshare: read_keyshare_in_hrr");
				return -1;
			}
		} else {
			if ((len = read_keyshare_in_shello(tls, msg, offset))
			    < 0) {
				TLS_DPRINTF("keyshare: read_keyshare_in_shello");
				return -1;
			}
		}
		break;

	case TLS_HANDSHAKE_CLIENT_HELLO:
		if ((len = read_keyshare_in_chello(tls, msg, offset)) < 0) {
			TLS_DPRINTF("keyshare: read_keyshare_in_chello");
			return -1;
		}
		break;

	default:
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_KEYSHARE + 28,
			     NULL);
		return -1;
	}
	read_bytes += len;

	return read_bytes;
}
