/*
 * Copyright (c) 2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_alert.h"

/**
 * Write the struct SupportedVersions
 *
 * RFC8446 4.2.1.  Supported Versions
 *
 * struct {
 *     select (Handshake.msg_type) {
 *         case client_hello:
 *             ProtocolVersion versions<2..254>;
 *         case server_hello: // and HelloRetryRequest
 *             ProtocolVersion selected_version;
 *     };
 * } SupportedVersions;
 */
static int32_t write_supported_versions_for_chello(TLS *tls,
						   struct tls_hs_msg *msg);

static int32_t write_supported_versions_for_shello(TLS *tls,
						   struct tls_hs_msg *msg);

/**
 * Read the struct SupportedVersions
 *
 * RFC8446 4.2.1.  Supported Versions
 *
 * struct {
 *     select (Handshake.msg_type) {
 *         case client_hello:
 *             ProtocolVersion versions<2..254>;
 *         case server_hello: // and HelloRetryRequest
 *             ProtocolVersion selected_version;
 *     };
 * } SupportedVersions;
 */
static int32_t read_supported_versions_in_chello(TLS *tls,
						 const struct tls_hs_msg *msg,
						 const uint32_t offset);

static int32_t read_supported_versions_in_shello(TLS *tls,
						 const struct tls_hs_msg *msg,
						 const uint32_t offset);

static int32_t write_supported_versions_for_chello(TLS *tls,
						   struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	/*
	 * RFC8446 4.2.1.  Supported Versions
	 *
	 *                    ProtocolVersion versions<2..254>;
	 */
	const int32_t list_length_min = 2;
	const int32_t list_length_max = TLS_VECTOR_1_BYTE_SIZE_MAX;
	if (tls->supported_versions.len * 2 < list_length_min ||
	    list_length_max < tls->supported_versions.len * 2) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 9, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/*
	 * Write the length of ProtocolVersion vector.
	 * This length occupies one octets.
	 */
	if (! tls_hs_msg_write_1(msg, tls->supported_versions.len * 2)) {
		return -1;
	}
	offset += 1;

	/*
	 * Write the list of ProtocolVersion.
	 * The one element occupies two octets.
	 */
	for (int i = 0; i < tls->supported_versions.len; ++i) {
		if (! tls_hs_msg_write_2(msg,
					tls->supported_versions.list[i])) {
			return -1;
		}
		offset += 2;
	}

	return offset;
}

static int32_t write_supported_versions_for_shello(TLS *tls,
						   struct tls_hs_msg *msg)
{
	int32_t offset = 0;
	int32_t version_length = 2;
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	if (! tls_hs_msg_write_2(msg, version)) {
		return -1;
	}
	offset += version_length;

	return offset;
}

static int32_t read_supported_versions_in_chello(TLS *tls,
						 const struct tls_hs_msg *msg,
						 const uint32_t offset)
{
	uint32_t read_bytes = 0;

	const int32_t length_bytes = 1;
	if (msg->len < (offset + length_bytes)) {
		TLS_DPRINTF("supported_versions: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 0,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += length_bytes;

	uint8_t list_length = msg->msg[offset];
	if (msg->len < (offset + read_bytes + list_length)) {
		TLS_DPRINTF("supported_versions: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 1,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/*
	 * RFC8446 4.2.1.  Supported Versions
	 *
	 *                    ProtocolVersion versions<2..254>;
	 */
	uint8_t list_length_min = sizeof(uint16_t);
	if (list_length < list_length_min) {
		TLS_DPRINTF("supported_versions: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 2,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	const uint32_t base = offset + read_bytes;
	uint8_t n = list_length / 2;
	int32_t off = 0;
	uint16_t *list;

	if ((list = malloc(list_length)) == NULL) {
		TLS_DPRINTF("supported_versions: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 3,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	for (uint8_t i = 0; i < n; i++, off += 2) {
		list[i] = tls_util_read_2(&(msg->msg[base+off]));
	}
	read_bytes += list_length;

	tls->peer_supported_versions.len = n;
	tls->peer_supported_versions.list = list;

	return read_bytes;
}

static int32_t read_supported_versions_in_shello(TLS *tls,
						 const struct tls_hs_msg *msg,
						 const uint32_t offset)
{
	uint32_t read_bytes = 0;

	const uint32_t version_length = sizeof(uint16_t);
	if (msg->len < (offset + version_length)) {
		TLS_DPRINTF("supported_versions: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 4,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	if ((msg->len - offset) != version_length) {
		TLS_DPRINTF("supported_versions: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 5,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += version_length;

	uint16_t version = tls_util_read_2(&(msg->msg[offset]));
	uint16_t *list;

	if ((list = malloc(version_length)) == NULL) {
		TLS_DPRINTF("supported_versions: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 6,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	list[0] = version;

	tls->peer_supported_versions.len = 1;
	tls->peer_supported_versions.list = list;

	return read_bytes;
}

int32_t tls_hs_supported_versions_write(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t off = 0;
	uint16_t version;
	const uint32_t type_bytes = 2;

	/*
	 * Use supported_versions extension only when negotiation of
	 * TLS 1.3 is in progress or done.
	 */
	switch (msg->type) {
	case TLS_HANDSHAKE_CLIENT_HELLO:
		for (int i = 0; i < tls->supported_versions.len; i++) {
			version = tls->supported_versions.list[i];
			if (version >= TLS_VER_TLS13) {
				goto found_version;
			}
		}
		break;

	case TLS_HANDSHAKE_SERVER_HELLO:
		version = tls_util_convert_protover_to_ver(
			&(tls->negotiated_version));
		if (version >= TLS_VER_TLS13) {
			goto found_version;
		}
		break;

	default:
		break;
	}

	return 0;

found_version:
	if (tls_hs_msg_write_2(msg, TLS_EXT_SUPPORTED_VERSIONS) == false) {
		return -1;
	}
	off += type_bytes;

	/* write dummy length bytes. */
	int32_t pos = msg->len;

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	off += len_bytes;

	int32_t sv_len;
	switch (msg->type) {
	case TLS_HANDSHAKE_CLIENT_HELLO:
		if ((sv_len = write_supported_versions_for_chello(tls, msg)) < 0) {
			return -1;
		}
		break;

	case TLS_HANDSHAKE_SERVER_HELLO:
		if ((sv_len = write_supported_versions_for_shello(tls, msg)) < 0) {
			return -1;
		}
		break;

	default:
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 7, NULL);
		return -1;
	}
	off += sv_len;

	const int32_t sv_len_max = TLS_EXT_SIZE_MAX;
	if (sv_len > sv_len_max) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 10, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), sv_len);

	return off;
}

int32_t tls_hs_supported_versions_read(TLS *tls,
				       const struct tls_hs_msg *msg,
				       const uint32_t offset)
{
	int32_t read_bytes = 0;
	int32_t sv_len;

	switch (msg->type) {
	case TLS_HANDSHAKE_CLIENT_HELLO:
		if ((sv_len = read_supported_versions_in_chello(
				tls, msg, offset)) < 0) {
			return -1;
		}
		break;

	case TLS_HANDSHAKE_SERVER_HELLO:
		if ((sv_len = read_supported_versions_in_shello(
				tls, msg, offset)) < 0) {
			return -1;
		}
		break;

	default:
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SUPPORTED_VERSIONS + 8, NULL);
		return -1;
	}
	read_bytes += sv_len;

	return read_bytes;
}
