/*
 * Copyright (c) 2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

/* template for unit testing with Check <https://libcheck.github.io/check/> */
#include <stdlib.h>
#include <check.h>

#include <aicrypto/nrg_tls.h>
#include <stdbool.h>
#include "tls_handshake.h"

#include "client.c"

static int ut_error;
static int ut_location;
static int ut_point;

static void ut_error_setup()
{
	ut_error = 0;
	ut_location = 0;
	ut_point = 0;
}

static enum tls_alert_desc ut_alert_desc;

static void ut_alert_setup()
{
	ut_alert_desc = 0;
}

/*
 * fixture
 */
void setup(void)
{
	/* code */
}
void teardown(void)
{
	/* code */
}

static TLS global_tls;
void setup_tls(void)
{
	global_tls.client_version.major = TLS_MAJOR_MAX;
	global_tls.client_version.minor = TLS_MINOR_MAX;

	global_tls.negotiated_version.major = TLS_MAJOR_MAX;
	global_tls.negotiated_version.minor = TLS_MINOR_MAX;

	global_tls.record_version.major = TLS_MAJOR_MAX;
	global_tls.record_version.minor = TLS_MINOR_MAX;

	global_tls.errnum = 0;

	global_tls.interim_params = tls_hs_interim_params_init();

	ut_error_setup();
	ut_alert_setup();
}

void teardown_tls(void)
{
	if (global_tls.interim_params != NULL) {
		tls_hs_interim_params_free(global_tls.interim_params);
	}
}

/*
 * unit testing for establish_protocol_version()
 */
/* normal case */
START_TEST (test_establish_protocol_version_normal_tls12)
{
	bool rc;

	global_tls.client_version.major = TLS_MAJOR_TLS;
	global_tls.client_version.minor = TLS_MINOR_TLS12;

	global_tls.interim_params->version.major = TLS_MAJOR_TLS;
	global_tls.interim_params->version.minor = TLS_MINOR_TLS12;

	rc = establish_protocol_version(&global_tls);

	ck_assert_int_eq(rc, true);

	ck_assert_int_eq(global_tls.negotiated_version.major, TLS_MAJOR_TLS);
	ck_assert_int_eq(global_tls.record_version.major, TLS_MAJOR_TLS);

	ck_assert_int_eq(global_tls.negotiated_version.minor, TLS_MINOR_TLS12);
	ck_assert_int_eq(global_tls.record_version.minor, TLS_MINOR_TLS12);
}
END_TEST

/* if (tls->client_version.major != major) {... */
START_TEST (test_establish_protocol_version_failure_protocol_version_major)
{
	bool rc;

	global_tls.client_version.major = TLS_MAJOR_TLS;

	global_tls.interim_params->version.major = TLS_MAJOR_MAX;

	rc = establish_protocol_version(&global_tls);

	ck_assert_int_eq(rc, false);
	ck_assert_int_eq(global_tls.errnum, TLS_ERR_PROTOCOL_VERSION);
	ck_assert_int_eq(ut_error, ERR_ST_TLS_PROTOCOL_VERSION);
	ck_assert_int_eq(ut_location, ERR_LC_TLS5);
	ck_assert_int_eq(ut_point, ERR_PT_TLS_HS_CS_CLIENT2 + 4);
	ck_assert_int_eq(ut_alert_desc, TLS_ALERT_DESC_PROTOCOL_VERSION);

	ck_assert_int_eq(global_tls.negotiated_version.major, TLS_MAJOR_MAX);
	ck_assert_int_eq(global_tls.record_version.major, TLS_MAJOR_MAX);

	ck_assert_int_eq(global_tls.negotiated_version.minor, TLS_MINOR_MAX);
	ck_assert_int_eq(global_tls.record_version.minor, TLS_MINOR_MAX);
}
END_TEST

/* if (tls->client_version.minor != minor) {...*/
START_TEST (test_establish_protocol_version_failure_protocol_version_minor)
{
	bool rc;

	global_tls.client_version.major = TLS_MAJOR_TLS;
	global_tls.client_version.minor = TLS_MINOR_TLS12;

	global_tls.interim_params->version.major = TLS_MAJOR_TLS;
	global_tls.interim_params->version.minor = TLS_MINOR_MAX;

	rc = establish_protocol_version(&global_tls);

	ck_assert_int_eq(rc, false);
	ck_assert_int_eq(global_tls.errnum, TLS_ERR_PROTOCOL_VERSION);
	ck_assert_int_eq(ut_error, ERR_ST_TLS_PROTOCOL_VERSION);
	ck_assert_int_eq(ut_location, ERR_LC_TLS5);
	ck_assert_int_eq(ut_point, ERR_PT_TLS_HS_CS_CLIENT2 + 4);
	ck_assert_int_eq(ut_alert_desc, TLS_ALERT_DESC_PROTOCOL_VERSION);

	ck_assert_int_eq(global_tls.negotiated_version.major, TLS_MAJOR_MAX);
	ck_assert_int_eq(global_tls.record_version.major, TLS_MAJOR_MAX);

	ck_assert_int_eq(global_tls.negotiated_version.minor, TLS_MINOR_MAX);
	ck_assert_int_eq(global_tls.record_version.minor, TLS_MINOR_MAX);
}
END_TEST

Suite *establish_protocol_version_suite(void)
{
	Suite *s;
	TCase *tc_core;
	TCase *tc_limits;

	s = suite_create("establish_protocol_version()");

	/* Core test case */
	tc_core = tcase_create("Core");
	tcase_add_checked_fixture(tc_core, setup_tls, teardown_tls);
	tcase_add_test(tc_core, test_establish_protocol_version_normal_tls12);
	suite_add_tcase(s, tc_core);

	/* Limits test case */
	tc_limits = tcase_create("Limits");
	tcase_add_checked_fixture(tc_limits, setup_tls, teardown_tls);
	tcase_add_test(tc_limits, test_establish_protocol_version_failure_protocol_version_major);
	tcase_add_test(tc_limits, test_establish_protocol_version_failure_protocol_version_minor);
	suite_add_tcase(s, tc_limits);

	return s;
}

Suite *client_suite(void)
{
	Suite *s;

	return s = suite_create("client_suites");
}

int main(void)
{
	int number_failed;
	SRunner *sr;

	sr = srunner_create(client_suite());
	srunner_add_suite(sr, establish_protocol_version_suite());

	srunner_run_all(sr, CK_NORMAL);
	number_failed = srunner_ntests_failed(sr);
	srunner_free(sr);
	return (number_failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE;
}

/* ------------------------------------------------------------------------- */
#pragma GCC diagnostic ignored "-Wunused-parameter"
/* puts stub functions below this line. */

void OK_set_error(int error, int location, int point, CK_RV *info)
{
	ut_error = error;
	ut_location = location;
	ut_point = point;

	fprintf(stderr, "error:%0x location:%0x point:%0x\n",
		error, location, point);
	if (info != NULL) {
		fprintf(stderr, "*info=%zu\n", (size_t)(*info));
	} else {
		fprintf(stderr, "info=NULL\n");
	}
}

/* -----------------tls_cert.c---------------------------------------------- */

void tls_cert_type_free(struct tls_cert_type_list *list)
{
	;
}

/* -----------------tls_alert.c--------------------------------------------- */

bool tls_alert_send(TLS *tls,
		    const enum tls_alert_level level,
		    const enum tls_alert_desc desc)
{
	ut_alert_desc = desc;
	return true;
}

/* -----------------tls_util.c---------------------------------------------- */

uint16_t tls_util_read_2(uint8_t *buf) {
	return (((buf[0] << 8) & 0xff00) |
		((buf[1]     ) & 0x00ff));
}

void tls_util_convert_ver_to_protover(uint16_t version,
				      struct tls_protocol_version *version_st)
{
	version_st->major = (version >> 8) & 0xff;
	version_st->minor = (version     ) & 0xff;
}

uint16_t tls_util_convert_protover_to_ver(struct tls_protocol_version *version_st)
{
	return (version_st->major << 8) + ((version_st->minor) & 0x00ff);
}

bool tls_util_check_version_in_supported_version(
	struct tls_protocol_version_list *vlist,
	uint16_t version)
{
	for (int i = 0; i < vlist->len; i++) {
		if (vlist->list[i] == version) {
			return true;
		}
	}

	return false;
}

/* -----------------handshake/util/message.c-------------------------------- */

struct tls_hs_msg * tls_hs_msg_init()
{
	return NULL;
}

void tls_hs_msg_free(struct tls_hs_msg *p)
{
	;
}

/* -----------------tls_handshake.c----------------------------------------- */

void tls_hs_change_state(TLS *tls, enum tls_state state)
{
	;
}

bool tls_hs_check_state(TLS *tls, enum tls_state state) {
    return true;
}

void tls_hs_update_hash(TLS *tls) {
    return;
}
struct tls_hs_msg * tls_handshake_read(TLS *tls)
{
	return NULL;
}

bool tls_handshake_write(TLS *tls, struct tls_hs_msg * msg) {
	return true;
}

struct tls_extension *tls_extension_init(void)
{
	struct tls_extension *ext;

	if ((ext = malloc(1 * sizeof(struct tls_extension))) == NULL) {
		TLS_DPRINTF("extension: malloc: %s", strerror(errno));
		return NULL;
	}
	ext->opaque = NULL;

	return ext;
}

void tls_extension_free(struct tls_extension *ext)
{
	if (ext == NULL) {
		return;
	}

	free(ext->opaque);
	ext->opaque = NULL;

	free(ext);
}

struct tls_hs_interim_params *tls_hs_interim_params_init(void)
{
	struct tls_hs_interim_params *params;

	if ((params = calloc(1, sizeof(struct tls_hs_interim_params))) == NULL) {
		TLS_DPRINTF("interim_params: malloc: %s", strerror(errno));
		return NULL;
	}
	TAILQ_INIT(&(params->head));

	return params;
}

void tls_hs_interim_params_free(struct tls_hs_interim_params *params)
{
	struct tls_extension *ext;

	if (params == NULL) {
		return;
	}

	free(params->session);
	params->session = NULL;

	free(params->cmplist);
	params->cmplist = NULL;

	while (!TAILQ_EMPTY(&(params->head))) {
		ext = TAILQ_FIRST(&(params->head));
		TAILQ_REMOVE(&(params->head), ext, link);
		tls_extension_free(ext);
	}

	free(params);
}


/* -----------------handshake/extension/sighash.c--------------------------- */

void tls_hs_sighash_free(struct tls_hs_sighash_list *sighash)
{
	;
}

/* -----------------handshake/extension/supported_versions.c---------------- */
int32_t tls_hs_supported_versions_read(TLS *tls,
				       const struct tls_hs_msg *msg,
				       const uint32_t offset)
{
	return 0;
}

/* -----------------handshake/ecdh/ecdh.c----------------------------------- */

void tls_hs_ecdh_free(struct tls_hs_ecdh *ecdh)
{
	;
}

/* -----------------handshake/client-server/finale.c------------------------ */

enum hs_phase tls_hs_finale_write_first(TLS *tls)
{
	return TLS_HS_PHASE_DONE;
}

enum hs_phase tls_hs_finale_read_first(TLS *tls)
{
	return TLS_HS_PHASE_DONE;
}

/* -----------------handshake/message/ckeyexc.c----------------------------- */

struct tls_hs_msg * tls_hs_ckeyexc_compose(TLS *tls)
{
	return NULL;
}

/* -----------------handshake/message/cert.c-------------------------------- */

struct tls_hs_msg * tls_hs_ccert_compose(TLS *tls)
{
	return NULL;
}

/* -----------------handshake/message/certvfy.c----------------------------- */
struct tls_hs_msg * tls_hs_certvfy_compose(TLS *tls)
{
	return NULL;
}

/* -----------------handshake/message/finished.c---------------------------- */
struct tls_hs_msg * tls_hs_finished_compose(TLS *tls) {
	return NULL;
}
