/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_session.h"
#include "tls_record.h"
#include "tls_cipher.h"
#include "tls_alert.h"

/* this file treat the ChChangeCipherSpec protocol. ChangeCipherSpec
 * protocol has following structure (RFC 5246 section 7.1).
 *
 *   struct {
 *     enum {
 *       change_cipher_spec (1), 255
 *     } type
 *   } ChangeCipherSpec;
 */

/** the fragment length of CCS protocol. */
static const uint32_t TLS_CCS_BODY_LENGTH = 1;

/**
 * value of the body in theCCS protocol. see explanation in header of
 * this file. */
enum tls_ccs_t {
	TLS_CCS_CHANGE_CIPHER_SPEC = 1
};

/**
 * status (direction) of CCS protocol data that want to handle.
 *
 * this enum is internal definition of this tls module to be used to
 * determine whether the ccs protocol data (send from now on or
 * received) should be handled.
 */
enum tls_ccs_stat {
	/** received status */
	TLS_CCS_RECV_STAT,

	/** sending status  */
	TLS_CCS_SEND_STAT
};

/**
 * do check in the case of reading ccs protocol data first.
 *
 * for example, in the server side, when communication is not resumption
 * (resession), ccs protocol data is read earlier than write.
 */
static bool check_stat_for_read_first(TLS *tls, enum tls_ccs_stat status);

/**
 * do check in the case of writing ccs protocol data first.
 *
 * for example, in the client side, when communication is not resumption
 * (resession), ccs protocol data is written earlier than read.
 */
static bool check_stat_for_write_first(TLS *tls, enum tls_ccs_stat status);

/**
 * check whether the CCS protocol data is handled now.
 *
 * by entity of tls structure, status of resumption and so on, check
 * whether the ccs protocol data should be handled.
 */
static bool check_stat(TLS *tls, enum tls_ccs_stat status);

/**
 * check whether state is valid to receive CCS.
 */
static bool check_state_tls13(TLS *tls);

/**
 * read type of ccs.
 *
 * this value should be 1 only. see definition of enum tls_ccs_t.
 *
 * TODO: i think inline is enough for this function.
 */
static enum tls_ccs_t read_type(uint8_t *buff);

static bool check_stat_for_read_first(TLS *tls, enum tls_ccs_stat status) {
	if (status == TLS_CCS_SEND_STAT &&
	    tls->state == TLS_STATE_HS_RECV_FINISH) {
		return true;
	}

	if (status == TLS_CCS_RECV_STAT &&
	    tls->state == TLS_STATE_HS_BEFORE_FINISH) {
		return true;
	}

	return false;
}

static bool check_stat_for_write_first(TLS *tls, enum tls_ccs_stat status) {
	if (status == TLS_CCS_SEND_STAT &&
	    tls->state == TLS_STATE_HS_BEFORE_FINISH) {
		return true;
	}

	if (status == TLS_CCS_RECV_STAT &&
	    tls->state == TLS_STATE_HS_SEND_FINISH) {
		return true;
	}

	return false;
}

static bool check_stat(TLS *tls, enum tls_ccs_stat status) {
	switch (tls->entity) {
	case TLS_CONNECT_CLIENT:
		if (tls->resession == true) {
			return check_stat_for_read_first(tls, status);
		} else {
			return check_stat_for_write_first(tls, status);
		}

	case TLS_CONNECT_SERVER:
		if (tls->resession == true) {
			return check_stat_for_write_first(tls, status);
		} else {
			return check_stat_for_read_first(tls, status);
		}

	default:
		assert(!"unknown connection stat.");
	}

	return false;
}

static bool check_state_tls13(TLS *tls) {
	switch (tls->state) {
	/* client side state */
	case TLS_STATE_HS_AFTER_SEND_CHELLO:
	case TLS_STATE_HS_BEFORE_SEND_2NDCHELLO:
	case TLS_STATE_HS_AFTER_SEND_2NDCHELLO:
	case TLS_STATE_HS_BEFORE_RECV_SHELLO:
	case TLS_STATE_HS_AFTER_RECV_SHELLO:
	case TLS_STATE_HS_BEFORE_RECV_2NDSHELLO:
	case TLS_STATE_HS_AFTER_RECV_2NDSHELLO:
	case TLS_STATE_HS_AFTER_RECV_HRREQ:
	case TLS_STATE_HS_RECV_ENCEXT:
	case TLS_STATE_HS_RECV_CERTREQ:
	case TLS_STATE_HS_RECV_SCERT:
	case TLS_STATE_HS_RECV_SCERTVFY:

	/* server side state */
	case TLS_STATE_HS_AFTER_RECV_CHELLO:
	case TLS_STATE_HS_BEFORE_RECV_2NDCHELLO:
	case TLS_STATE_HS_AFTER_RECV_2NDCHELLO:
	case TLS_STATE_HS_SEND_SHELLO:
	case TLS_STATE_HS_BEFORE_SEND_HRREQ:
	case TLS_STATE_HS_SEND_ENCEXT:
	case TLS_STATE_HS_SEND_CERTREQ:
	case TLS_STATE_HS_SEND_SCERT:
	case TLS_STATE_HS_SEND_SCERTVFY:
	case TLS_STATE_HS_SEND_FINISH:
	case TLS_STATE_HS_RECV_CCERT:
	case TLS_STATE_HS_RECV_CCERTVFY:

	/* common state */
	case TLS_STATE_HS_BEFORE_FINISH:
		return true;

	default:
		return false;
	}
}

static enum tls_ccs_t read_type(uint8_t *buff) {
	return buff[0];
}

bool tls_ccs_send(TLS *tls) {
	uint32_t len = TLS_CCS_BODY_LENGTH;
	uint8_t  buf[len];

	/* send change cipher spec according to compatibility mode request. */
	if (tls->state == TLS_STATE_HS_SEND_SHELLO ||
	    tls->state == TLS_STATE_HS_BEFORE_RECV_2NDCHELLO) {
		uint16_t version = tls_util_convert_protover_to_ver(
			&(tls->negotiated_version));
		if (version == TLS_VER_TLS13) {
			buf[0] = TLS_CCS_CHANGE_CIPHER_SPEC;

			if (tls_record_write(TLS_CTYPE_CHANGE_CIPHER_SPEC,
					     tls, buf, len) < 0) {
				TLS_DPRINTF("CCS: tls_record_write");
				goto err;
			}

			return true;
		}
	}

	if (! check_stat(tls, TLS_CCS_SEND_STAT)) {
		TLS_DPRINTF("CCS: invalid status.");
		OK_set_error(ERR_ST_TLS_INVALID_STATUS,
			     ERR_LC_TLS1, ERR_PT_TLS_CCS + 0, NULL);
		goto err;
	}

	buf[0] = TLS_CCS_CHANGE_CIPHER_SPEC;

	ssize_t n UNUSED;
	if ((n = tls_record_write(TLS_CTYPE_CHANGE_CIPHER_SPEC,
				  tls, buf, len)) < 0) {
		TLS_DPRINTF("CCS: tls_record_write (%zd)", n);
		OK_set_error(ERR_ST_TLS_TLS_RECORD_WRITE,
			     ERR_LC_TLS1, ERR_PT_TLS_CCS + 1, NULL);
		goto err;
	}

	tls->active_write.seqnum = 0;
	tls->active_write.compression_algorithm = tls->pending->compression_algorithm;

	/* set parameter of selected cipher suite. */
	if (! tls_cipher_param_set(tls->pending->cipher_suite,
				   &(tls->active_write.cipher))) {
		TLS_DPRINTF("CCS: tls_cipher_param_set");
		OK_set_error(ERR_ST_TLS_TLS_CIPHER_PARAM_SET,
			     ERR_LC_TLS1, ERR_PT_TLS_CCS + 3, NULL);
		goto err;
	}

	return true;

err:
	/* I think error in the writing status is internal error
	 * (probably implementation error). */
	TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
	return false;
}

bool tls_ccs_recv(TLS *tls, struct tls_record *record) {
	/*
	 * RFC8446 5.  Record Protocol
	 *
	 *    An implementation may receive an unencrypted record of type
	 *    change_cipher_spec consisting of the single byte value 0x01 at any
	 *    time after the first ClientHello message has been sent or received
	 *    and before the peer's Finished message has been received and MUST
	 *    simply drop it without further processing.  Note that this record may
	 *    appear at a point at the handshake where the implementation is
	 *    expecting protected records, and so it is necessary to detect this
	 *    condition prior to attempting to deprotect the record.  An
	 *    implementation which receives any other change_cipher_spec value or
	 *    which receives a protected change_cipher_spec record MUST abort the
	 *    handshake with an "unexpected_message" alert.  If an implementation
	 *    detects a change_cipher_spec record received before the first
	 *    ClientHello message or after the peer's Finished message, it MUST be
	 *    treated as an unexpected record type (though stateless servers may
	 *    not be able to distinguish these cases from allowed cases).
	 */
	if (tls->state == TLS_STATE_HS_AFTER_SEND_CHELLO ||
	    tls->state == TLS_STATE_HS_BEFORE_RECV_SHELLO ||
	    tls->state == TLS_STATE_HS_BEFORE_RECV_CHELLO) {
		return true;
	}

	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch (version) {
	case TLS_VER_TLS13:
		if (! check_state_tls13(tls)) {
			TLS_DPRINTF("CCS: invalid status.");
			OK_set_error(ERR_ST_TLS_INVALID_STATUS,
				     ERR_LC_TLS1, ERR_PT_TLS_CCS + 2, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
			return false;
		}
		break;

	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
	default:
		if (! check_stat(tls, TLS_CCS_RECV_STAT)) {
			TLS_DPRINTF("CCS: invalid status.");
			OK_set_error(ERR_ST_TLS_INVALID_STATUS,
				     ERR_LC_TLS1, ERR_PT_TLS_CCS + 4, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
			return false;
		}
		break;
	}

	if (record->len != TLS_CCS_BODY_LENGTH) {
		TLS_DPRINTF("CCS: len != 1");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_CCS + 5, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return false;
	}

	/*
	 * RFC8446 5.  Record Protocol
	 *
	 *                                                            An
	 *    implementation which receives any other change_cipher_spec value or
	 *    which receives a protected change_cipher_spec record MUST abort the
	 *    handshake with an "unexpected_message" alert.
	 */
	enum tls_ccs_t type = read_type(record->frag);
	if (type != TLS_CCS_CHANGE_CIPHER_SPEC) {
		TLS_DPRINTF("CCS: unknown type value");
		OK_set_error(ERR_ST_TLS_INVALID_CCS_BODY,
			     ERR_LC_TLS1, ERR_PT_TLS_CCS + 6, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
		return false;
	}

	/*
	 * RFC8446 5.  Record Protocol
	 *
	 *    An implementation may receive an unencrypted record of type
	 *    change_cipher_spec consisting of the single byte value 0x01 at any
	 *    time after the first ClientHello message has been sent or received
	 *    and before the peer's Finished message has been received and MUST
	 *    simply drop it without further processing.
	 */
	if (version != TLS_VER_TLS13) {
		tls->active_read.seqnum = 0;
		tls->active_read.compression_algorithm = tls->pending->compression_algorithm;

		/* set parameter of selected cipher suite. */
		if (! tls_cipher_param_set(tls->pending->cipher_suite,
					   &(tls->active_read.cipher))) {
			TLS_DPRINTF("CCS: tls_cipher_param_set");
			OK_set_error(ERR_ST_TLS_TLS_CIPHER_PARAM_SET,
				     ERR_LC_TLS1, ERR_PT_TLS_CCS + 7, NULL);
			/* NOTE: internal error is better? */
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
			return false;
		}

		tls->state = TLS_STATE_CCS_RECV;
	}

	return true;
}
