/*
 * Copyright (c) 2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_alert.h"

/** @see tls_key.c */
extern bool tls_key_derive_application_traffic_secret_n(
	TLS *tls, struct tls_connection *connection);

/** @see tls_key.c */
bool tls_key_make_traffic_key(TLS *tls, struct tls_connection *connection);

static int32_t write_keyupdate(TLS *tls, struct tls_hs_msg *msg,
			       enum tls_keyupdate_request req);

static int32_t read_keyupdate(TLS *tls, struct tls_hs_msg *msg,
			      uint32_t offset);

static int32_t write_keyupdate(TLS *tls, struct tls_hs_msg *msg,
			       enum tls_keyupdate_request req)
{
	int32_t offset = 0;

	uint32_t request_update_length = 1;
	if (tls_hs_msg_write_1(msg, (uint8_t)req) == false) {
		TLS_DPRINTF("keyupdate: failed to write request_update value");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
	offset += request_update_length;

	return offset;
}

static int32_t read_keyupdate(TLS *tls, struct tls_hs_msg *msg,
			      uint32_t offset)
{
	uint32_t read_bytes = 0;

	uint8_t keyupdate_request_lentgh = 1;
	if (msg->len < (offset + keyupdate_request_lentgh)) {
		TLS_DPRINTF("keyupdate: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_KEYUPDATE + 0,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/*
	 * RFC8446 4.6.3.  Key and Initialization Vector Update
	 *
	 *    If the request_update field is set to "update_requested", then the
	 *    receiver MUST send a KeyUpdate of its own with request_update set to
	 *    "update_not_requested" prior to sending its next Application Data
	 *    record.  This mechanism allows either side to force an update to the
	 *    entire connection, but causes an implementation which receives
	 *    multiple KeyUpdates while it is silent to respond with a single
	 *    update.  Note that implementations may receive an arbitrary number of
	 *    messages between sending a KeyUpdate with request_update set to
	 *    "update_requested" and receiving the peer's KeyUpdate, because those
	 *    messages may already be in flight.  However, because send and receive
	 *    keys are derived from independent traffic secrets, retaining the
	 *    receive traffic secret does not threaten the forward secrecy of data
	 *    sent before the sender changed keys.
	 */
	uint8_t keyupdate_request = msg->msg[offset];
	switch (keyupdate_request) {
	case TLS_KEYUPDATE_REQUESTED:
		tls->need_sending_keyupdate = true;
		/* Fall through */

	case TLS_KEYUPDATE_NOT_REQUESTED:
		/*
		 * RFC8446 7.2.  Updating Traffic Secrets
		 *
		 *    Once client_/server_application_traffic_secret_N+1 and its associated
		 *    traffic keys have been computed, implementations SHOULD delete
		 *    client_/server_application_traffic_secret_N and its associated
		 *    traffic keys.
		 */
		if (tls_key_derive_application_traffic_secret_n(tls,
								&(tls->active_read))
		    == false) {
			TLS_DPRINTF("keyupdate: "
				"tls_key_derive_application_traffic_secret_n");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		if (tls_key_make_traffic_key(tls, &(tls->active_read))
		    == false) {
			TLS_DPRINTF("keyupdate: tls_key_make_traffic_key");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}
		break;

	/*
	 * RFC8446 4.6.3.  Key and Initialization Vector Update
	 *
	 *    request_update:  Indicates whether the recipient of the KeyUpdate
	 *       should respond with its own KeyUpdate.  If an implementation
	 *       receives any other value, it MUST terminate the connection with an
	 *       "illegal_parameter" alert.
	 */
	default:
		TLS_DPRINTF("keyupdate: unknown value");
		OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_KEYUPDATE + 1,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}
	read_bytes += keyupdate_request_lentgh;

	return read_bytes;
}

struct tls_hs_msg *tls_hs_keyupdate_compose(TLS *tls,
					    enum tls_keyupdate_request req)
{
	struct tls_hs_msg *msg;
	if ((msg = tls_hs_msg_init()) == NULL) {
		TLS_DPRINTF("keyupdate: tls_hs_msg_init");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return NULL;
	}

	/* EncryptedExtensions message has following structure.
	 *
	 * | type                 (1) |
	 * | length of message    (3) |
	 * | request_update value (1) |
	 */

	msg->type = TLS_HANDSHAKE_KEY_UPDATE;

	int32_t write_bytes = 0;
	if ((write_bytes = write_keyupdate(tls, msg, req)) < 0) {
		TLS_DPRINTF("keyupdate: write_keyupdate");
		goto failed;
	}

	return msg;

failed:
	tls_hs_msg_free(msg);
	return NULL;
}

bool tls_hs_keyupdate_parse(TLS *tls, struct tls_hs_msg *msg)
{
	uint32_t offset = 0;

	if (msg->type != TLS_HANDSHAKE_KEY_UPDATE) {
		OK_set_error(ERR_ST_TLS_UNEXPECTED_MSG,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_KEYUPDATE + 2,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
		return false;
	}

	int32_t read_bytes = 0;
	if ((read_bytes = read_keyupdate(tls, msg, offset)) < 0) {
		TLS_DPRINTF("keyupdate: read_keyupdate");
		return false;
	}
	offset += read_bytes;

	if (msg->len != offset) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_KEYUPDATE + 3,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return false;
	}

	return true;
}

