/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_cipher.h"
#include "tls_compress.h"
#include "tls_alert.h"

#include <string.h>

/**
 * check whether the received server hello extension is available in
 * this module.
 *
 * TODO: i think this function should be moved to other file.
 * (e.g. handshake/extension/extension.c?).
 */
static bool check_ext_availability_tls12(const enum tls_extension_type type);

static bool check_ext_availability_tls13_shello(const enum tls_extension_type type);

static bool check_ext_availability_tls13_hrr(const enum tls_extension_type type);

static bool check_ext_availability(TLS *tls, const enum tls_extension_type type);

/**
 * write negotiated version to the send data.
 */
static int32_t write_version(TLS *tls, struct tls_hs_msg *msg);

/**
 * write server random to the send data.
 */
static int32_t write_server_random(TLS *tls, struct tls_hs_msg *msg);

/**
 * write session id to the send data.
 */
static int32_t write_session_id(TLS *tls, struct tls_hs_msg *msg);

/**
 * write selected cipher suite by server to the send data.
 */
static int32_t write_cipher_suite(const TLS *tls, struct tls_hs_msg *msg);

/**
 * write selected compression algorithm by server to the send data.
 */
static int32_t write_cmp_method(const TLS *tls, struct tls_hs_msg *msg);

/**
 * write extension to the send data.
 */
static int32_t write_extension(TLS *tls, struct tls_hs_msg *msg);

/**
 * read selected version by server from the received handshake.
 */
static inline int32_t read_version(TLS *tls,
				   const struct tls_hs_msg *msg,
				   const uint32_t offset);

/**
 * read server random from the received handshake.
 */
static inline int32_t read_server_random(TLS *tls,
					 const struct tls_hs_msg *msg,
					 const uint32_t offset);

/**
 * read session id from the received handshake.
 *
 * if client sent a filled session id in the client hello, also check
 * whether the resumption was accepted.
 */
static inline int32_t read_session_id(TLS *tls,
				      const struct tls_hs_msg *msg,
				      const uint32_t offset);

/**
 * read selected cipher suite by server from the received handshake.
 */
static inline int32_t read_cipher_suite(TLS *tls,
					const struct tls_hs_msg *msg,
					const uint32_t offset);

/**
 * read selected compression algorithm by server from the received
 * handshake.
 */
static inline int32_t read_cmp_method(TLS *tls,
				      const struct tls_hs_msg *msg,
				      const uint32_t offset);

/**
 * read extension from the received handshake.
 */
static bool read_ext_list(TLS *tls,
			  const enum tls_extension_type type,
			  const struct tls_hs_msg *msg,
			  const uint32_t offset);

/**
 * interpret server random stored in tls structure.
 */
static bool interpret_server_random(TLS *tls);

/**
 * interpret session id stored in tls_hs_interim_params.
 */
static bool interpret_session_id_up_to_tls12(TLS *tls);
static bool interpret_session_id_tls13(TLS *tls);
static bool interpret_session_id(TLS *tls);

/**
 * interpret cipher suite in tls_hs_interim_params.
 */
static bool interpret_cipher_suite(TLS *tls);

/**
 * interpret compression method stored in tls_hs_interim_params.
 */
static bool interpret_cmp_method(TLS *tls);

/**
 * interpret extensions stored in tls_hs_interim_params.
 */
static bool interpret_ext_list(TLS *tls);

static uint8_t tls_hs_special_random_value_for_hrr[] = {
	0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
	0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
	0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
	0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C
};

static uint8_t tls_hs_special_random_value_for_tls11[] = {
	0x44, 0x4F, 0x57, 0x4E, 0x47, 0x52, 0x44, 0x00
};

static uint8_t tls_hs_special_random_value_for_tls12[] = {
	0x44, 0x4F, 0x57, 0x4E, 0x47, 0x52, 0x44, 0x01
};

static bool check_ext_availability_tls12(const enum tls_extension_type type) {
	switch(type) {
	case TLS_EXT_SERVER_NAME:
	case TLS_EXT_EC_POINT_FORMATS:
		return true;

	default:
		break;
	}

	return false;
}

static bool check_ext_availability_tls13_shello(const enum tls_extension_type type) {
	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *    extensions:  A list of extensions.  The ServerHello MUST only include
	 *       extensions which are required to establish the cryptographic
	 *       context and negotiate the protocol version.  All TLS 1.3
	 *       ServerHello messages MUST contain the "supported_versions"
	 *       extension.  Current ServerHello messages additionally contain
	 *       either the "pre_shared_key" extension or the "key_share"
	 *       extension, or both (when using a PSK with (EC)DHE key
	 *       establishment).  Other extensions (see Section 4.2) are sent
	 *       separately in the EncryptedExtensions message.
	 */
	switch(type) {
	case TLS_EXT_PRE_SHARED_KEY:
	case TLS_EXT_SUPPORTED_VERSIONS:
	case TLS_EXT_KEY_SHARE:
		return true;

	default:
		break;
	}

	return false;
}

static bool check_ext_availability_tls13_hrr(const enum tls_extension_type type) {
	/*
	 * RFC8446 4.1.4.  Hello Retry Request
	 *
	 *    The server's extensions MUST contain "supported_versions".
	 *    Additionally, it SHOULD contain the minimal set of extensions
	 *    necessary for the client to generate a correct ClientHello pair.  As
	 *    with the ServerHello, a HelloRetryRequest MUST NOT contain any
	 *    extensions that were not first offered by the client in its
	 *    ClientHello, with the exception of optionally the "cookie" (see
	 *    Section 4.2.2) extension.
	 */
	/*
	 * RFC8446 4.1.4.  Hello Retry Request
	 *
	 *    Otherwise, the client MUST process all extensions in the
	 *    HelloRetryRequest and send a second updated ClientHello.  The
	 *    HelloRetryRequest extensions defined in this specification are:
	 *
	 *    -  supported_versions (see Section 4.2.1)
	 *
	 *    -  cookie (see Section 4.2.2)
	 *
	 *    -  key_share (see Section 4.2.8)
	 */
	switch(type) {
	case TLS_EXT_SUPPORTED_VERSIONS:
	case TLS_EXT_COOKIE:
	case TLS_EXT_KEY_SHARE:
		return true;

	default:
		break;
	}

	return false;
}

static bool check_ext_availability(TLS *tls,
					 const enum tls_extension_type type) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch(version) {
	case TLS_VER_SSL30:
		/* Not supported */
		break;

	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
		/* Not implemented */
		break;

	case TLS_VER_TLS12:
		return check_ext_availability_tls12(type);

	case TLS_VER_TLS13:
		if (tls_hs_check_state(tls, TLS_STATE_HS_AFTER_RECV_HRREQ)) {
			return check_ext_availability_tls13_hrr(type);
		} else {
			return check_ext_availability_tls13_shello(type);
		}

	default:
		/* Unknown version */
		break;
	}

	return false;
}

static int32_t write_version(TLS *tls, struct tls_hs_msg *msg) {
	const int32_t version_length = 2;
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
		if (! tls_hs_msg_write_1(msg, tls->negotiated_version.major)) {
			return -1;
		}

		if (! tls_hs_msg_write_1(msg, tls->negotiated_version.minor)) {
			return -1;;
		}
		break;

	case TLS_VER_TLS13:
		/*
		 * RFC8446 4.1.3.  Server Hello
		 *
		 *    legacy_version:  In previous versions of TLS, this field was used for
		 *       version negotiation and represented the selected version number
		 *       for the connection.  Unfortunately, some middleboxes fail when
		 *       presented with new values.  In TLS 1.3, the TLS server indicates
		 *       its version using the "supported_versions" extension
		 *       (Section 4.2.1), and the legacy_version field MUST be set to
		 *       0x0303, which is the version number for TLS 1.2.  (See Appendix D
		 *       for details about backward compatibility.)
		 */
		if (! tls_hs_msg_write_2(msg, TLS_VER_TLS12)) {
			return -1;
		}
		break;

	default:
		break;
	}

	return version_length;
}

static int32_t write_server_random(TLS *tls, struct tls_hs_msg *msg) {
	const uint32_t len = 32;

	if (! tls_util_get_random(&(tls->server_random[0]), len)) {
		return -1;
	}

	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *    TLS 1.3 has a downgrade protection mechanism embedded in the server's
	 *    random value.  TLS 1.3 servers which negotiate TLS 1.2 or below in
	 *    response to a ClientHello MUST set the last 8 bytes of their Random
	 *    value specially in their ServerHello.
	 *
	 *    If negotiating TLS 1.2, TLS 1.3 servers MUST set the last 8 bytes of
	 *    their Random value to the bytes:
	 *
	 *      44 4F 57 4E 47 52 44 01
	 *
	 *    If negotiating TLS 1.1 or below, TLS 1.3 servers MUST, and TLS 1.2
	 *    servers SHOULD, set the last 8 bytes of their ServerHello.Random
	 *    value to the bytes:
	 *
	 *      44 4F 57 4E 47 52 44 00
	 */
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	uint32_t protrand_len = 8;
	uint8_t *protrand;
	uint8_t *buf = &(tls->server_random[len - protrand_len]);

	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
		protrand = tls_hs_special_random_value_for_tls11;
		memcpy(buf, protrand, protrand_len);
		break;

	case TLS_VER_TLS12:
		protrand = tls_hs_special_random_value_for_tls12;
		memcpy(buf, protrand, protrand_len);
		break;

	case TLS_VER_TLS13:
		if (tls_hs_check_state(tls, TLS_STATE_HS_BEFORE_SEND_HRREQ)
		    == true) {
			memcpy(tls->server_random,
			       tls_hs_special_random_value_for_hrr, len);
		}
		break;

	default:
		break;
	}

	if (! tls_hs_msg_write_n(msg, tls->server_random, len)) {
		return -1;
	}

	return len;
}

static int32_t write_session_id(TLS *tls, struct tls_hs_msg *msg) {
	uint8_t *session;
	uint8_t seslen;
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch (version) {
	case TLS_VER_TLS13:
		seslen = tls->interim_params->seslen;
		session = tls->interim_params->session;
		break;

	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
	default:
		if (tls->resession == false) {
			seslen = 32;
			uint8_t session_tls12[seslen];

			if (! tls_util_get_random(&(session_tls12[0]), seslen)) {
				OK_set_error(ERR_ST_TLS_GET_RANDOM, ERR_LC_TLS3,
					     ERR_PT_TLS_HS_MSG_SHELLO + 0, NULL);
				return -1;
			}

			/* save session id. */
			tls->pending->session_id_length = seslen;
			memcpy(&(tls->pending->session_id[0]), &(session_tls12), seslen);
		}

		seslen = tls->pending->session_id_length;
		session = &(tls->pending->session_id[0]);
		break;
	}

	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *           opaque legacy_session_id_echo<0..32>;
	 */
	const uint8_t seslen_max = 32;
	if (seslen > seslen_max) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 9, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/* length of session id. */
	if (! tls_hs_msg_write_1(msg, seslen)) {
		return -1;
	}

	/* generate session id. */
	if (! tls_hs_msg_write_n(msg, session, seslen)) {
		return -1;
	}

	return 1 + seslen;
}

static int32_t write_cipher_suite(const TLS *tls, struct tls_hs_msg *msg) {
	const int32_t cipher_suite_length = 2;

	if (! tls_hs_msg_write_2(msg, tls->pending->cipher_suite)) {
		return -1;
	}

	return cipher_suite_length;
}

static int32_t write_cmp_method(const TLS *tls,
				struct tls_hs_msg *msg) {
	const int32_t compression_algorithm_length = 1;

	if (! tls_hs_msg_write_1(msg, tls->pending->compression_algorithm)) {
		return -1;
	}

	return compression_algorithm_length;
}

static int32_t write_extension_shello(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t off = 0;

	/* write dummy length bytes */
	const int32_t pos = msg->len;
	const int32_t length_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	off += length_bytes;

	int32_t extlen = 0;
	bool *sent_exts = tls->interim_params->sent_ext_flags;

	int32_t server_name_len = 0;
	server_name_len = tls_hs_servername_write(tls, msg);
	if (server_name_len < 0) {
		return -1;
	} else if (server_name_len > 0) {
		sent_exts[TLS_EXT_SERVER_NAME] = true;
	}
	extlen += server_name_len;

	int32_t ec_point_formats_len = 0;
	ec_point_formats_len = tls_hs_ecc_write_ec_point_formats(tls, msg);
	if (ec_point_formats_len < 0) {
		return -1;
	} else if (ec_point_formats_len > 0) {
		sent_exts[TLS_EXT_EC_POINT_FORMATS] = true;
	}
	extlen += ec_point_formats_len;

	int32_t supported_versions_len = 2;
	supported_versions_len = tls_hs_supported_versions_write(tls, msg);
	if (supported_versions_len < 0) {
		return -1;
	} else if (supported_versions_len > 0) {
		sent_exts[TLS_EXT_SUPPORTED_VERSIONS] = true;
	}
	extlen += supported_versions_len;

	int32_t key_share_len;
	key_share_len = tls_hs_keyshare_write(tls, msg);
	if (key_share_len < 0) {
		return -1;
	} else if (key_share_len > 0) {
		sent_exts[TLS_EXT_KEY_SHARE] = true;
	}
	extlen += key_share_len;

	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *           Extension extensions<6..2^16-1>;
	 */
	const int32_t extlen_min = 6;
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_TLS13:
		if (extlen < extlen_min) {
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_SHELLO2 + 7,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

	default:
		break;
	}

	const int32_t extlen_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (extlen > extlen_max) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_SHELLO2 + 8, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/* write extension length. */
	tls_util_write_2(&(msg->msg[pos]), extlen);

	off += extlen;

	/* if any extension do not be written, revert the dummy length
	 * bytes that was written in the first. */
	if (off == length_bytes) {
		off      -= length_bytes;
		msg->len -= length_bytes;
	}

	return off;
}

static int32_t write_extension_hrr(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t off = 0;

	/* write dummy length bytes */
	const int32_t pos = msg->len;
	const int32_t length_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	off += length_bytes;

	int32_t extlen = 0;
	bool *sent_exts = tls->interim_params->sent_ext_flags;

	int32_t supported_versions_len = 2;
	supported_versions_len = tls_hs_supported_versions_write(tls, msg);
	if (supported_versions_len < 0) {
		return -1;
	} else if (supported_versions_len > 0) {
		sent_exts[TLS_EXT_SUPPORTED_VERSIONS] = true;
	}
	extlen += supported_versions_len;

	/*
	 * TODO: write cookie extension if it is needed. server can pack
	 * negotiation information into cookie extension to operate statelessly.
	 * aicrypto implementation operates statefully now, cookie extension has
	 * no meaning.
	 */

	int32_t key_share_len;
	key_share_len = tls_hs_keyshare_write(tls, msg);
	if (key_share_len < 0) {
		return -1;
	} else if (key_share_len > 0) {
		sent_exts[TLS_EXT_KEY_SHARE] = true;
	}
	extlen += key_share_len;

	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *           Extension extensions<6..2^16-1>;
	 */
	const int32_t extensions_length_min = 6;
	const int32_t extensions_length_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (extlen < extensions_length_min || extensions_length_max < extlen) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_SHELLO2 + 9, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/* write extension length. */
	tls_util_write_2(&(msg->msg[pos]), extlen);

	off += extlen;

	/* if any extension do not be written, revert the dummy length
	 * bytes that was written in the first. */
	if (off == length_bytes) {
		off      -= length_bytes;
		msg->len -= length_bytes;
	}

	return off;
}

static int32_t write_extension(TLS *tls, struct tls_hs_msg *msg)
{
	if (tls_hs_check_state(tls, TLS_STATE_HS_BEFORE_SEND_HRREQ) == true) {
		return write_extension_hrr(tls, msg);
	} else {
		return write_extension_shello(tls, msg);
	}
}

static inline int32_t read_version(TLS *tls,
				   const struct tls_hs_msg *msg,
				   const uint32_t offset) {
	uint32_t read_bytes = 0;

	const int verlen = 2;
	if (msg->len < (offset + verlen)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 1, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	tls->interim_params->version.major = msg->msg[offset + 0];
	tls->interim_params->version.minor = msg->msg[offset + 1];

	read_bytes += verlen;

	return read_bytes;
}

static inline int32_t read_server_random(TLS *tls,
					 const struct tls_hs_msg *msg,
					 const uint32_t offset) {
	uint32_t read_bytes = 0;

	const int32_t randlen = 32;
	if (msg->len < (offset + randlen)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *    TLS 1.3 clients receiving a ServerHello indicating TLS 1.2 or below
	 *    MUST check that the last 8 bytes are not equal to either of these
	 *    values.  TLS 1.2 clients SHOULD also check that the last 8 bytes are
	 *    not equal to the second value if the ServerHello indicates TLS 1.1 or
	 *    below.  If a match is found, the client MUST abort the handshake with
	 *    an "illegal_parameter" alert.  This mechanism provides limited
	 *    protection against downgrade attacks over and above what is provided
	 *    by the Finished exchange: because the ServerKeyExchange, a message
	 *    present in TLS 1.2 and below, includes a signature over both random
	 *    values, it is not possible for an active attacker to modify the
	 *    random values without detection as long as ephemeral ciphers are
	 *    used.  It does not provide downgrade protection when static RSA
	 *    is used.
	 */
	uint8_t protrand_len = 8;
	uint8_t *protrand;
	uint8_t *buf = &(msg->msg[offset+randlen-protrand_len]);
	bool tls13clnt = tls_util_check_version_in_supported_version(
		&(tls->supported_versions), TLS_VER_TLS13);
	if (tls13clnt == true) {
		protrand = tls_hs_special_random_value_for_tls12;
		if (memcmp(protrand, buf, protrand_len) == 0) {
			OK_set_error(ERR_ST_TLS_PROTOCOL_VERSION_DOWNGRADE,
				     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_SHELLO2 + 0,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return -1;
		}
	}

	bool tls12clnt = tls_util_check_version_in_supported_version(
		&(tls->supported_versions), TLS_VER_TLS12);
	if (tls13clnt == true || tls12clnt == true) {
		protrand = tls_hs_special_random_value_for_tls11;
		if (memcmp(protrand, buf, protrand_len) == 0) {
			OK_set_error(ERR_ST_TLS_PROTOCOL_VERSION_DOWNGRADE,
				     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_SHELLO2 + 1,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return -1;
		}
	}

	memcpy(&(tls->server_random[0]), &(msg->msg[offset]), randlen);
	read_bytes += randlen;

	return read_bytes;
}

static inline int32_t read_session_id(TLS *tls,
				      const struct tls_hs_msg *msg,
				      const uint32_t offset) {
	uint32_t read_bytes = 0;

	const uint32_t length_bytes = 1;
	if (msg->len < (offset + read_bytes + length_bytes)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 4, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	uint8_t len = msg->msg[offset];
	read_bytes += length_bytes;

	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *           opaque legacy_session_id_echo<0..32>;
	 */
	const uint8_t len_max = 32;
	if (len > len_max) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 2, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	if (msg->len < (offset + read_bytes + len)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 5, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	tls->interim_params->seslen = len;
	read_bytes += len;

	if (len > 0) {
		uint8_t *session = &(msg->msg[offset + length_bytes]);
		tls->interim_params->session = malloc(1 * len);
		if (tls->interim_params->session == NULL) {
			OK_set_error(ERR_ST_TLS_MALLOC,
				     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_SHELLO2 + 2,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}
		memcpy(tls->interim_params->session, session, len);
	}

	return read_bytes;
}

static inline int32_t read_cipher_suite(TLS *tls,
					const struct tls_hs_msg *msg,
					const uint32_t offset) {
	uint32_t read_bytes = 0;

	const uint32_t cslen = 2;
	if (msg->len < (offset + cslen)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 7, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	tls->interim_params->cipher_suite = tls_util_read_2(&(msg->msg[offset]));
	read_bytes += cslen;

	return read_bytes;
}

static inline int32_t read_cmp_method(TLS *tls,
				      const struct tls_hs_msg *msg,
				      const uint32_t offset) {
	uint32_t read_bytes = 0;

	const uint32_t cmplen = 1;
	if (msg->len < (offset + cmplen)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 8, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	tls->interim_params->cmplen = cmplen;
	if ((tls->interim_params->cmplist = malloc(1 * cmplen)) == NULL) {
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_SHELLO2 + 3,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	memcpy(tls->interim_params->cmplist, &(msg->msg[offset]), cmplen);
	read_bytes += cmplen;

	return cmplen;
}

static bool read_ext_list(TLS *tls,
			  const enum tls_extension_type type,
			  const struct tls_hs_msg *msg,
			  const uint32_t offset)
{
	/*
	 * assume unknown extensions never come because check
	 * is performed before this function is called.
	 */
	/*
	 * RFC8446 4.2.  Extensions
	 *
	 *                        There MUST NOT be more than one extension of the
	 *    same type in a given extension block.
	 */
	bool *recv_exts = tls->interim_params->recv_ext_flags;
	if (recv_exts[type] == true) {
		TLS_DPRINTF("shello: extensions of same type come multiple times");
		return false;
	}
	recv_exts[type] = true;

	switch(type) {
	case TLS_EXT_SERVER_NAME:
		if (tls_hs_servername_read(tls, msg, offset) < 0) {
			return false;
		}
		return true;

	case TLS_EXT_EC_POINT_FORMATS:
		if (tls_hs_ecc_read_point_format(tls, msg, offset) < 0) {
			return false;
		}
		return true;

	case TLS_EXT_SUPPORTED_VERSIONS:
		/*
		 * TODO: dequeue this extension structure earlier.
		 * extensions that was interpreted must not be listed here.
		 */
		return true;

	case TLS_EXT_COOKIE:
		if (tls_hs_cookie_read(tls, msg, offset) < 0) {
			return false;
		}
		return true;

	case TLS_EXT_KEY_SHARE:
		if (tls_hs_keyshare_read(tls, msg, offset) < 0) {
			return false;
		}
		return true;

	default:
		assert(!"unknown extension type");
	}

	return false;
}

static bool interpret_server_random(TLS *tls) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	if (version == TLS_VER_TLS13) {
		/*
		 * RFC8446 4.1.3.  Server Hello
		 *
		 *    Upon receiving a message with type server_hello, implementations MUST
		 *    first examine the Random value and, if it matches this value, process
		 *    it as described in Section 4.1.4).
		 */
		size_t randlen = 32;
		uint8_t *server_rand = &(tls->server_random[0]);
		if (memcmp(server_rand, tls_hs_special_random_value_for_hrr,
			 randlen) == 0) {
			/*
			 * RFC8446 4.1.4.  Hello Retry Request
			 *
			 *                                       If a client receives a second
			 *    HelloRetryRequest in the same connection (i.e., where the ClientHello
			 *    was itself in response to a HelloRetryRequest), it MUST abort the
			 *    handshake with an "unexpected_message" alert.
			 */
			if (tls_hs_check_state(tls,
					       TLS_STATE_HS_AFTER_SEND_2NDCHELLO)
			    == true) {
				TLS_DPRINTF("shello: second hrr comes");
				OK_set_error(ERR_ST_TLS_UNEXPECTED_MSG,
					     ERR_LC_TLS6,
					     ERR_PT_TLS_HS_MSG_SHELLO2 + 4,
					     NULL);
				TLS_ALERT_FATAL(tls,
						TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
				return false;
			}

			tls_hs_change_state(tls, TLS_STATE_HS_AFTER_RECV_HRREQ);
		}
	}

	return true;
}

static bool interpret_session_id_up_to_tls12(TLS *tls) {
	uint8_t len = tls->interim_params->seslen;
	uint8_t session[len];
	memcpy(session, tls->interim_params->session, len);

	/* check re-session. */
	if (tls->pending->session_id_length > 0) {
		/* whether re-session is succeed. */
		if ((len == tls->pending->session_id_length) &&
		    (memcmp(&(tls->pending->session_id[0]),
			    &(session[0]),
			    len) == 0)) {
			tls->resession = true;
			return true;
		}

		/* if re-session was denied by server, generate new
		 * session. */
		tls_session_unrefer(tls->pending);
		if ((tls->pending = tls_session_new()) == NULL) {
			OK_set_error(ERR_ST_TLS_GET_SESSION, ERR_LC_TLS3,
				     ERR_PT_TLS_HS_MSG_SHELLO + 6, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return false;
		}
		tls_session_refer(tls->pending);
	}

	memcpy(&(tls->pending->session_id[0]), session, len);
	tls->pending->session_id_length = len;

	return true;
}

static bool interpret_session_id_tls13(TLS *tls) {
	uint8_t len = tls->interim_params->seslen;
	uint8_t session[len];
	memcpy(session, tls->interim_params->session, len);

	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *    legacy_session_id_echo:  The contents of the client's
	 *       legacy_session_id field.  Note that this field is echoed even if
	 *       the client's value corresponded to a cached pre-TLS 1.3 session
	 *       which the server has chosen not to resume.  A client which
	 *       receives a legacy_session_id_echo field that does not match what
	 *       it sent in the ClientHello MUST abort the handshake with an
	 *       "illegal_parameter" alert.
	 */
	if (tls->pending->session_id_length != len) {
		OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_SHELLO2 + 5, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return false;
	}

	if (len > 0 && memcmp(tls->pending->session_id, session, len) != 0) {
		OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_SHELLO2 + 6, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return false;
	}

	return true;
}

static bool interpret_session_id(TLS *tls) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_TLS13:
		return interpret_session_id_tls13(tls);

	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
	default:
		return interpret_session_id_up_to_tls12(tls);
	}
}

static bool interpret_cipher_suite(TLS *tls) {
	/* TODO: should check that the cipher suite list in sent client
	 * hello has the cipher suite in received server hello. this is
	 * a countermeasure for FREAK attack. */
	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *    cipher_suite:  The single cipher suite selected by the server from
	 *       the list in ClientHello.cipher_suites.  A client which receives a
	 *       cipher suite that was not offered MUST abort the handshake with an
	 *       "illegal_parameter" alert.
	 */
	/*
	 * RFC8446 4.1.4.  Hello Retry Request
	 *
	 *    A client which receives a cipher suite that was not offered MUST
	 *    abort the handshake.
	 */
	if (! tls_cipher_set(tls, tls->interim_params->cipher_suite)) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return false;
	}

	return true;
}

static bool interpret_cmp_method(TLS *tls) {
	/*
	 * RFC8446 4.1.3.  Server Hello
	 *
	 *    legacy_compression_method:  A single byte which MUST have the
	 *       value 0.
	 */
	uint32_t method = tls->interim_params->cmplist[0];
	if (! tls_compress_set(tls, method)) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return false;
	}

	return true;
}

static bool interpret_ext_list(TLS *tls) {
	struct tls_hs_interim_params *params = tls->interim_params;
	struct tls_extension *ext;
	struct tls_hs_msg msg;

	TAILQ_FOREACH(ext, &(params->head), link) {
	/*
	 * RFC5246 7.4.1.4.  Hello Extensions
	 *
	 * An extension type MUST NOT appear in the ServerHello unless the same
	 * extension type appeared in the corresponding ClientHello.  If a
	 * client receives an extension type in ServerHello that it did not
	 * request in the associated ClientHello, it MUST abort the handshake
	 * with an unsupported_extension fatal alert.
	 */
		if (check_ext_availability(tls, ext->type) == false) {
			TLS_ALERT_FATAL(tls,
					TLS_ALERT_DESC_UNSUPPORTED_EXTENSION);
			return false;
		}

		msg.type = TLS_HANDSHAKE_SERVER_HELLO;
		msg.len = ext->len;
		msg.max = ext->len;
		msg.msg = ext->opaque;

		if (! read_ext_list(tls, ext->type, &msg, 0)) {
			/* alerts is sent by internal of tls_hs_ext_read. */
			return false;
		}
	}

	/*
	 * if client receives hello retry request, preserve extensions for
	 * later comparison between client hello messages.
	 */
	if (tls_hs_check_state(tls, TLS_STATE_HS_AFTER_RECV_HRREQ) == true) {
		return true;
	}

	/*
	 * Actions of the receiver:
	 *
	 * A client that receives a ServerHello message containing a Supported
	 * Point Formats Extension MUST respect the server's choice of point
	 * formats during the handshake (cf. Sections 5.6 and 5.7).  If no
	 * Supported Point Formats Extension is received with the ServerHello,
	 * this is equivalent to an extension allowing only the uncompressed
	 * point format.
	 */

	while (!TAILQ_EMPTY(&(params->head))) {
		ext = TAILQ_FIRST(&(params->head));
		TAILQ_REMOVE(&(params->head), ext, link);
		tls_extension_free(ext);
	}

	return true;
}

struct tls_hs_msg * tls_hs_shello_compose(TLS *tls) {
	uint32_t offset = 0;

	struct tls_hs_msg *msg;

	if ((msg = tls_hs_msg_init()) == NULL) {
		TLS_DPRINTF("tls_hs_msg_init");
		return NULL;
	}

	/* serverHello message has following structure.
	 *
	 * | type                 (1) |
	 * | length of message    (3) |
	 * | major version        (1) |
	 * | minor version        (1) |
	 * | server random       (32) |
	 * | session id lengfth   (1) |
	 * | session id           (n) | n >= 0 || n <= 32
	 * | cipher suite         (2) |
	 * | compression method   (1) |
	 * | extension length     (2) |
	 * | extension            (n) |
	 */

	msg->type = TLS_HANDSHAKE_SERVER_HELLO;

	/* write version number string. */
	int32_t verlen;
	if ((verlen = write_version(tls, msg)) < 0) {
		goto failed;
	}
	offset += verlen;

	int32_t randlen;
	if ((randlen = write_server_random(tls, msg)) < 0) {
		goto failed;
	}
	offset += randlen;

	int32_t seslen;
	if ((seslen = write_session_id(tls, msg)) < 0) {
		goto failed;
	}
	offset += seslen;

	int32_t cslen;
	if ((cslen = write_cipher_suite(tls, msg)) < 0) {
		goto failed;
	}
	offset += cslen;

	int32_t cmplen;
	if ((cmplen = write_cmp_method(tls, msg)) < 0) {
		goto failed;
	}
	offset += cmplen;

	int32_t extlen;
	if ((extlen = write_extension(tls, msg)) < 0) {
		goto failed;
	}
	offset += extlen;

	msg->len = offset;

	return msg;

failed:
	tls_hs_msg_free(msg);
	return NULL;
}

bool tls_hs_shello_parse(TLS *tls, struct tls_hs_msg *msg) {
	uint32_t offset = 0;

	if (msg->type != TLS_HANDSHAKE_SERVER_HELLO) {
		TLS_DPRINTF("! TLS_HANDSHAKE_SERVER_HELLO");
		OK_set_error(ERR_ST_TLS_UNEXPECTED_MSG,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 10, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
		return false;
	}

	int32_t verlen;
	if ((verlen = read_version(tls, msg, offset)) < 0) {
		TLS_DPRINTF("read_version");
		return false;
	}
	offset += verlen;

	int32_t randlen;
	if ((randlen = read_server_random(tls, msg, offset)) < 0) {
		TLS_DPRINTF("read_server_random");
		return false;
	}
	offset += randlen;

	int32_t seslen;
	if ((seslen = read_session_id(tls, msg, offset)) < 0) {
		TLS_DPRINTF("read_session_id");
		return false;
	}
	offset += seslen;

	int32_t cslen;
	if ((cslen = read_cipher_suite(tls, msg, offset)) < 0) {
		TLS_DPRINTF("read_cipher_suite");
		return false;
	}
	offset += cslen;

	int32_t cmplen;
	if ((cmplen = read_cmp_method(tls, msg, offset)) < 0) {
		TLS_DPRINTF("read_cmp_method");
		return false;
	}
	offset += cmplen;

	int32_t extlen;
	if ((extlen = tls_hs_extension_parse(tls, msg, offset)) < 0) {
		TLS_DPRINTF("tls_hs_extension_parse");
		return false;
	}
	offset += extlen;

	if (msg->len != offset) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SHELLO + 11, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return false;
	}

	return true;
}

bool tls_hs_shello_interpret(TLS *tls) {
	/*
	 * RFC8446 4.1.4.  Hello Retry Request
	 *
	 *    Upon receipt of a HelloRetryRequest, the client MUST check the
	 *    legacy_version, legacy_session_id_echo, cipher_suite, and
	 *    legacy_compression_method as specified in Section 4.1.3 and then
	 *    process the extensions, starting with determining the version using
	 *    "supported_versions".
	 */
	if (! interpret_server_random(tls)) {
		return false;
	}

	if (! interpret_session_id(tls)) {
		return false;
	}

	if (! interpret_cipher_suite(tls)) {
		return false;
	}

	if (! interpret_cmp_method(tls)) {
		return false;
	}

	if (! interpret_ext_list(tls)) {
		return false;
	}

	return true;
}
