#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <check.h>

static bool malloc_failure = false;
static bool calloc_failure = false;
static bool realloc_failure = false;

static void *test_calloc(size_t nmemb, size_t size)
{
	fprintf(stderr, "test_calloc(%zu, %zu)\n", nmemb, size);

	if (calloc_failure) {
		return NULL;
	}

	return calloc(nmemb, size);
}

static void *test_malloc(size_t size)
{
	fprintf(stderr, "test_malloc(%zu)\n", size);

	if (malloc_failure) {
		return NULL;
	}

	return malloc(size);
}

static void *test_realloc(void *ptr, size_t size)
{
	fprintf(stderr, "test_realloc(%p, %zu)\n", ptr, size);

	if (realloc_failure) {
		return NULL;
	}

	return realloc(ptr, size);
}

#define calloc(nmemb, size) test_calloc(nmemb, size)
#define malloc(size) test_malloc(size)
#define realloc(ptr, size) test_realloc(ptr, size)

#include "message.c"

static int ut_error;
static int ut_location;
static int ut_point;

static void ut_error_setup()
{
	ut_error = 0;
	ut_location = 0;
	ut_point = 0;
}

/*
 * fixture
 */
void setup(void)
{
	malloc_failure = false;
	calloc_failure = false;

	ut_error_setup();
}
void teardown(void)
{
	/* code */
}

static struct tls_hs_msg *msg0 = NULL;

void setup2(void)
{
	setup();
	msg0 = tls_hs_msg_init();
}

void teardown2(void)
{
	if (msg0 != NULL) {
		tls_hs_msg_free(msg0);
		msg0 = NULL;
	}
}

#if 0
/*
 * unit testing for realloc_message().
 */
START_TEST (test_realloc_message)
{
//static bool realloc_message(struct tls_hs_msg *msg)
}
END_TEST
#endif

/*
 * unit testing for tls_hs_msg_init().
 */
START_TEST (test_tls_hs_msg_init_normal)
{
	struct tls_hs_msg *msg = tls_hs_msg_init();

	ck_assert_ptr_nonnull(msg);

	/* Cannot test because the value is indeterminate. */
	/* ck_assert_ptr_ne(msg->link, NULL); */
	/* ck_assert_int_eq(msg->type, TLS_HANDSHAKE_HELLO_REQUEST); */

	ck_assert_uint_eq(msg->len, 0);
	ck_assert_uint_eq(msg->max, TLS_RECORD_PLAIN_FRAGMENT_SIZE_MAX);
	ck_assert_ptr_nonnull(msg->msg);

	tls_hs_msg_free(msg);
}
END_TEST

/* if ((msg = malloc(1 * sizeof (struct tls_hs_msg))) == NULL) {... */
START_TEST (test_tls_hs_msg_init_failure_malloc)
{
	struct tls_hs_msg *msg;

	malloc_failure = true;
	msg = tls_hs_msg_init();

	ck_assert_ptr_null(msg);

	ck_assert_int_eq(ut_error, ERR_ST_TLS_MALLOC);
	ck_assert_int_eq(ut_location, ERR_LC_TLS3);
	ck_assert_int_eq(ut_point, ERR_PT_TLS_HS_UTIL_MSG + 1);
}
END_TEST

/* if ((msg->msg = calloc(1, TLS_RECORD_PLAIN_FRAGMENT_SIZE_MAX)) == NULL) {... */
START_TEST (test_tls_hs_msg_init_failure_calloc)
{
	struct tls_hs_msg *msg;

	calloc_failure = true;
	msg = tls_hs_msg_init();

	ck_assert_ptr_null(msg);

	ck_assert_int_eq(ut_error, ERR_ST_TLS_CALLOC);
	ck_assert_int_eq(ut_location, ERR_LC_TLS3);
	ck_assert_int_eq(ut_point, ERR_PT_TLS_HS_UTIL_MSG + 2);
}
END_TEST

Suite *tls_hs_msg_init_suite(void)
{
	Suite *s;
	TCase *tc_core;
	TCase *tc_limits;
	s = suite_create("tls_hs_msg_init");

	/* Core test case */
	tc_core = tcase_create("Core");
	tcase_add_checked_fixture(tc_core, setup, teardown);
	tcase_add_test(tc_core, test_tls_hs_msg_init_normal);
	suite_add_tcase(s, tc_core);

	/* Limits test case */
	tc_limits = tcase_create("Limits");
	tcase_add_checked_fixture(tc_limits, setup, teardown);
	tcase_add_test(tc_limits, test_tls_hs_msg_init_failure_malloc);
	tcase_add_test(tc_limits, test_tls_hs_msg_init_failure_calloc);
	suite_add_tcase(s, tc_limits);

	return s;
}

/*
 * unit testing for tls_hs_msg_free().
 */
START_TEST (test_tls_hs_msg_free_normal)
{
	struct tls_hs_msg *msg = tls_hs_msg_init();
	tls_hs_msg_free(msg);
}
END_TEST

Suite *tls_hs_msg_free_suite(void)
{
	Suite *s;
	TCase *tc_core;
	s = suite_create("tls_hs_msg_free");

	/* Core test case */
	tc_core = tcase_create("Core");
	tcase_add_checked_fixture(tc_core, setup, teardown);
	tcase_add_test(tc_core, test_tls_hs_msg_free_normal);
	suite_add_tcase(s, tc_core);

	return s;
}

/*
 * unit testing for tls_hs_msg_write_1().
 */
START_TEST (test_tls_hs_msg_write_1_normal)
{
	int i;
	bool rc;
	const uint8_t dat1 = 0x61;
	const uint8_t dat2 = 0x7a;

	for (i = 0; i < 4; i++) {
		ck_assert_int_eq(msg0->msg[i], 0);
	}
	ck_assert_int_eq(msg0->len, 0);

	rc = tls_hs_msg_write_1(msg0, dat1);

	ck_assert_int_eq(rc, true);
	ck_assert_int_eq(msg0->len, 1);
	ck_assert_int_eq(msg0->msg[0], dat1);
	ck_assert_int_eq(msg0->msg[1], 0);
	ck_assert_int_eq(msg0->msg[2], 0);
	ck_assert_int_eq(msg0->msg[3], 0);

	rc = tls_hs_msg_write_1(msg0, dat2);

	ck_assert_int_eq(rc, true);
	ck_assert_int_eq(msg0->len, 2);
	ck_assert_int_eq(msg0->msg[0], dat1);
	ck_assert_int_eq(msg0->msg[1], dat2);
	ck_assert_int_eq(msg0->msg[2], 0);
	ck_assert_int_eq(msg0->msg[3], 0);
}
END_TEST

/* if ((msg->len + size) > msg->max) {... */
START_TEST (test_tls_hs_msg_write_1_normal_realloc)
{
	bool rc;
	uint32_t oldmax;
	const uint8_t dat4 = 0x64;

	msg0->msg[0] = 0x60;
	msg0->msg[1] = 0x61;
	msg0->msg[2] = 0x62;
	msg0->msg[3] = 0x63;

	oldmax = msg0->max;

	msg0->len = msg0->max;
	rc = tls_hs_msg_write_1(msg0, dat4);

	ck_assert_int_eq(rc, true);
	ck_assert_int_eq(msg0->len, oldmax + 1);
	ck_assert_int_eq(msg0->msg[0], 0x60);
	ck_assert_int_eq(msg0->msg[1], 0x61);
	ck_assert_int_eq(msg0->msg[2], 0x62);
	ck_assert_int_eq(msg0->msg[3], 0x63);
	ck_assert_int_eq(msg0->msg[oldmax], 0x64);
}
END_TEST

Suite *tls_hs_msg_write_1_suite(void)
{
	Suite *s;
	TCase *tc_core;
	s = suite_create("tls_hs_msg_write_1()");

	/* Core test case */
	tc_core = tcase_create("Core");

	tcase_add_checked_fixture(tc_core, setup2, teardown2);
	tcase_add_test(tc_core, test_tls_hs_msg_write_1_normal);
	tcase_add_test(tc_core, test_tls_hs_msg_write_1_normal_realloc);
	suite_add_tcase(s, tc_core);

	return s;
}

/*
 * unit testing for tls_hs_msg_write_2().
 */
START_TEST (test_tls_hs_msg_write_2_normal)
{
	int i;
	bool rc;
	const uint16_t dat1 = 0x6263;
	const uint16_t dat2 = 0x6465;

	for (i = 0; i < 4; i++) {
		ck_assert_int_eq(msg0->msg[i], 0);
	}
	ck_assert_int_eq(msg0->len, 0);

	rc = tls_hs_msg_write_2(msg0, dat1);

	ck_assert_int_eq(rc, true);
	ck_assert_int_eq(msg0->len, 2);
	ck_assert_int_eq(msg0->msg[0], 0x62);
	ck_assert_int_eq(msg0->msg[1], 0x63);
	ck_assert_int_eq(msg0->msg[2], 0);
	ck_assert_int_eq(msg0->msg[3], 0);

	rc = tls_hs_msg_write_2(msg0, dat2);

	ck_assert_int_eq(rc, true);
	ck_assert_int_eq(msg0->len, 4);
	ck_assert_int_eq(msg0->msg[0], 0x62);
	ck_assert_int_eq(msg0->msg[1], 0x63);
	ck_assert_int_eq(msg0->msg[2], 0x64);
	ck_assert_int_eq(msg0->msg[3], 0x65);
}
END_TEST

/* if ((msg->len + size) > msg->max) {... */
START_TEST (test_tls_hs_msg_write_2_normal_realloc)
{
	bool rc;
	uint32_t oldmax;
	const uint16_t dat4 = 0x6465;

	msg0->msg[0] = 0x60;
	msg0->msg[1] = 0x61;
	msg0->msg[2] = 0x62;
	msg0->msg[3] = 0x63;

	oldmax = msg0->max;

	/* make to look like filled msg0->msg buffer. */
	msg0->len = msg0->max;
	rc = tls_hs_msg_write_2(msg0, dat4);

	ck_assert_int_eq(rc, true);
	ck_assert_int_eq(msg0->len, oldmax + 2);
	ck_assert_int_eq(msg0->msg[0], 0x60);
	ck_assert_int_eq(msg0->msg[1], 0x61);
	ck_assert_int_eq(msg0->msg[2], 0x62);
	ck_assert_int_eq(msg0->msg[3], 0x63);

	/* check a data is written in the expanded buffer. */
	ck_assert_int_eq(msg0->msg[oldmax+0], 0x64);
	ck_assert_int_eq(msg0->msg[oldmax+1], 0x65);
}
END_TEST

Suite *tls_hs_msg_write_2_suite(void)
{
	Suite *s;
	TCase *tc_core;
	s = suite_create("tls_hs_msg_write_2()");

	/* Core test case */
	tc_core = tcase_create("Core");

	tcase_add_checked_fixture(tc_core, setup2, teardown2);
	tcase_add_test(tc_core, test_tls_hs_msg_write_2_normal);
	tcase_add_test(tc_core, test_tls_hs_msg_write_2_normal_realloc);
	suite_add_tcase(s, tc_core);

	return s;
}

/*
 * unit testing for tls_hs_msg_write_3().
 */
START_TEST (test_tls_hs_msg_write_3_normal)
{
	int i;
	bool rc;
	const uint32_t dat1 = 0x666768;
	const uint32_t dat2 = 0x696a6b;

	for (i = 0; i < 8; i++) {
		ck_assert_int_eq(msg0->msg[i], 0);
	}
	ck_assert_int_eq(msg0->len, 0);

	rc = tls_hs_msg_write_3(msg0, dat1);

	ck_assert_int_eq(rc, true);
	ck_assert_int_eq(msg0->len, 3);
	ck_assert_int_eq(msg0->msg[0], 0x66);
	ck_assert_int_eq(msg0->msg[1], 0x67);
	ck_assert_int_eq(msg0->msg[2], 0x68);
	ck_assert_int_eq(msg0->msg[3], 0);

	rc = tls_hs_msg_write_3(msg0, dat2);

	ck_assert_int_eq(rc, true);
	ck_assert_int_eq(msg0->len, 6);
	ck_assert_int_eq(msg0->msg[0], 0x66);
	ck_assert_int_eq(msg0->msg[1], 0x67);
	ck_assert_int_eq(msg0->msg[2], 0x68);
	ck_assert_int_eq(msg0->msg[3], 0x69);
	ck_assert_int_eq(msg0->msg[4], 0x6a);
	ck_assert_int_eq(msg0->msg[5], 0x6b);
	ck_assert_int_eq(msg0->msg[6], 0);
	ck_assert_int_eq(msg0->msg[7], 0);
}
END_TEST

/* if ((msg->len + size) > msg->max) {... */
START_TEST (test_tls_hs_msg_write_3_normal_realloc)
{
	bool rc;
	uint32_t oldmax;
	const uint32_t dat4 = 0x646566;

	msg0->msg[0] = 0x60;
	msg0->msg[1] = 0x61;
	msg0->msg[2] = 0x62;
	msg0->msg[3] = 0x63;

	oldmax = msg0->max;

	/* make to look like filled msg0->msg buffer. */
	msg0->len = msg0->max;
	rc = tls_hs_msg_write_3(msg0, dat4);

	ck_assert_int_eq(rc, true);
	ck_assert_int_eq(msg0->len, oldmax + 3);
	ck_assert_int_eq(msg0->msg[0], 0x60);
	ck_assert_int_eq(msg0->msg[1], 0x61);
	ck_assert_int_eq(msg0->msg[2], 0x62);
	ck_assert_int_eq(msg0->msg[3], 0x63);

	/* check a data is written in the expanded buffer. */
	ck_assert_int_eq(msg0->msg[oldmax+0], 0x64);
	ck_assert_int_eq(msg0->msg[oldmax+1], 0x65);
	ck_assert_int_eq(msg0->msg[oldmax+2], 0x66);
}
END_TEST

Suite *tls_hs_msg_write_3_suite(void)
{
	Suite *s;
	TCase *tc_core;
	s = suite_create("tls_hs_msg_write_3()");

	/* Core test case */
	tc_core = tcase_create("Core");

	tcase_add_checked_fixture(tc_core, setup2, teardown2);
	tcase_add_test(tc_core, test_tls_hs_msg_write_3_normal);
	tcase_add_test(tc_core, test_tls_hs_msg_write_3_normal_realloc);
	suite_add_tcase(s, tc_core);

	return s;
}

#if 0
/*
 * unit testing for tls_hs_msg_write_n().
 */
START_TEST (test_tls_hs_msg_write_n)
{
	/*
bool tls_hs_msg_write_n(struct tls_hs_msg *msg, const uint8_t *dat,const uint32_t len)
	 */
}
END_TEST

Suite *tls_hs_msg_write_n_suite(void)
{
	Suite *s;
	TCase *tc_core;
	s = suite_create("tls_hs_msg_write_n()");

	/* Core test case */
	tc_core = tcase_create("Core");

	tcase_add_checked_fixture(tc_core, setup2, teardown2);
	tcase_add_test(tc_core, test_tls_hs_msg_write_n_normal);
	suite_add_tcase(s, tc_core);

	return s;
}
#endif

/**
 * test name (src file) suite (dummy).
 */
Suite *message_suite(void)
{
	Suite *s;

	return s = suite_create("message_suites");
}

int main(void)
{
	int number_failed;
	SRunner *sr;

	sr = srunner_create(message_suite());
	srunner_add_suite(sr, tls_hs_msg_init_suite());
	srunner_add_suite(sr, tls_hs_msg_free_suite());
	srunner_add_suite(sr, tls_hs_msg_write_1_suite());
	srunner_add_suite(sr, tls_hs_msg_write_2_suite());
	srunner_add_suite(sr, tls_hs_msg_write_3_suite());
#if 0
	srunner_add_suite(sr, tls_hs_msg_write_n_suite());
#endif

	srunner_run_all(sr, CK_NORMAL);
	//srunner_run_all(sr, CK_VERBOSE);
	number_failed = srunner_ntests_failed(sr);
	srunner_free(sr);
	return (number_failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE;
}

/* ------------------------------------------------------------------------- */
#pragma GCC diagnostic ignored "-Wunused-parameter"
/* puts stub functions below this line. */

void OK_set_error(int error, int location, int point, CK_RV *info)
{
	ut_error = error;
	ut_location = location;
	ut_point = point;

	fprintf(stderr, "error:%0x location:%0x point:%0x\n",
		error, location, point);
	if (info != NULL) {
		fprintf(stderr, "*info=%zu\n", (size_t)(*info));
	} else {
		fprintf(stderr, "info=NULL\n");
	}
}

/* Copied the implementation from tls/tls_util.c */
void tls_util_write_2(uint8_t *buf, int32_t val)
{
	buf[0] = ((val) >> 8) & 0xff;
	buf[1] = ((val)     ) & 0xff;
}

/* Copied the implementation from tls/tls_util.c */
void tls_util_write_3(uint8_t *buf, int32_t val) {
	buf[0] = ((val) >> 16) & 0xff;
	buf[1] = ((val) >>  8) & 0xff;
	buf[2] = ((val)      ) & 0xff;
}
