/*
 * Copyright (c) 2015 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include <aicrypto/nrg_tls.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netdb.h>

#include <stdio.h>
#include <string.h>
#include <getopt.h>
#include <errno.h>
#include <limits.h>

/* for OK_set_passwd and OK_clear_passwd */
#include <aicrypto/ok_tool.h>

static struct server_option {
	bool  client_auth;
	bool  may_client_auth;
	bool  loop;
	bool  echo;
	char* port;
	char* stm_path;
	char* stm_id;
	char* p12_path;
	char* pass;
	int32_t session_num;
} g_server_option = {
	.client_auth     = false,
	.may_client_auth = false,
	.loop            = false,
	.echo            = false,
	.port            = NULL,
	.stm_path        = NULL,
	.stm_id          = NULL,
	.p12_path        = NULL,
	.pass            = NULL,
	.session_num     = -1
};

static int string_to_int(char* str, int base)
{
	int   num;
	char* ep;
	long  val;

	errno = 0;
	val = strtol (str, &ep, base);
	if ((errno == ERANGE && (val == LONG_MAX || val == LONG_MIN)) ||
	    (errno != 0 && val == 0)) {
		perror ("strtol");
		return -1;
	}

	if (ep == optarg) {
		fprintf (stderr, "No digits were found.\n");
		return -1;
	}

	return (num = val);
}

static void use_may_client_auth(TLS *tls) {
	bool on = true;
	TLS_opt_set(tls, TLS_OPT_MAY_USE_CLIENT_CERT_AUTH, &on);
}

static void use_client_auth(TLS *tls) {
	bool on = true;
	TLS_opt_set(tls, TLS_OPT_USE_CLIENT_CERT_AUTH, &on);
}

static bool set_stm_id(TLS *tls) {
	char* id   = g_server_option.stm_id;
	char* pass = g_server_option.pass;

	OK_set_passwd(pass);

	if (TLS_set_serverkey_id(tls, id) < 0) {
		fprintf(stderr, "TLS_set_serverkey_id()\n");
		OK_clear_passwd();
		return false;
	}

	OK_clear_passwd();

	return true;
}

static bool set_stm_p12(TLS *tls) {
	char* path = g_server_option.p12_path;
	char* pass = g_server_option.pass;

	if (TLS_set_serverkey_file(tls, path, pass) < 0) {
		fprintf(stderr, "TLS_set_serverkey_file()\n");
		OK_clear_passwd();
		return false;
	}

	return true;
}

static bool do_echo(TLS *tls) {
	int n;

	const char quit[] = "quit\n";
	const int32_t quit_len = sizeof (quit) - 1;

	while (true) {
		/* do not care too long line. */
		const int32_t buff_len = 80;
		char buff[buff_len];

		memset(buff, 0x0U, buff_len);

		if ((n = TLS_gets(tls, buff, buff_len)) < 0) {
			return false;
		}

		printf("SERVER> (r) len = %d, %s", n, buff);

		if ((n == quit_len) && (strncmp(quit, buff, n) == 0)) {
			printf("SERVER> (q)\n");
			break;
		}

		if (buff[n - 1] == '\n') {
			buff[n - 1] = '\0';
		}

		printf("SERVER> (w) len = %d, %s\n", n, buff);

		buff[n - 1] = '\n';
		buff[n]     = '\0';

		if ((n = TLS_write(tls, buff, n)) < 0) {
			fprintf(stderr, "TLS_write()\n");
			return false;
		}
	}

	return true;
}

static bool run_server(void) {
	struct addrinfo hints = {
		.ai_family   = AF_UNSPEC,
		.ai_socktype = SOCK_STREAM,
		.ai_flags    = AI_PASSIVE
	};

	struct addrinfo *aihead;
	if (getaddrinfo(NULL, g_server_option.port, &hints, &aihead) != 0) {
		perror("getaddrinfo");
		return false;
	}

	if (g_server_option.session_num >= 0) {
		TLS_session_set_list_size(g_server_option.session_num);
	}

	bool ret = true;
	TLS *tls[FD_SETSIZE];
	uint32_t max_tls = 0;

	TLS_init();
	for (struct addrinfo *ai = aihead;
	     ai != NULL && max_tls < FD_SETSIZE;
	     ai = ai->ai_next, max_tls++) {
		if ((tls[max_tls] = TLS_socket(ai->ai_family,
					       ai->ai_socktype,
					       ai->ai_protocol)) == NULL) {
			perror("socket");
			ret = false;
			break;
		}

		int on = 1;
		if(TLS_setsockopt(tls[max_tls],
				  SOL_SOCKET, SO_REUSEADDR,
				  (const char *)&on, sizeof(on)) != 0) {
			fprintf(stderr, "setsockopt\n");
			ret = false;
			break;
		}

		if(ai->ai_family == AF_INET6) {
			/* disable IPv4 mapped address */
			if(TLS_setsockopt(tls[max_tls],
					  IPPROTO_IPV6, IPV6_V6ONLY,
					  (const char *)&on, sizeof(on)) != 0) {
				fprintf(stderr, "setsockopt\n");
				ret = false;
				break;
			}
		}

		if (g_server_option.may_client_auth == true) {
			use_may_client_auth(tls[max_tls]);
		}

		if (g_server_option.client_auth == true) {
			use_client_auth(tls[max_tls]);
		}

		if (g_server_option.stm_path != NULL) {
			if (! TLS_stm_set(tls[max_tls],
					  g_server_option.stm_path)) {
				fprintf(stderr, "TLS_stm_set()\n");
				ret = false;
				break;
			}
		}

		if ((g_server_option.stm_id != NULL) &&
		    (g_server_option.pass != NULL)) {
			if (set_stm_id(tls[max_tls]) == false) {
				ret = false;
				break;
			}

			/* don't check whether a certificate is revoked because
			 * the store manager for test don't have CRL.
			 */
			enum tls_cert_verify_type type = TLS_DONT_CHECK_REVOKED;
			TLS_opt_set(tls[max_tls], TLS_OPT_CERT_VERIFY_TYPE,
				    &type);
		} else if ((g_server_option.p12_path != NULL) &&
			   (g_server_option.pass != NULL)) {
			if (set_stm_p12(tls[max_tls]) == false) {
				ret = false;
				break;
			}
		}

		if (TLS_bind(tls[max_tls], ai->ai_addr, ai->ai_addrlen) < 0) {
			perror("bind");
			ret = false;
			break;
		}

		if (TLS_listen(tls[max_tls], 1) < 0) {
			perror("listen");
			ret = false;
			break;
		}
	}
	freeaddrinfo(aihead);
	if (ret == false) {
		goto done;
	}

	fd_set readfds;
	fd_set fds;
	int    maxfd = -1;

	FD_ZERO(&readfds);
	for (uint32_t i = 0; i < max_tls; i++) {
		FD_SET(TLS_get_fd(tls[i]), &readfds);
		if (TLS_get_fd(tls[i]) > maxfd) {
			maxfd = TLS_get_fd(tls[i]);
		}
	}

	do {
		memcpy(&fds, &readfds, sizeof(fd_set));

		int n;
		if ((n = select(maxfd + 1, &fds, NULL, NULL, NULL)) <= 0) {
			perror("select");
			ret = false;
			goto done;
		}

		for (uint32_t i = 0; i < max_tls; i++) {
			if (! FD_ISSET(TLS_get_fd(tls[i]), &fds)) {
				continue;
			}

			struct sockaddr_storage sa;
			socklen_t sa_len = sizeof(sa);

			TLS *newone;
			if ((newone = TLS_accept(tls[i],
						 (struct sockaddr *)&sa,
						 &sa_len)) == NULL) {
				perror ("accept");
				if (g_server_option.loop == false) {
					ret = false;
				}
				goto end_for;
			}

			if (TLS_handshake(newone) < 0) {
				fprintf(stderr, "TLS_handshake()\n");
				if (g_server_option.loop == false) {
					ret = false;
				}
				goto end_for;
			}

			if (g_server_option.echo == true) {
				if (do_echo(newone) == false) {
					if (g_server_option.loop == false) {
						ret = false;
					}
					goto end_for;
				}
			}

		end_for:
			TLS_close(newone);
			TLS_free(newone);
			newone = NULL;
		}
	} while (g_server_option.loop);

done:
	for (uint32_t i = 0; i < max_tls; i++) {
		TLS_close(tls[i]);
		TLS_free(tls[i]);
		tls[i] = NULL;
	}
	TLS_cleanup();

	return ret;
}


int
main (int argc, char* argv[]) {
	static struct option opts[] = {
		{"--cauth",   no_argument,       NULL, 'a'},
		{"--loop",    no_argument,       NULL, 'l'},
		{"--echo"   , no_argument,       NULL, 'e'},
		{"--port",    required_argument, NULL, 'p'},
		{"--stm",     required_argument, NULL, 's'},
		{"--svid",    required_argument, NULL, 'i'},
		{"--cert",    required_argument, NULL, 'c'},
		{"--pass",    required_argument, NULL, 'P'},
		{"--snum",    required_argument, NULL, 'n'},
		{NULL, 0, NULL, 0}
	};

	int c;
	while ((c = getopt_long (argc, argv,
				 "amelp:s:i:c:P:n:vh", opts, NULL)) != EOF) {
		switch (c) {
		case 'a':
			g_server_option.client_auth = true;
			break;

		case 'm':
			g_server_option.may_client_auth = true;
			break;

		case 'e':
			g_server_option.echo = true;
			break;

		case 'l':
			g_server_option.loop = true;
			break;

		case 'p':
			g_server_option.port = optarg;
			break;

		case 's':
			g_server_option.stm_path = optarg;
			break;

		case 'i':
			g_server_option.stm_id = optarg;
			break;

		case 'c':
			g_server_option.p12_path = optarg;
			break;

		case 'P':
			g_server_option.pass = optarg;
			break;

		case 'n':
			g_server_option.session_num =
				string_to_int(optarg, 10);
			break;

		case 'v':
			exit (EXIT_SUCCESS);

		case 'h':
			exit (EXIT_SUCCESS);

		default:
			exit (EXIT_FAILURE);
		}
	}

	if (run_server() == false) {
		return 1;
	}

	return 0;
}
