/*
 * Copyright (c) 2015-2016 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_digitally_signed.h"
#include "tls_alert.h"

/**
 * Write the server's key exchange parameters.
 *
 * Key exchange alg. | Structure        | References
 * ----------------- | ---------------- | ----------
 * dh_anon           | ServerDHParams   | RFC5246
 * dhe_dss           | ServerDHParams   | RFC5246
 * dhe_rsa           | ServerDHParams   | RFC5246
 * rsa               | (omitted)        | RFC5246
 * dh_dss            | (omitted)        | RFC5246
 * dh_rsa            | (omitted)        | RFC5246
 * ecdh_anon         | ServerECDHParams | RFC4492
 * ecdhe_ecdsa       | ServerECDHParams | RFC4492
 * ecdhe_rsa         | ServerECDHParams | RFC4492
 * ecdh_ecdsa        | ServerECDHParams | RFC4492
 * ecdh_rsa          | ServerECDHParams | RFC4492
 */
static int32_t write_params(TLS *tls, struct tls_hs_msg *msg);

/**
 * Write a signature over the server's key exchange parameters.
 *
 * Key exchange alg. | Structure        | References
 * ----------------- | ---------------- | ----------
 * dh_anon           | (not needed)     | RFC5246
 * dhe_dss           | *1               | RFC5246
 * dhe_rsa           | *1               | RFC5246
 * rsa               | (omitted)        | RFC5246
 * dh_dss            | (omitted)        | RFC5246
 * dh_rsa            | (omitted)        | RFC5246
 * ecdh_anon         | (not needed)     | RFC4492
 * ecdhe_ecdsa       | *2               | RFC4492
 * ecdhe_rsa         | *2               | RFC4492
 * ecdh_ecdsa        | (omitted)        | RFC4492
 * ecdh_rsa          | (omitted)        | RFC4492
 *
 * (*1)
 *      digitally-signed struct {
 *          opaque client_random[32];
 *          opaque server_random[32];
 *          ServerDHParams params;
 *      } signed_params;
 *
 * (*2)
 *      digitally-signed struct {
 *          opaque client_random[32];
 *          opaque server_random[32];
 *          ServerECDHParams params;
 *      } signed_params;
 */
static int32_t write_signed_params(TLS *tls, struct tls_hs_msg *msg);

/**
 * Read the server's key exchange parameters.
 */
static int32_t read_params(TLS *tls, struct tls_hs_msg *msg,
			   const uint32_t offset);

/**
 * Read a signature over the server's key exchange parameters.
 */
static int32_t read_signed_params(TLS *tls, struct tls_hs_msg *msg,
				  const uint32_t offset);

static int32_t write_params(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t offset = 0;
	int32_t params_len = 0;

	/*
	 * Save the starting position for later reference
	 * (Used by tls_hs_signature_get_digest()).
	 */
	tls->skeyexc_params = &(msg->msg[msg->len]);

	switch(tls->keymethod) {
	case TLS_KXC_DH_anon:
		/* TODO: do implemetation. */
		assert(!"unsupported.");
		break;

	case TLS_KXC_DHE_DSS:
	case TLS_KXC_DHE_RSA:
		/* TODO: do implemetation. */
		assert(!"unsupported.");
		break;

	case TLS_KXC_RSA:
	case TLS_KXC_DH_DSS:
	case TLS_KXC_DH_RSA:
		/* message ommited */
		break;

	case TLS_KXC_ECDHE_ECDSA:
	case TLS_KXC_ECDHE_RSA:
	case TLS_KXC_ECDH_anon:
		params_len = tls_hs_ecdh_skeyexc_write_server_params(tls, msg);
		break;

	case TLS_KXC_ECDH_ECDSA:
	case TLS_KXC_ECDH_RSA:
		/* message ommited */
		break;

	default:
		/* unknown */
		assert(!"unknown key exchange algorithm.");
		break;
	}

	if (params_len < 0) {
		return -1;
	}

	if (params_len > 0) {
		tls->skeyexc_params_len = params_len;
	} else {
		tls->skeyexc_params = NULL;
		tls->skeyexc_params_len = 0;
	}

	offset += params_len;

	return offset;
}

static int32_t write_signed_params(TLS *tls, struct tls_hs_msg *msg)
{
	int32_t offset = 0;
	int32_t ds_len = 0;

	switch(tls->keymethod) {
	case TLS_KXC_DH_anon:
		/* not needed */
		break;

	case TLS_KXC_DHE_DSS:
	case TLS_KXC_DHE_RSA:
		ds_len = tls_digitally_signed_write_hash(tls,
							 tls->pkcs12_server,
							 msg);
		break;

	case TLS_KXC_RSA:
	case TLS_KXC_DH_DSS:
	case TLS_KXC_DH_RSA:
		/* message ommited */
		break;

	case TLS_KXC_ECDH_anon:
		/* not needed */
		break;

	case TLS_KXC_ECDHE_ECDSA:
	case TLS_KXC_ECDHE_RSA:
		ds_len = tls_digitally_signed_write_hash(tls,
							 tls->pkcs12_server,
							 msg);
		break;

	case TLS_KXC_ECDH_ECDSA:
	case TLS_KXC_ECDH_RSA:
		/* message ommited */
		break;

	default:
		/* unknown */
		assert(!"unknown key exchange algorithm.");
		break;
	}

	if (ds_len < 0) {
		return -1;
	}

	offset += ds_len;

	return offset;
}

static int32_t read_params(TLS *tls, struct tls_hs_msg *msg,
			   const uint32_t offset)
{
	uint32_t read_bytes = 0;
	int32_t params_len = 0;

	switch(tls->keymethod) {
	case TLS_KXC_DH_anon:
		/* TODO : tls_hs_dh_skeyexc_read_server_params() */
		assert(!"unsuppoted.");
		break;

	case TLS_KXC_DHE_DSS:
	case TLS_KXC_DHE_RSA:
		/* TODO : tls_hs_dh_skeyexc_read_server_params() */
		assert(!"unsuppoted.");
		break;

	case TLS_KXC_RSA:
	case TLS_KXC_DH_DSS:
	case TLS_KXC_DH_RSA:
		/* message ommited */
		break;

	case TLS_KXC_ECDHE_ECDSA:
	case TLS_KXC_ECDHE_RSA:
	case TLS_KXC_ECDH_anon:
		params_len =
			tls_hs_ecdh_skeyexc_read_server_params(tls, msg,
							       offset);
		break;

	case TLS_KXC_ECDH_ECDSA:
	case TLS_KXC_ECDH_RSA:
		/* message ommited */
		break;

	default:
		/* unknown */
		assert(!"unknown key exchange algorithm.");
		break;
	}

	if (params_len < 0) {
		return -1;
	}

	if (params_len > 0) {
		tls->skeyexc_params = &(msg->msg[0]);
		tls->skeyexc_params_len = params_len;
	} else {
		tls->skeyexc_params = NULL;
		tls->skeyexc_params_len = 0;
	}

	read_bytes += params_len;

	return read_bytes;
}

static int32_t read_signed_params(TLS *tls, struct tls_hs_msg *msg,
				  const uint32_t offset)
{
	uint32_t read_bytes = 0;
	int32_t ds_len = 0;

	switch(tls->keymethod) {
	case TLS_KXC_DH_anon:
		/* not needed */
		break;

	case TLS_KXC_DHE_DSS:
	case TLS_KXC_DHE_RSA:
		ds_len = tls_digitally_signed_read_hash(tls, tls->pkcs12_server,
							msg, offset);
		break;

	case TLS_KXC_RSA:
	case TLS_KXC_DH_DSS:
	case TLS_KXC_DH_RSA:
		/* message ommited */
		break;

	case TLS_KXC_ECDH_anon:
		/* not needed */
		break;

	case TLS_KXC_ECDHE_ECDSA:
	case TLS_KXC_ECDHE_RSA:
		ds_len = tls_digitally_signed_read_hash(tls, tls->pkcs12_server,
							msg, offset);
		break;

	case TLS_KXC_ECDH_ECDSA:
	case TLS_KXC_ECDH_RSA:
		/* message ommited */
		break;

	default:
		/* unknown */
		assert(!"unknown key exchange algorithm.");
		break;
	}

	if (ds_len < 0) {
		return -1;
	}

	read_bytes += ds_len;

	return read_bytes;
}

bool tls_hs_skeyexc_need_to_send(TLS *tls)
{
	/*
	 * key exchange method has been set in shello.c:read_cipher_suite().
	 */
	switch(tls->keymethod) {
	/* RFC 5246 section 7.4.3. */
	case TLS_KXC_DHE_DSS:
	case TLS_KXC_DHE_RSA:
	case TLS_KXC_DH_anon:
	/* RFC 4492 section 5.4. */
	case TLS_KXC_ECDHE_ECDSA:
	case TLS_KXC_ECDHE_RSA:
	case TLS_KXC_ECDH_anon:
		return true;

	default:
		return false;
	}
}

struct tls_hs_msg *tls_hs_skeyexc_compose(TLS *tls)
{
	uint32_t offset = 0;

	struct tls_hs_msg *msg;

	if ((msg = tls_hs_msg_init()) == NULL) {
		TLS_DPRINTF("tls_hs_msg_init");
		return NULL;
	}

	/* ServerKeyExchange message has following structure.
	 *
	 * switch (KeyExchangeAlgorithm)
	 * case dh_anon:
	 * | type                      (1) |
	 * | length of message         (3) |
	 * | params                    (x) | (TODO: do implemetation)
	 *
	 * case dhe_dss:
	 * case dhe_rsa:
	 * | type                      (1) |
	 * | length of message         (3) |
	 * | params                    (x) | (TODO: do implemetation)
	 * | signed_params             (x) | (TODO: do implemetation)
	 *
	 * case rsa:
	 * case dh_dss:
	 * case dh_rsa:
	 *   (This function is not called.)
	 *
	 * case  ecdh_anon:
	 * | type                      (1) |
	 * | length of message         (3) |
	 * | params                    (n) |
	 *
	 * case  ecdhe_ecdsa:
	 * case  ecdhe_rsa:
	 * | type                      (1) |
	 * | length of message         (3) |
	 * | params                    (n) |
	 * | signed_params             (n) |
	 *
	 * case  ecdh_ecdsa:
	 * case  ecdh_rsa:
	 *   (This function is not called.)
	 */

	msg->type = TLS_HANDSHAKE_SERVER_KEY_EXCHANGE;

	int32_t params_len;
	params_len = write_params(tls, msg);
	if (params_len < 0) {
		goto failed;
	}
	offset += params_len;

	int32_t signed_params_len;
	signed_params_len = write_signed_params(tls, msg);
	if (signed_params_len < 0) {
		goto failed;
	}
	offset += signed_params_len;

	msg->len = offset;

	return msg;

failed:
	tls_hs_msg_free(msg);

	if (tls->ecdh != NULL) {
		tls_hs_ecdh_free(tls->ecdh);
		tls->ecdh = NULL;
	}

	return msg;
}

bool tls_hs_skeyexc_parse(TLS *tls, struct tls_hs_msg *msg)
{
	uint32_t offset = 0;

	if (msg->type != TLS_HANDSHAKE_SERVER_KEY_EXCHANGE) {
		TLS_DPRINTF("! TLS_HANDSHAKE_SERVER_KEY_EXCHANGE");
		OK_set_error(ERR_ST_TLS_UNEXPECTED_MSG,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SKEYEXC + 0, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
		return false;
	}

	int32_t paramslen;
	paramslen = read_params(tls, msg, offset);
	if (paramslen < 0) {
		TLS_DPRINTF("read_params");
		return false;
	}
	offset += paramslen;

	int32_t signed_paramslen;
	signed_paramslen = read_signed_params(tls, msg, offset);
	if (signed_paramslen < 0) {
		TLS_DPRINTF("read_signed_params");
		return false;
	}
	offset += signed_paramslen;

	if (msg->len != offset) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS3, ERR_PT_TLS_HS_MSG_SKEYEXC + 1, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return false;
	}

	return true;
}
