/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_record.h"
#include "tls_handshake.h"

#include <string.h>

/**
 * realloc the memory in the case the memory became lack.
 *
 * when memory became lack in the default allocated size of tls_hs_msg
 * structure, allocates TLS_RECORD_PLAIN_FRAGMENT_SIZE_MAX size
 * furthermore.
 */
static bool realloc_message(struct tls_hs_msg *msg);

static bool realloc_message(struct tls_hs_msg *msg) {
	uint32_t len = msg->max + TLS_RECORD_PLAIN_FRAGMENT_SIZE_MAX;

	/* TODO: should retry if failed ? */
	uint8_t *buf;
	if ((buf = realloc(msg->msg, len)) == NULL) {
		/* TODO: care realloc failed */
		/* this code means the realloc is failed, but originally
		 * allocated memory is still valid. So, do not free
		 * memory in this routine. */
		TLS_DPRINTF("hs: u: msg: realloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_REALLOC, ERR_LC_TLS3,
			     ERR_PT_TLS_HS_UTIL_MSG + 0, NULL);
		return false;
	}

	msg->msg = buf;
	msg->max = len;

	return true;
}

struct tls_hs_msg * tls_hs_msg_init(void) {
	struct tls_hs_msg *msg;

	if ((msg = malloc(1 * sizeof (struct tls_hs_msg))) == NULL) {
		TLS_DPRINTF("hs: u: msg: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC, ERR_LC_TLS3,
			     ERR_PT_TLS_HS_UTIL_MSG + 1, NULL);
		return NULL;
	}

	if ((msg->msg = calloc(1,
			       TLS_RECORD_PLAIN_FRAGMENT_SIZE_MAX)) == NULL) {
		TLS_DPRINTF("hs: u: msg: calloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_CALLOC, ERR_LC_TLS3,
			     ERR_PT_TLS_HS_UTIL_MSG + 2, NULL);
		free(msg);
		return NULL;
	}

	msg->len = 0;
	msg->max = TLS_RECORD_PLAIN_FRAGMENT_SIZE_MAX;

	return msg;
}

void tls_hs_msg_free(struct tls_hs_msg *msg) {
	free(msg->msg);
	msg->msg = NULL;
	msg->len = 0;
	msg->max = 0;

	free(msg);
}

bool tls_hs_msg_write_1(struct tls_hs_msg *msg, const uint8_t dat) {
	const uint8_t size = 1;

	if ((msg->len + size) > msg->max) {
		if (! realloc_message(msg)) {
			TLS_DPRINTF("realloc_message");
			return false;
		}
	}

	msg->msg[msg->len] = dat;

	msg->len += size;

	return true;
}

bool tls_hs_msg_write_2(struct tls_hs_msg *msg, const uint16_t dat) {
	const uint8_t size = 2;

	if ((msg->len + size) > msg->max) {
		if (! realloc_message(msg)) {
			TLS_DPRINTF("realloc_message");
			return false;
		}
	}

	tls_util_write_2(&(msg->msg[msg->len]), dat);

	msg->len += size;

	return true;
}

bool tls_hs_msg_write_3(struct tls_hs_msg *msg, const uint32_t dat) {
	const uint8_t size = 3;

	if ((msg->len + size) > msg->max) {
		if (! realloc_message(msg)) {
			TLS_DPRINTF("realloc_message");
			return false;
		}
	}

	tls_util_write_3(&(msg->msg[msg->len]), dat);

	msg->len += size;

	return true;
}

bool tls_hs_msg_write_n(struct tls_hs_msg *msg, const uint8_t *dat,
			const uint32_t len) {
	if (msg->len + len > msg->max) {
		if (! realloc_message(msg)) {
			TLS_DPRINTF("realloc_message");
			return false;
		}
	}

	memcpy(&(msg->msg[msg->len]), dat, len);

	msg->len += len;

	return true;
}
