/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_alert.h"

/**
 * check whether the received extension is available in this module.
 *
 * TODO: i think this function should be moved to other file.
 * (e.g. handshake/extension/extension.c?).
 */

static bool check_ext_availability_tls13(const enum tls_extension_type type);

static bool check_ext_availability(TLS *tls, const enum tls_extension_type type);

/**
 * handle extensions stored in list.
 */
static bool read_ext_list(TLS *tls,
			  const enum tls_extension_type type,
			  const struct tls_hs_msg *msg,
			  const uint32_t offset);

/**
 * write encrypted extensions data to message structure.
 */
static int32_t write_encrypted_extensions(TLS *tls,
					 struct tls_hs_msg *msg);

/**
 * read encrypted extensions data from message structure.
 */
static int32_t read_encrypted_extensions(TLS *tls, struct tls_hs_msg *msg,
					 uint32_t offset);

static bool check_ext_availability_tls13(const enum tls_extension_type type) {
	switch(type) {
	case TLS_EXT_SERVER_NAME:
	case TLS_EXT_MAX_FRAGMENT_LENGTH:
	case TLS_EXT_SUPPORTED_GROUPS:
	case TLS_EXT_USE_SRTP:
	case TLS_EXT_HEARTBEAT:
	case TLS_EXT_APP_LAYER_PROTO_NEGOTIATION:
	case TLS_EXT_CLIENT_CERTIFICATE_TYPE:
	case TLS_EXT_SERVER_CERTIFICATE_TYPE:
	case TLS_EXT_EARLY_DATA:
		return true;

	default:
		break;
	}

	return false;
}

static bool check_ext_availability(TLS *tls,
					 const enum tls_extension_type type) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch(version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
		/* Not supported */
		break;

	case TLS_VER_TLS13:
		return check_ext_availability_tls13(type);

	default:
		/* Unknown version */
		break;
	}

	return false;
}

static bool read_ext_list(TLS *tls,
			  const enum tls_extension_type type,
			  const struct tls_hs_msg *msg,
			  const uint32_t offset)
{
	/*
	 * assume unknown extensions never come because check
	 * is performed before this function is called.
	 */
	/*
	 * RFC8446 4.2.  Extensions
	 *
	 *                        There MUST NOT be more than one extension of the
	 *    same type in a given extension block.
	 */
	bool *recv_exts = tls->interim_params->recv_ext_flags;
	if (recv_exts[type] == true) {
		TLS_DPRINTF("encext: extensions of same type come multiple times");
		OK_set_error(ERR_ST_TLS_SAME_TYPE_EXTENSION,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_ENCEXT + 0, NULL);
		return false;
	}
	recv_exts[type] = true;

	switch(type) {
	case TLS_EXT_SERVER_NAME:
		if (tls_hs_servername_read(tls, msg, offset) < 0) {
			return false;
		}
		return true;

	case TLS_EXT_SUPPORTED_GROUPS:
		if (tls_hs_ecc_read_elliptic_curves(tls, msg, offset) < 0) {
			return false;
		}
		return true;

	default:
		assert(!"unknown extension type");
	}

	return false;
}

static int32_t write_encrypted_extensions(TLS *tls,
					 struct tls_hs_msg *msg) {
	int32_t offset = 0;

	bool *sent_exts = tls->interim_params->sent_ext_flags;

	int32_t server_name_len;
	server_name_len = tls_hs_servername_write(tls, msg);
	if (server_name_len < 0) {
		return -1;
	} else if (server_name_len > 0) {
		sent_exts[TLS_EXT_SERVER_NAME] = true;
	}
	offset += server_name_len;

	int32_t supported_groups_len;
	supported_groups_len = tls_hs_ecc_write_elliptic_curves(tls, msg);
	if (supported_groups_len < 0) {
		return -1;
	} else if (supported_groups_len > 0) {
		sent_exts[TLS_EXT_ELLIPTIC_CURVES] = true;
	}
	offset += supported_groups_len;

	return offset;
}

static int32_t read_encrypted_extensions(TLS *tls, struct tls_hs_msg *msg,
					 uint32_t offset) {
	int32_t extlen;
	if ((extlen = tls_hs_extension_parse(tls, msg, offset)) < 0) {
		TLS_DPRINTF("tls_hs_extension_parse");
		return -1;
	}

	/*
	 * RFC8446 4.3.1.  Encrypted Extensions
	 *
	 *       struct {
	 *           Extension extensions<0..2^16-1>;
	 *       } EncryptedExtensions;
	 */
	const int32_t extlen_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (extlen > extlen_max) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_ENCEXT + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	return extlen;
}

static bool interpret_ext_list(TLS *tls) {
	struct tls_hs_interim_params *params = tls->interim_params;
	struct tls_extension *ext;
	struct tls_hs_msg msg;

	/*
	 * RFC8446 4.3.1.  Encrypted Extensions
	 *
	 *                   The client MUST check EncryptedExtensions for the
	 *    presence of any forbidden extensions and if any are found MUST abort
	 *    the handshake with an "illegal_parameter" alert.
	 */
	TAILQ_FOREACH(ext, &(params->head), link) {
		if (check_ext_availability(tls, ext->type) == false) {
			TLS_ALERT_FATAL(tls,
				TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return false;
		}

		msg.type = TLS_HANDSHAKE_ENCRYPTED_EXTENSIONS;
		msg.len = ext->len;
		msg.max = ext->len;
		msg.msg = ext->opaque;

		if (! read_ext_list(tls, ext->type, &msg, 0)) {
			/* alerts is sent by internal of tls_hs_ext_read. */
			return false;
		}
	}

	while (!TAILQ_EMPTY(&(params->head))) {
		ext = TAILQ_FIRST(&(params->head));
		TAILQ_REMOVE(&(params->head), ext, link);
		tls_extension_free(ext);
	}

	return true;
}

struct tls_hs_msg *tls_hs_encext_compose(TLS *tls) {
	uint32_t offset = 0;

	struct tls_hs_msg *msg;
	int32_t pos;
	const int32_t length_bytes = 2;

	if ((msg = tls_hs_msg_init()) == NULL) {
		TLS_DPRINTF("tls_hs_msg_init");
		return NULL;
	}

	/* EncryptedExtensions message has following structure.
	 *
	 * | type                 (1) |
	 * | length of message    (3) |
	 * | extension length     (2) |
	 * | extension            (n) |
	 */

	msg->type = TLS_HANDSHAKE_ENCRYPTED_EXTENSIONS;

	/* write dummy length bytes */
	pos = msg->len;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		goto failed;
	}
	offset += length_bytes;

	int32_t extlen = 0;

	if ((extlen = write_encrypted_extensions(tls, msg)) < 0) {
		goto failed;
	}

	/*
	 * RFC8446 4.3.1.  Encrypted Extensions
	 *
	 *       struct {
	 *           Extension extensions<0..2^16-1>;
	 *       } EncryptedExtensions;
	 */
	const int32_t extensions_length_min = 0;
	const int32_t extensions_length_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (extlen < extensions_length_min || extensions_length_max < extlen) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_ENCEXT + 4, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		goto failed;
	}

	offset += extlen;

	/* write extension length. */
	tls_util_write_2(&(msg->msg[pos]), extlen);

	return msg;

failed:
	tls_hs_msg_free(msg);
	return NULL;
}

bool tls_hs_encext_parse(TLS *tls, struct tls_hs_msg *msg) {
	uint32_t offset = 0;

	if (msg->type != TLS_HANDSHAKE_ENCRYPTED_EXTENSIONS) {
		OK_set_error(ERR_ST_TLS_UNEXPECTED_MSG,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_ENCEXT + 1, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
		return false;
	}

	int32_t extlen;
	if ((extlen = read_encrypted_extensions(tls, msg, offset)) < 0) {
		return false;
	}
	offset += extlen;

	if (msg->len != offset) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_ENCEXT + 2, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return false;
	}

	return true;
}

bool tls_hs_encext_interpret(TLS *tls) {
	if (! interpret_ext_list(tls)) {
		return false;
	}

	return true;
}
