/*
 * Copyright (c) 2015-2017 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls.h"
#include "tls_session.h"

#include <string.h>

/**
 * the structure that manages list of session.
 */
struct tls_session_list {
	/** whether head is initialized. */
	bool init;

	/** current length of session list */
	uint32_t len;

	/** maximum length of session list */
	uint32_t max;

	/** head of session list */
	TAILQ_HEAD(tls_session_head, tls_session_param) head;
};

/**
 * initialize list of session.
 *
 * internally, this function initialize variable session_list.
 */
static void session_init(void);

/**
 * free the disabled and expired one among the memory that is allocated
 * in the session_list variable.
 */
static uint32_t session_clean(void);

/**
 * free the memory that is allocated in the specified session.
 */
static void session_free(struct tls_session_param *session);

/**
 * default list length of sessin_list variable.
 */
#define SESSION_LIST_MAX_LENGTH_DEFAULT	100

/**
 * variable that manages list of session.
 */
static struct tls_session_list session_list = {
	.init = false,
	.len  = 0,
	.max  = SESSION_LIST_MAX_LENGTH_DEFAULT
};

static void session_init(void) {
	TAILQ_INIT(&(session_list.head));
	session_list.init = true;
}

static uint32_t session_clean(void) {
	/* clean old sessions */
	uint32_t count = 0;

	const uint32_t epochtime_24hours = 24 * 60 * 60;

	uint64_t epochtime = tls_util_get_epochtime();

	struct tls_session_param *session, *next_session;

	TAILQ_FOREACH_SAFE(session, &(session_list.head), link, next_session) {
		if (session->references > 0) {
			continue;
		}

		if (session->disabled == true) {
			TLS_DPRINTF("session: free disable session");
			session_free(session);
			session = NULL;
			count++;
		}

		if (epochtime >
		    (session->created_epochtime + epochtime_24hours)) {
			TLS_DPRINTF("session: free outdated session");
			session_free(session);
			session = NULL;
			count++;
		}
	}

	TLS_DPRINTF("session: free session of the %d num.", count);
	return count;
}

static void session_free(struct tls_session_param *session) {
	memset(session->master_secret, 0x0U, sizeof (session->master_secret));

	if (session->created_epochtime > 0) {
		TAILQ_REMOVE(&(session_list.head), session, link);
		session_list.len--;
	}

	free(session);
	session = NULL;
}

struct tls_session_param * tls_session_new(void) {
	struct tls_session_param *param;

	if (session_list.init == false) {
		session_init();
	}

	if ((param = calloc (1, sizeof (struct tls_session_param))) == NULL) {
		TLS_DPRINTF("session: calloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_CALLOC,
			     ERR_LC_TLS4, ERR_PT_TLS_SESSION + 0, NULL);
		return NULL;
	}

	param->created_epochtime = 0;
	param->session_id_length = 0;
	param->references = 0;
	param->disabled = true;
	param->compression_algorithm = TLS_COMPRESSION_NULL;
	param->cipher_suite          = TLS_NULL_WITH_NULL_NULL;

	TLS_DPRINTF("session: make initial session.");

	return param;
}

bool tls_session_free(struct tls_session_param *session) {
	if (session == NULL) {
		OK_set_error(ERR_ST_TLS_SESSION_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SESSION + 1, NULL);
		return true;
	}

	if (session->references != 0) {
		TLS_DPRINTF("session: free failed.");
		OK_set_error(ERR_ST_TLS_SESSION_USED,
			     ERR_LC_TLS4, ERR_PT_TLS_SESSION + 2, NULL);
		return false;
	}

	session_free(session);

	TLS_DPRINTF("session: freed.");
	return true;
}

struct tls_session_param * tls_session_find_by_id(const uint8_t *session_id,
						  const uint32_t len) {
	if (len == 0) {
		OK_set_error(ERR_ST_TLS_SESSION_ID_LENGTH,
			     ERR_LC_TLS4, ERR_PT_TLS_SESSION + 3, NULL);
		return NULL;
	}

	struct tls_session_param *session;
	TAILQ_FOREACH(session, &(session_list.head), link) {
		if (session->session_id_length != len) {
			OK_set_error(ERR_ST_TLS_SESSION_ID_LENGTH, ERR_LC_TLS4,
				     ERR_PT_TLS_SESSION + 4, NULL);
			continue;
		}

		if (memcmp(&(session_id[0]),
			   &(session->session_id[0]), len) == 0) {
			return session;
		}
	}

	return NULL;
}

void tls_session_refer(struct tls_session_param *session) {
	if (session == NULL) {
		OK_set_error(ERR_ST_TLS_SESSION_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SESSION + 5, NULL);
		return ;
	}

	session->references++;
	TLS_DPRINTF("session: refer ++ => %d.", session->references);
}

void tls_session_unrefer(struct tls_session_param *session) {
	if (session == NULL) {
		OK_set_error(ERR_ST_TLS_SESSION_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SESSION + 6, NULL);
		return ;
	}

	session->references--;
	TLS_DPRINTF("session: refer -- => %d.", session->references);

	/* if this session is disabled, free this session. */
	if (session->references == 0 && session->disabled == true) {
		TLS_DPRINTF("session: cleanup disabled session");
		session_free(session);
		return ;
	}

	/* if this session has been already added to the queue,
	 * finished. */
	if (session->created_epochtime > 0) {
		return ;
	}

	/* if session list is full, try to free member of that list. */
	if ((session_list.max <= session_list.len) &&
	    (session_clean() == 0)) {
		/* failed to clean up session_list. session_list is
		 * still full. do not save this session. */
		TLS_DPRINTF("session: session list is full.");
		session_free(session);
		return ;
	}

	TAILQ_INSERT_TAIL(&(session_list.head), session, link);
	session_list.len++;

	session->created_epochtime = tls_util_get_epochtime();
	TLS_DPRINTF("session: save session (list num = %d/%d).",
		    session_list.len, session_list.max);
}

void tls_session_enable(struct tls_session_param *session) {
	if (session == NULL) {
		OK_set_error(ERR_ST_TLS_SESSION_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SESSION + 7, NULL);
		return ;
	}

	session->disabled = false;
}


void tls_session_disable(struct tls_session_param *session) {
	if (session == NULL) {
		OK_set_error(ERR_ST_TLS_SESSION_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SESSION + 8, NULL);
		return ;
	}

	session->disabled = true;
}

uint32_t TLS_session_get_list_size(void) {
	return session_list.max;
}

void TLS_session_set_list_size(uint32_t num) {
	session_list.max = num;
}

bool TLS_session_free_all(void) {
	struct tls_session_param *session;

	while(!(TAILQ_EMPTY(&(session_list.head)))) {
		session = TAILQ_FIRST(&(session_list.head));

		if (session->references == 0) {
			session_free(session);
		} else {
			TLS_DPRINTF("session currently in use.");
			OK_set_error(ERR_ST_TLS_SESSION_USED,
				     ERR_LC_TLS4, ERR_PT_TLS_SESSION + 9, NULL);
			return false;
		}

		session = NULL;
	}

	TLS_DPRINTF("session: all session freed.");
	return true;
}
