/*
 * Copyright (c) 2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_alert.h"

#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>

static enum tls_hs_server_name_type server_name_types[] = {
	TLS_SNI_HOST_NAME
};

/**
 * check if hostname is valid string and not literal ip address.
 */
static bool check_hostname(const char *hostname);

/**
 * write ServerName structure in case of host_name.
 *
 * RFC6066 3.  Server Name Indication
 *
 *       struct {
 *           NameType name_type;
 *           select (name_type) {
 *               case host_name: HostName;
 *           } name;
 *       } ServerName;
 *
 *       enum {
 *           host_name(0), (255)
 *       } NameType;
 *
 *       opaque HostName<1..2^16-1>;
 */
static int32_t write_hostname(TLS *tls, struct tls_hs_msg *msg);

/**
 * read ServerName structure in case of host_name.
 *
 * RFC6066 3.  Server Name Indication
 *
 *       struct {
 *           NameType name_type;
 *           select (name_type) {
 *               case host_name: HostName;
 *           } name;
 *       } ServerName;
 *
 *       enum {
 *           host_name(0), (255)
 *       } NameType;
 *
 *       opaque HostName<1..2^16-1>;
 */
static int32_t read_hostname(TLS *tls, const struct tls_hs_msg *msg,
			      const uint32_t offset);

/**
 * write ServerNameList structure.
 *
 * RFC6066 3.  Server Name Indication
 *
 *       struct {
 *           ServerName server_name_list<1..2^16-1>
 *       } ServerNameList;
 */
static int32_t write_server_name_by_client(TLS *tls, struct tls_hs_msg *msg);

/**
 * read ServerNameList structure.
 *
 * RFC6066 3.  Server Name Indication
 *
 *       struct {
 *           ServerName server_name_list<1..2^16-1>
 *       } ServerNameList;
 */
static int32_t read_server_name_by_server(TLS *tls,
					  const struct tls_hs_msg *msg,
					  const uint32_t offset);

/**
 * read empty ServerNameList structure.
 *
 * RFC6066 3.  Server Name Indication
 *
 *       struct {
 *           ServerName server_name_list<1..2^16-1>
 *       } ServerNameList;
 */
static int32_t read_server_name_by_client(TLS *tls,
					  const struct tls_hs_msg *msg,
					  const uint32_t offset);

static bool check_hostname(const char *hostname)
{
	size_t len = strnlen(hostname, TLS_EXT_SIZE_MAX+1);
	if (len == 0) {
		TLS_DPRINTF("empty hostname");
		OK_set_error(ERR_ST_TLS_SERVER_NAME_EMPTY,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 0,
			     NULL);
		return false;
	}

	if (hostname[len] != '\0') {
		TLS_DPRINTF("too long hostname");
		OK_set_error(ERR_ST_TLS_SERVER_NAME_TOOLONG,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 1,
			     NULL);
		return false;
	}

	/* literal ip address is not permitted. */
	struct addrinfo hints = {
		.ai_family = AF_UNSPEC,
		.ai_socktype = SOCK_STREAM,
		.ai_flags = AI_NUMERICHOST,
		.ai_protocol = 0,
		.ai_canonname = NULL,
		.ai_addr = NULL,
		.ai_next = NULL,
	};
	struct addrinfo  *aihead;

	int gaicode;
	if ((gaicode = getaddrinfo(hostname, NULL, &hints, &aihead)) == 0) {
		OK_set_error(ERR_ST_TLS_SERVER_NAME_NUMHOST,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 2,
			     NULL);
		freeaddrinfo(aihead);
		return false;
	}

	if (gaicode != EAI_NONAME) {
		OK_set_error(ERR_ST_TLS_GETADDRINFO,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 3,
			     NULL);
		return false;
	}

	/* TODO: check if hostname is FQDN */

	return true;
}

static int32_t write_hostname(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	/*
	 * RFC 6066 3.  Server Name Indication
	 *
	 *    "HostName" contains the fully qualified DNS hostname of the server,
	 *    as understood by the client.  The hostname is represented as a byte
	 *    string using ASCII encoding without a trailing dot.  This allows the
	 *    support of internationalized domain names through the use of A-labels
	 *    defined in [RFC5890].  DNS hostnames are case-insensitive.  The
	 *    algorithm to compare hostnames is described in [RFC5890], Section
	 *    2.3.2.4.
	 *
	 *    Literal IPv4 and IPv6 addresses are not permitted in "HostName".
	 */
	const uint32_t type_bytes = 1;
	if (tls_hs_msg_write_1(msg, TLS_SNI_HOST_NAME) == false) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	offset += type_bytes;

	size_t len = strlen(tls->server_name);

	/*
	 * RFC6066 3.  Server Name Indication
	 *
	 *       opaque HostName<1..2^16-1>;
	 */
	const size_t hostname_size_min = 1;
	const size_t hostname_size_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (len < hostname_size_min || hostname_size_max < len) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 18,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, len) == false) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	offset += len_bytes;

	if (tls_hs_msg_write_n(msg, (uint8_t *)(tls->server_name), len)
	    == false) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	offset += len;

	return offset;
}

static int32_t read_hostname(TLS *tls, const struct tls_hs_msg *msg,
			      const uint32_t offset)
{
	uint32_t read_bytes = 0;

	const uint32_t length_bytes = 2;
	if (msg->len < (offset + length_bytes)) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 4,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += length_bytes;

	const uint16_t name_length = tls_util_read_2(&(msg->msg[offset]));
	if (msg->len < (offset + read_bytes + name_length)) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 5,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/*
	 * RFC6066 3.  Server Name Indication
	 *
	 *       opaque HostName<1..2^16-1>;
	 */
	const uint16_t hostname_size_min = 1;
	if (name_length < hostname_size_min) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 6,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	char hostname[name_length + 1];
	memcpy(hostname, &(msg->msg[offset + read_bytes]), name_length);
	hostname[name_length] = '\0';
	TLS_DPRINTF("server_name: %s", hostname);

	if (check_hostname(hostname) == false) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}

	if ((tls->server_name = strdup(hostname)) == NULL) {
		TLS_DPRINTF("strdup %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_STRDUP,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 7,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	read_bytes += name_length;

	return read_bytes;
}

static int32_t write_server_name_by_client(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t offset = 0;

	/*
	 * RFC 6066 3.  Server Name Indication
	 *
	 *                                                        The client SHOULD
	 *    include the same server_name extension in the session resumption
	 *    request as it did in the full handshake that established the session.
	 */

	/*
	 * RFC 6066 3.  Server Name Indication
	 *
	 *    Currently, the only server names supported are DNS hostnames;
	 *    however, this does not imply any dependency of TLS on DNS, and other
	 *    name types may be added in the future (by an RFC that updates this
	 *    document).  The data structure associated with the host_name NameType
	 *    is a variable-length vector that begins with a 16-bit length.  For
	 *    backward compatibility, all future data structures associated with
	 *    new NameTypes MUST begin with a 16-bit length field.  TLS MAY treat
	 *    provided server names as opaque data and pass the names and types to
	 *    the application.
	 */
	size_t list_length = sizeof(server_name_types) \
		/ sizeof(enum tls_hs_server_name_type);

	/* write dummy length of list of the struct ServerName */
	int pos = msg->len;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	offset += 2;

	int32_t len;
	for (uint32_t i = 0; i < list_length; i++) {
		len = 0;
		switch (server_name_types[i]) {
		case TLS_SNI_HOST_NAME:
			if ((len = write_hostname(tls, msg)) < 0) {
				return -1;
			}
			break;

		default:
			break;
		}

		offset += len;
	}

	/*
	 * RFC6066 3.  Server Name Indication
	 *
	 *           ServerName server_name_list<1..2^16-1>
	 */
	const int32_t list_length_min = 1;
	const int32_t list_length_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (offset - 2 < list_length_min || list_length_max < offset - 2) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 19,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), offset - 2);

	return offset;
}

static int32_t read_server_name_by_server(TLS *tls,
					  const struct tls_hs_msg *msg,
					  const uint32_t offset)
{
	uint32_t read_bytes = 0;

	const uint32_t length_bytes = 2;
	if (msg->len < (offset + length_bytes)) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 8,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += length_bytes;

	const uint16_t list_length = tls_util_read_2(&(msg->msg[offset]));
	if (msg->len < (offset + read_bytes + list_length)) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 9,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/*
	 * RFC6066 3.  Server Name Indication
	 *
	 *           ServerName server_name_list<1..2^16-1>
	 */
	const uint16_t server_name_list_size_min = 1;
	if (list_length < server_name_list_size_min) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 10,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	const uint8_t server_name_type_value_max = 255;
	bool recv_type_flags[server_name_type_value_max];
	memset(recv_type_flags, 0, sizeof(recv_type_flags));

	/*
	 * RFC 6066 3.  Server Name Indication
	 *
	 *    A server that implements this extension MUST NOT accept the request
	 *    to resume the session if the server_name extension contains a
	 *    different name.  Instead, it proceeds with a full handshake to
	 *    establish a new session.
	 */
	const uint16_t type_byte = 1;
	enum tls_hs_server_name_type server_name_type;
	int32_t len;
	while (read_bytes < length_bytes + list_length) {
		if (msg->len < (offset + read_bytes + type_byte)) {
			TLS_DPRINTF("invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_SERVERNAME + 11,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
			return -1;
		}

		server_name_type = msg->msg[offset + read_bytes];
		read_bytes += type_byte;

		/*
		 * RFC 6066 3.  Server Name Indication
		 *
		 *    The ServerNameList MUST NOT contain more than one name of the same
		 *    name_type.  If the server understood the ClientHello extension but
		 *    does not recognize the server name, the server SHOULD take one of two
		 *    actions: either abort the handshake by sending a fatal-level
		 *    unrecognized_name(112) alert or continue the handshake.  It is NOT
		 *    RECOMMENDED to send a warning-level unrecognized_name(112) alert,
		 *    because the client's behavior in response to warning-level alerts is
		 *    unpredictable.  If there is a mismatch between the server name used
		 *    by the client application and the server name of the credential
		 *    chosen by the server, this mismatch will become apparent when the
		 *    client application performs the server endpoint identification, at
		 *    which point the client application will have to decide whether to
		 *    proceed with the communication.
		 */
		if (recv_type_flags[server_name_type] == true) {
			TLS_DPRINTF("multiple same type server_name");
			OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_SERVERNAME + 12,
				     NULL);
			TLS_ALERT_FATAL(tls,
					TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return -1;
		}
		recv_type_flags[server_name_type] = true;

		switch (server_name_type) {
		case TLS_SNI_HOST_NAME:
			if ((len = read_hostname(tls, msg, offset + read_bytes))
			    < 0) {
				return -1;
			}
			break;

		default:
			TLS_DPRINTF("unknown name type");
			OK_set_error(ERR_ST_TLS_UNRECOGNIZED_NAME,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_SERVERNAME + 13,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNRECOGNIZED_NAME);
			return -1;
		}

		read_bytes += len;
	}

	return read_bytes;
}

static int32_t read_server_name_by_client(TLS *tls,
					  const struct tls_hs_msg *msg,
					  const uint32_t offset)
{
	uint32_t read_bytes = 0;

	/*
	 * RFC 6066 3.  Server Name Indication
	 *
	 *    A server that receives a client hello containing the "server_name"
	 *    extension MAY use the information contained in the extension to guide
	 *    its selection of an appropriate certificate to return to the client,
	 *    and/or other aspects of security policy.  In this event, the server
	 *    SHALL include an extension of type "server_name" in the (extended)
	 *    server hello.  The "extension_data" field of this extension SHALL be
	 *    empty.
	 */

	if (msg->len != 0 || offset != 0) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 14,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	return read_bytes;
}

int32_t tls_hs_servername_write(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t off = 0;
	uint16_t version;
	const uint32_t type_bytes = 2;

	/* server name isn't set by application or sent by client. */
	if (tls->server_name == NULL) {
		return 0;
	}

	switch (msg->type) {
	case TLS_HANDSHAKE_CLIENT_HELLO:
		goto found_version;

	case TLS_HANDSHAKE_SERVER_HELLO:
		version = tls_util_convert_protover_to_ver(
			&(tls->negotiated_version));
		switch (version) {
		case TLS_VER_TLS10:
		case TLS_VER_TLS11:
		case TLS_VER_TLS12:
			/*
			 * RFC 6066 3.  Server Name Indication
			 *
			 *                              When resuming a session, the server MUST
			 *    NOT include a server_name extension in the server hello.
			 */
			goto found_version;

		default:
			break;
		}
		break;

	case TLS_HANDSHAKE_ENCRYPTED_EXTENSIONS:
		version = tls_util_convert_protover_to_ver(
			&(tls->negotiated_version));
		switch (version) {
		case TLS_VER_TLS13:
			goto found_version;

		default:
			break;
		}
		break;

	default:
		break;
	}

	return 0;

found_version:
	if (tls_hs_msg_write_2(msg, TLS_EXT_SERVER_NAME) == false) {
		return -1;
	}
	off += type_bytes;

	/* write dummy length bytes. */
	int32_t pos = msg->len;

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	off += len_bytes;

	int32_t sn_len;
	switch (msg->type) {
	case TLS_HANDSHAKE_CLIENT_HELLO:
		if ((sn_len = write_server_name_by_client(tls, msg)) < 0) {
			TLS_DPRINTF("write_server_name_by_client");
			return -1;
		}
		break;

	case TLS_HANDSHAKE_SERVER_HELLO:
	case TLS_HANDSHAKE_ENCRYPTED_EXTENSIONS:
		/*
		 * RFC 6066 3.  Server Name Indication
		 *
		 *    A server that receives a client hello containing the "server_name"
		 *    extension MAY use the information contained in the extension to guide
		 *    its selection of an appropriate certificate to return to the client,
		 *    and/or other aspects of security policy.  In this event, the server
		 *    SHALL include an extension of type "server_name" in the (extended)
		 *    server hello.  The "extension_data" field of this extension SHALL be
		 *    empty.
		 */
		sn_len = 0;
		break;

	default:
		TLS_DPRINTF("unexpected message type");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 15,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	off += sn_len;

	const int32_t sn_len_max = TLS_EXT_SIZE_MAX;
	if (sn_len > sn_len_max) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS5, ERR_PT_TLS_HS_EXT_SERVERNAME + 20,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), sn_len);

	return off;
}

int32_t tls_hs_servername_read(TLS *tls, const struct tls_hs_msg *msg,
			       const uint32_t offset)
{
	int32_t read_bytes = 0;
	int32_t sn_len;

	bool *sent_exts = tls->interim_params->sent_ext_flags;
	switch (msg->type) {
	case TLS_HANDSHAKE_CLIENT_HELLO:
		if ((sn_len = read_server_name_by_server(tls, msg, offset))
		    < 0) {
			TLS_DPRINTF("read_server_name_by_server");
			return -1;
		}
		break;

	case TLS_HANDSHAKE_SERVER_HELLO:
	case TLS_HANDSHAKE_ENCRYPTED_EXTENSIONS:
		if (sent_exts[TLS_EXT_SERVER_NAME] == false) {
			TLS_DPRINTF("not sent in client hello");
			OK_set_error(ERR_ST_TLS_UNSUPPORTED_EXTENSION,
				     ERR_LC_TLS5,
				     ERR_PT_TLS_HS_EXT_SERVERNAME + 16,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNSUPPORTED_EXTENSION);
			return -1;
		}

		if ((sn_len = read_server_name_by_client(tls, msg, offset))
		    < 0) {
			TLS_DPRINTF("read_server_name_by_client");
			return -1;
		}
		break;

	default:
		TLS_DPRINTF("unexpected message type");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS5,
			     ERR_PT_TLS_HS_EXT_SERVERNAME + 17,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	read_bytes += sn_len;

	return read_bytes;
}
