/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls.h"
#include "tls_session.h"
#include "tls_record.h"
/* TODO : XXX for tls_hs_*_alloc() in init_handshake() */
#include "tls_handshake.h"
#include "tls_cipher.h"
#include "tls_cert.h"

/* for free_u2j_table */
#include <aicrypto/ok_uconv.h>

/* for RAND_cleanup */
#include <aicrypto/ok_rand.h>

/* for Key_free */
#include <aicrypto/ok_x509.h>

/* for P12_dup and P12_free */
#include <aicrypto/ok_pkcs12.h>

/* for STM_close */
#include <aicrypto/ok_store.h>

/**
 * initialize TLS/SSL version of tls structure.
 */
static void init_version(TLS *tls);

/**
 * initialize session related member of tls structure.
 */
static void init_session(TLS *tls);

/**
 * initialize certificate related member of tls structure.
 */
static void init_certs(TLS *tls);

/**
 * initialize cipher parameter (active_read and active_write member) of
 * tls structure.
 */
static void init_cipher_param(TLS *tls);

/**
 *  initialize handshake interim parameters in tls structure.
 */
static bool init_interim_params(TLS *tls);

/**
 * initialize queue for TLS/SSL record layer in tls structure.
 */
static bool init_record_queue(TLS *tls);

/**
 * initialize handshake related member of tls structure.
 */
static bool init_handshake(TLS *tls);

/**
 * initialize option of tls structure.
 */
static void init_option(TLS *tls);

/**
 * copy option of tls structure.
 */
static void copy_option(TLS *dest, const TLS *src);

/**
 * copy certificate of tls structure.
 *
 * in here, also copy instance of store manager.  since aicrypto API
 * related with store manager (function that has prefix STM_) do not
 * provide the cloning function of store manager, make new instance of
 * store manager from path of store manager that saved beforehand.
 */
static bool copy_certs(TLS *dest, const TLS *src);

/**
 * free the memory related with cipher parameter that is allocated in tl
 * structure (active_read and active_write member).
 */
static void free_cipher_param(TLS *tls);

/**
 * free the memory related with the certificate that is allocated in tls
 * structure.
 */
static void free_certs(TLS *tls);

/**
 * free the memory related with handshake interim parameters that is allocated
 * in tls structure.
 */
static void free_interim_params(TLS *tls);

/**
 * free the memory related with queue of record layer that is allocated
 * in tls structure.
 */
static void free_record_queue(TLS *tls);

/**
 * free the memory related with cookie that is allocated in tls structure.
 */
static void free_cookie(TLS *tls);

/**
 * free the memory related with session layer that is allocated in tls
 * structure.
 */
static void free_session(TLS *tls);

uint16_t tls_protocol_version_list_default[] = {
	/* TLS 1.3 */
	TLS_VER_TLS13,

	/* TLS 1.2 */
	TLS_VER_TLS12
};

static void init_version(TLS *tls) {
	/* initial version of negotiate version is -1. these values are
	 * initialized by client hello and server hello. */
	tls->negotiated_version.major = - 1;
	tls->negotiated_version.minor = - 1;

	tls->client_version.major = - 1;
	tls->client_version.minor = - 1;

	tls->record_version.major = - 1;
	tls->record_version.minor = - 1;

	tls->supported_versions.len = sizeof(tls_protocol_version_list_default) / sizeof(uint16_t);
	tls->supported_versions.list = tls_protocol_version_list_default;
}

static void init_session(TLS *tls) {
	tls->resession = false;

	tls->pending = NULL;
}

static void init_certs(TLS *tls) {
	tls->pkcs12_server = NULL;
	tls->pkcs12_client = NULL;

	tls->store_manager_path = NULL;
	tls->store_manager      = NULL;

	tls->server_name = NULL;
}

static void init_cipher_param(TLS *tls) {
	struct tls_cipher_param cipher = {
		.cipher_algorithm = TLS_BULK_CIPHER_NULL,
		.cipher_type      = TLS_CIPHER_TYPE_STREAM,
		.mac_algorithm    = TLS_MAC_NULL,
		.mac_length       = 0,
		.prf_algorithm    = TLS_PRF_SHA256,
	};

	tls->premaster_secret_len = 0;

	tls->active_read.cipher  = cipher;
	tls->active_read.mac_key = NULL;
	tls->active_read.key     = NULL;
	tls->active_read.seqnum  = 0;
	tls->active_read.compression_algorithm = TLS_COMPRESSION_NULL;
	tls->active_read.closed  = false;

	tls->active_write.cipher  = cipher;
	tls->active_write.mac_key = NULL;
	tls->active_write.key     = NULL;
	tls->active_write.seqnum  = 0;
	tls->active_write.compression_algorithm = TLS_COMPRESSION_NULL;
	tls->active_write.closed  = false;
}

static bool init_interim_params(TLS *tls) {
	if ((tls->interim_params = tls_hs_interim_params_init()) == NULL) {
		TLS_DPRINTF("tls_hs_interim_params_init");
		return false;
	}

	return true;
}

static bool init_record_queue(TLS *tls) {
	if ((tls->queue_control = tls_record_init()) == NULL) {
		TLS_DPRINTF("tls_record_init (control)");
		OK_set_error(ERR_ST_TLS_TLS_RECORD_INIT,
			     ERR_LC_TLS1, ERR_PT_TLS + 0, NULL);
		return false;
	}

	if ((tls->queue_data = tls_record_init()) == NULL) {
		TLS_DPRINTF("tls_record_init (data)");
		OK_set_error(ERR_ST_TLS_TLS_RECORD_INIT,
			     ERR_LC_TLS1, ERR_PT_TLS + 2, NULL);
		return false;
	}

	return true;
}

static bool init_handshake(TLS *tls) {
	if (init_interim_params(tls) == false) {
		TLS_DPRINTF("init_interim_params");
		return false;
	}

	if (init_record_queue(tls) == false) {
		TLS_DPRINTF("init_record_queue");
		OK_set_error(ERR_ST_TLS_INIT_RECORD_Q,
			     ERR_LC_TLS1, ERR_PT_TLS + 3, NULL);
		return false;
	}

	tls->certtype_list_server = NULL;
	tls->sighash_list         = NULL;
	tls->sighash_list_cert    = NULL;
	tls->ecdh                 = NULL;

	/* TODO: XXX: It should be moved to a more appropriate place. */
	tls_hs_ecdh_alloc(tls);

	tls->ccert_null     = false;
	tls->certreq_used   = false;
	tls->is_ccert_auth  = false;
	tls->handshake_over = false;

	return true;
}

static void init_option(TLS *tls) {
	tls->opt.timeout.tv_sec  = 10 * 60;
	tls->opt.timeout.tv_nsec = 0;

	tls->opt.verify_type  = 0;
	tls->opt.verify_depth = 8;

	tls->opt.immediate_handshake     = false;
	tls->opt.use_certreq             = false;
	tls->opt.skip_ccertificate_check = false;
}

static void copy_option(TLS *dest, const TLS *src) {
	dest->opt.timeout.tv_sec  = src->opt.timeout.tv_sec;
	dest->opt.timeout.tv_nsec = src->opt.timeout.tv_nsec;

	dest->opt.verify_type  = src->opt.verify_type;
	dest->opt.verify_depth = src->opt.verify_depth;

	dest->opt.immediate_handshake     = src->opt.immediate_handshake;
	dest->opt.use_certreq             = src->opt.use_certreq;
	dest->opt.skip_ccertificate_check = src->opt.skip_ccertificate_check;
}

static bool copy_certs(TLS *dest, const TLS *src) {
	dest->pkcs12_client = P12_dup(src->pkcs12_client);
	dest->pkcs12_server = P12_dup(src->pkcs12_server);

	if ((src->store_manager == NULL) ||
	    (src->store_manager_path == NULL)) {
		return true;
	}

	if (TLS_stm_set(dest, src->store_manager_path) == false) {
		return false;
	}

	if (src->server_name != NULL) {
		if ((dest->server_name = strdup(src->server_name)) == NULL) {
		OK_set_error(ERR_ST_TLS_STRDUP,
			     ERR_LC_TLS1, ERR_PT_TLS + 1, NULL);
			return false;
		}
	}

	return true;
}

static void free_cipher_param(TLS *tls) {
	if (tls->active_read.mac_key != NULL) {
		free(tls->active_read.mac_key);
		tls->active_read.mac_key = NULL;
	}

	if (tls->active_read.key != NULL) {
		Key_free(tls->active_read.key);
		tls->active_read.key = NULL;
	}

	if (tls->active_read.secret != NULL) {
		free(tls->active_read.secret);
		tls->active_read.secret = NULL;
	}

	if (tls->active_write.mac_key != NULL) {
		free(tls->active_write.mac_key);
		tls->active_write.mac_key = NULL;
	}

	if (tls->active_write.key != NULL) {
		Key_free(tls->active_write.key);
		tls->active_write.key = NULL;
	}

	if (tls->active_write.secret != NULL) {
		free(tls->active_write.secret);
		tls->active_write.secret = NULL;
	}

	if (tls->early_secret != NULL) {
		free(tls->early_secret);
		tls->early_secret = NULL;
	}

	if (tls->handshake_secret != NULL) {
		free(tls->handshake_secret);
		tls->handshake_secret = NULL;
	}

	if (tls->chello_cipher_suites != NULL) {
		tls_cipher_list_free(tls->chello_cipher_suites);
		tls->chello_cipher_suites = NULL;
	}

	if (tls->ecdh != NULL) {
		tls_hs_ecdh_free(tls->ecdh);
		tls->ecdh = NULL;
	}

	/*
	 * tls->skeyexc_params must not be freed.
	 * this variable has tls_hs_msg buffer that will be freed at another
	 * place.
	 */

	tls->active_read.seqnum  = 0;
	tls->active_write.seqnum = 0;

	tls->active_read.closed  = true;
	tls->active_write.closed = true;
}


static void free_certs(TLS *tls) {
	if (tls->pkcs12_client != NULL) {
		P12_free(tls->pkcs12_client);
		tls->pkcs12_client = NULL;
	}

	if (tls->pkcs12_server != NULL) {
		P12_free(tls->pkcs12_server);
		tls->pkcs12_server = NULL;
	}

	if (tls->store_manager != NULL) {
		STM_close(tls->store_manager);
		tls->store_manager = NULL;

		free(tls->store_manager_path);
		tls->store_manager_path = NULL;
	}

	if (tls->certtype_list_server != NULL) {
		tls_cert_type_free(tls->certtype_list_server);
		tls->certtype_list_server = NULL;
	}

	if (tls->server_name != NULL) {
		free(tls->server_name);
		tls->server_name = NULL;
	}
}

static void free_interim_params(TLS *tls) {
	tls_hs_interim_params_free(tls->interim_params);
	tls->interim_params = NULL;

	tls_hs_interim_params_free(tls->first_hello_params);
	tls->first_hello_params = NULL;
}

static void free_record_queue(TLS *tls) {
	tls_record_free(tls->queue_control);
	tls->queue_control = NULL;

	tls_record_free(tls->queue_data);
	tls->queue_data = NULL;
}

static void free_cookie(TLS *tls) {
		tls_hs_cookie_free(tls->cookie);
		tls->cookie = NULL;
}

static void free_session(TLS *tls) {
	if (tls->pending != NULL) {
		tls_session_disable(tls->pending);
		tls_session_unrefer(tls->pending);
		tls->pending = NULL;
	}
}

/**
 * free the memory related with connection of tls that is allocated tls
 * structure.
 *
 * this function is called by tls module globally. for example, when the
 * implementation receive the fatal level alert, since it is necessary
 * to cut the connection, this function is called.
 *
 * this function is not defined the prototype declaration in the any tls
 * header file because the position that is called is tls_alert.c file
 * only.
 */
void tls_free_connection(TLS *tls) {
	if (tls != NULL) {
		free_interim_params(tls);
		free_record_queue(tls);
		free_cipher_param(tls);
		free_cookie(tls);
		free_session(tls);

		tls_handshake_free(tls);

		tls->state = TLS_STATE_CLOSED;
	}
}

void TLS_init(void) {
	/* TODO: do nothing in the current implementation. initialize
	 * session? */
	;
}

void TLS_cleanup(void) {
	TLS_session_free_all();

	RAND_cleanup();
	free_u2j_table();
}

TLS* TLS_new (void) {
	TLS *tls;

#if TLS_DEBUG
	setbuf(stdout, NULL);
	setbuf(stderr, NULL);
#endif /* TLS_DEBUG */

	if ((tls = calloc (1, sizeof(struct tls))) == NULL) {
		TLS_DPRINTF("calloc %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_CALLOC,
			     ERR_LC_TLS1, ERR_PT_TLS + 4, NULL);
		return NULL;
	}

	tls->entity = TLS_CONNECT_CLIENT;

	init_version(tls);

	init_session(tls);

	init_certs(tls);

	init_option(tls);

	init_cipher_param(tls);

	if (! init_handshake(tls)) {
		TLS_DPRINTF("init_handshake");
		OK_set_error(ERR_ST_TLS_INIT_HS,
			     ERR_LC_TLS1, ERR_PT_TLS + 5, NULL);
		return NULL;
	}

	tls->state  = TLS_STATE_CLOSED;
	tls->errnum = 0;

	return tls;
}

TLS* TLS_dup(const TLS *tls) {
	TLS *newone;

	if (tls == NULL) {
		return NULL;
	}

	if ((newone = TLS_new()) == NULL) {
		TLS_DPRINTF("TLS_new");
		OK_set_error(ERR_ST_TLS_TLS_NEW,
			     ERR_LC_TLS1, ERR_PT_TLS + 6, NULL);
		return NULL;
	}

	newone->entity = tls->entity;

	if (copy_certs(newone, tls) == false) {
		TLS_DPRINTF("copy_certs");
		OK_set_error(ERR_ST_TLS_COPY_CERTS,
			     ERR_LC_TLS1, ERR_PT_TLS + 7, NULL);
		TLS_free(newone);
		return NULL;
	}

	copy_option(newone, tls);

	return newone;
}

void TLS_free (TLS* tls) {
	if (tls == NULL) {
		/*
		 * No error set here, because TLS_free() is called multiple
		 * times. This implementation is same as SSL_free().
		 */
		return ;
	}

	tls_free_connection(tls);

	free_certs(tls);

	if (tls->sighash_list != NULL) {
		tls_hs_sighash_free(tls->sighash_list);
		tls->sighash_list = NULL;
	}

	if (tls->sighash_list_cert != NULL) {
		tls_hs_sighash_free(tls->sighash_list_cert);
		tls->sighash_list_cert = NULL;
	}

	free(tls->peer_supported_versions.list);

	free(tls);
	tls = NULL;
}

enum tls_errno TLS_get_error(TLS *tls) {
	if (tls == NULL) {
		TLS_DPRINTF("tls (null)");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS + 9, NULL);
		return 0;
	}

	return tls->errnum;
}
