/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include <aicrypto/nrg_tls.h>
#include "tls_handshake.h"

#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>

#include <stdio.h>
#include <string.h>
#include <getopt.h>
#include <errno.h>
#include <limits.h>

/* for OK_set_passwd and OK_clear_passwd */
#include <aicrypto/ok_tool.h>

static struct client_option {
	bool  echo;
	bool  www;
	bool  keyupdate;
	char* addr;
	char* port;
	char* stm_path;
	char* stm_id;
	char* p12_path;
	char* pass;
	int32_t session_num;
} g_client_option = {
	.echo        = false,
	.www        = false,
	.keyupdate   = false,
	.addr        = NULL,
	.port        = NULL,
	.stm_path    = NULL,
	.stm_id      = NULL,
	.p12_path    = NULL,
	.pass        = NULL,
	.session_num = -1
};

static int string_to_int(char* str, int base)
{
	int   num;
	char* ep;
	long  val;

	errno = 0;
	val = strtol (str, &ep, base);
	if ((errno == ERANGE && (val == LONG_MAX || val == LONG_MIN)) ||
	    (errno != 0 && val == 0)) {
		perror ("strtol");
		return -1;
	}

	if (ep == optarg) {
		fprintf (stderr, "No digits were found.\n");
		return -1;
	}

	return (num = val);
}

static bool set_stm_id(TLS *tls) {
	char* id   = g_client_option.stm_id;
	char* pass = g_client_option.pass;

	OK_set_passwd(pass);

	if (TLS_set_clientkey_id(tls, id) < 0) {
		fprintf(stderr, "TLS_set_clientkey_id()\n");
		OK_clear_passwd();
		return false;
	}

	OK_clear_passwd();

	return true;
}

static bool set_stm_p12(TLS *tls) {
	char* path = g_client_option.p12_path;
	char* pass = g_client_option.pass;

	if (TLS_set_clientkey_file(tls, path, pass) < 0) {
		fprintf(stderr, "TLS_set_clientkey_file()\n");
		OK_clear_passwd();
		return false;
	}

	return true;
}

static bool do_echo(TLS *tls) {
	int n;

	const char message[] = "hello\n";
	int32_t message_size = sizeof (message) - 1;

	printf("CLIENT> (w) len = %d, %s", message_size, message);

	if ((n = TLS_write(tls, message, message_size)) < 0) {
		fprintf(stderr, "TLS_write()\n");
		return false;
	}

	if (n != message_size) {
		fprintf(stderr, "invalid message length.");
		return false;
	}

	int32_t buff_len = 20;
	char buff[buff_len];

	memset(buff, 0x0U, buff_len);

	if ((n = TLS_read(tls, buff, buff_len)) < 0) {
		fprintf(stderr, "TLS_read()\n");
		return false;
	}

	if (n != message_size) {
		fprintf(stderr, "E: read length (message %d != %d).",
			n, message_size);
		return false;
	}

	if (strcmp(buff, message) != 0) {
		fprintf(stderr, "E: echo not match.");
		return false;
	}

	printf("CLIENT> (r) len = %d, %s", buff_len , buff);

	const char quit[] = "quit\n";
	int32_t quit_size = sizeof (quit) - 1;
	if ((n = TLS_write(tls, quit, quit_size)) < 0) {
		fprintf(stderr, "TLS_write()\n");
		return false;
	}

	if (n != quit_size) {
		fprintf(stderr, "E: read length (quit, %d != %d).",
			n, quit_size);
		return false;
	}

	printf("CLIENT> (q) len = %d, %s", quit_size, quit);

	return true;
}

static bool do_www(TLS *tls) {
	int n;

	const char message[] = "GET /\n\n";
	int32_t message_size = sizeof (message) - 1;

	printf("CLIENT> (w) len = %d, %s", message_size, message);

	if ((n = TLS_write(tls, message, message_size)) < 0) {
		fprintf(stderr, "TLS_write()\n");
		return false;
	}

	if (n != message_size) {
		fprintf(stderr, "invalid message length.");
		return false;
	}

	int32_t buff_len = 4096;
	char buff[buff_len];

	memset(buff, 0x0U, buff_len);

	if ((n = TLS_read(tls, buff, buff_len)) < 0) {
		fprintf(stderr, "TLS_read()\n");
		return false;
	}

	printf("CLIENT> (r) len = %d\n", buff_len);
	printf("buf:\n---\n%s\n---\n", buff);

	const char quit[] = "quit\n";
	int32_t quit_size = sizeof (quit) - 1;
	if ((n = TLS_write(tls, quit, quit_size)) < 0) {
		fprintf(stderr, "TLS_write()\n");
		return false;
	}

	if (n != quit_size) {
		fprintf(stderr, "E: read length (quit, %d != %d).",
			n, quit_size);
		return false;
	}

	printf("CLIENT> (q) len = %d, %s", quit_size, quit);

	return true;
}

static bool do_key_update(TLS *tls) {
	char test_string1[] = "before key update\n";
	ssize_t test_string1_len = sizeof(test_string1) - 1;
	if (TLS_write(tls, test_string1, test_string1_len) < 0) {
		return false;
	}

	printf("Send> (w) len = %zd, %s", test_string1_len, test_string1);

	if (tls_hs_send_key_update(tls, TLS_KEYUPDATE_REQUESTED) == false) {
		return false;
	}

	char test_string2[] = "after key update\n";
	ssize_t test_string2_len = sizeof(test_string2) - 1;
	if (TLS_write(tls, test_string2, test_string2_len) < 0) {
		return false;
	}

	printf("Send> (w) len = %zd, %s", test_string2_len, test_string2);

	char test_string1_buf[sizeof(test_string1)];
	memset(test_string1_buf, 0, test_string1_len + 1);

	ssize_t recv1len;
	if ((recv1len = TLS_read(tls, test_string1_buf, test_string1_len)) < 0) {
		return false;
	}

	if (recv1len != test_string1_len) {
		fprintf(stderr, "test_string1 is not equal length\n");
		return false;
	}
	printf("Recv> (r) len = %zd, %s", recv1len, test_string1_buf);

	char test_string2_buf[sizeof(test_string2)];
	memset(test_string2_buf, 0, test_string1_len + 1);

	ssize_t recv2len;
	if ((recv2len = TLS_read(tls, test_string2_buf, test_string2_len)) < 0) {
		return false;
	}

	if (recv2len != test_string2_len) {
		fprintf(stderr, "test_string2 is not equal length\n");
		return false;
	}
	printf("Recv> (r) len = %zd, %s", recv2len, test_string2_buf);

	return true;
}

static bool run_client(void) {
	bool ret = true;

	struct addrinfo hints = {
		.ai_family   = AF_UNSPEC,
		.ai_socktype = SOCK_STREAM
	};

	struct addrinfo *aihead = NULL;
	if (getaddrinfo(g_client_option.addr, g_client_option.port,
			&hints, &aihead) != 0) {
		perror("getaddrinfo");
		return false;
	}

	if (g_client_option.session_num >= 0) {
		TLS_session_set_list_size(g_client_option.session_num);
	}

	TLS_init();

	TLS *tls = NULL;
	for (struct addrinfo *ai = aihead; ai != NULL; ai = ai->ai_next) {
		if ((tls = TLS_socket(ai->ai_family,
				      ai->ai_socktype,
				      ai->ai_protocol)) == NULL) {
			perror("socket");
			goto warn_loop;
		}

		if (g_client_option.stm_path != NULL) {
			if (! TLS_stm_set(tls,
					  g_client_option.stm_path)) {
				fprintf(stderr, "TLS_stm_set()\n");
				ret = false;
				break;
			}
		}

		if ((g_client_option.stm_id != NULL) &&
		    (g_client_option.pass != NULL)) {
			if (set_stm_id(tls) == false) {
				goto warn_loop;
			}

			/* don't check whether a certificate is revoked because
			 * the store manager for test don't have CRL.
			 */
			enum tls_cert_verify_type type = TLS_DONT_CHECK_REVOKED;
			TLS_opt_set(tls, TLS_OPT_CERT_VERIFY_TYPE, &type);
		} else if ((g_client_option.p12_path != NULL) &&
			   (g_client_option.pass != NULL)) {
			if (set_stm_p12(tls) == false) {
				goto warn_loop;
			}
		}

		if (TLS_connect(tls, ai->ai_addr, ai->ai_addrlen) < 0) {
			fprintf(stderr, "E: TLS_connect()\n");
			goto warn_loop;
		}

		if (TLS_set_server_name(tls, g_client_option.addr) < 0) {
			fprintf(stderr, "W: failed to set server name");
		}
		break;

warn_loop:
		if (tls != NULL) {
			TLS_close(tls);
			TLS_free(tls);
			tls = NULL;
			/* close tls socket. */
		}
	}
	freeaddrinfo(aihead);

	if (tls == NULL) {
		fprintf(stderr, "E: no socket connection to %s\n",
			g_client_option.addr);
		ret = false;
		goto done;
	}

	if (TLS_handshake(tls) < 0) {
		fprintf(stderr, "E: TLS_handshake()\n");
		ret = false;
		goto done;
	}

	/* established. */
	if (g_client_option.echo == true) {
		if (do_echo(tls) == false) {
			ret = false;
			goto done;
		}
	}
	/* established. */
	if (g_client_option.www == true) {
		if (do_www(tls) == false) {
			ret = false;
			goto done;
		}
	}

	/* established. */
	if (g_client_option.keyupdate == true) {
		if (do_key_update(tls) == false) {
			ret = false;
			goto done;
		}
	}

done:
	if (tls != NULL) {
		TLS_close(tls);
		TLS_free(tls);
		tls = NULL;
	}
	TLS_cleanup();
	return ret;
}

int
main (int argc, char* argv[]) {
	static struct option opts[] = {
		{"--echo",    no_argument,       NULL, 'e'},
		{"--www",     no_argument,       NULL, 'w'},
		{"--keyupdate", no_argument,     NULL, 'k'},
		{"--addr",    required_argument, NULL, 'a'},
		{"--port",    required_argument, NULL, 'p'},
		{"--stm",     required_argument, NULL, 's'},
		{"--clid",    required_argument, NULL, 'i'},
		{"--cert",    required_argument, NULL, 'c'},
		{"--pass",    required_argument, NULL, 'P'},
		{"--snum",    required_argument, NULL, 'n'},
		{NULL, 0, NULL, 0}
	};

	int c;
	while ((c = getopt_long (argc, argv,
				 "ewka:p:s:i:c:n:P:vh", opts, NULL)) != EOF) {
		switch (c) {
		case 'e':
			g_client_option.echo = true;
			break;

		case 'w':
			g_client_option.www = true;
			break;

		case 'k':
			g_client_option.keyupdate = true;
			break;

		case 'a':
			g_client_option.addr = optarg;
			break;

		case 'p':
			g_client_option.port = optarg;
			break;

		case 's':
			g_client_option.stm_path = optarg;
			break;

		case 'i':
			g_client_option.stm_id = optarg;
			break;

		case 'c':
			g_client_option.p12_path = optarg;
			break;

		case 'P':
			g_client_option.pass = optarg;
			break;

		case 'n':
			g_client_option.session_num =
				string_to_int(optarg, 10);
			break;

		case 'v':
			exit (EXIT_SUCCESS);

		case 'h':
			exit (EXIT_SUCCESS);

		default:
			exit (EXIT_FAILURE);
		}
	}

	if (run_client() == false) {
		return 1;
	}

	return 0;
}
