/*
 * Copyright (c) 2015 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls.h"
#include "tls_alert.h"

#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>

#include <string.h>
#include <errno.h>

/**
 * send CLOSE NOTIFY Alert to the peer.
 *
 * this function is called when connection finishes successfully.
 */
static void send_close_notify(TLS *tls);

static void send_close_notify(TLS *tls) {
	/* when client (server) finishes connection normally, send close
	 * notify alert to the peer. if client (server) had already
	 * finished abnormally, do not send that alert. this is checked
	 * by state.  if state was TLS_STATE_CLOSED, consider the
	 * connection had already finished abnormally.
	 */
	if (tls->state == TLS_STATE_CLOSED) {
		return ;
	}

	/* do not wait response from peer. */
	TLS_ALERT_WARN(tls, TLS_ALERT_DESC_CLOSE_NOTIFY);
}

TLS* TLS_socket(int domain, int type, int protocol) {
	TLS* tls = NULL;

	if ((tls = TLS_new()) == NULL) {
		TLS_DPRINTF("tls_new");
		OK_set_error(ERR_ST_TLS_TLS_NEW,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 0, NULL);
		return NULL;
	}

	if ((tls->sockfd = socket(domain, type, protocol)) < 0) {
		TLS_DPRINTF("socket: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_SOCKET,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 1, NULL);
		return NULL;
	}

	return tls;
}

int TLS_getsockopt(TLS *tls,
		   int level, int optname, void * optval, socklen_t *optlen) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 2, NULL);
		return -1;
	}

	return getsockopt(tls->sockfd, level, optname, optval, optlen);
}


int TLS_setsockopt(TLS *tls,
		   int level, int optname,
		   const void * optval, socklen_t optlen) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 3, NULL);
		return -1;
	}

	return setsockopt(tls->sockfd, level, optname, optval, optlen);
}

int TLS_getsockname(TLS *tls, struct sockaddr *addr, socklen_t *addrlen) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 4, NULL);
		return -1;
	}

	return getsockname(tls->sockfd, addr, addrlen);
}

int TLS_getpeername(TLS *tls, struct sockaddr *addr, socklen_t *addrlen) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 5, NULL);
		return -1;
	}

	return getpeername(tls->sockfd, addr, addrlen);
}


int TLS_connect(TLS *tls, const struct sockaddr *addr, socklen_t addrlen) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 6, NULL);
		return -1;
	}

	int res;
	if ((res = connect(tls->sockfd, addr, addrlen)) < 0) {
		TLS_DPRINTF("connect: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_CONNECT,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 7, NULL);
		return res;
	}
	tls->entity = TLS_CONNECT_CLIENT;

	if (tls->opt.immediate_handshake) {
		TLS_DPRINTF("immediate_handshake");
		if (! TLS_handshake(tls)) {
			return -1;
		}
	}

	return res;
}

int TLS_bind(TLS* tls, const struct sockaddr *addr, socklen_t addrlen) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 8, NULL);
		return -1;
	}

#ifdef TLS_DEBUG
	const char on = 1;
	setsockopt(tls->sockfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
#endif

	int res;
	if ((res = bind(tls->sockfd, addr, addrlen)) < 0) {
		TLS_DPRINTF("bind: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_TLS_BIND,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 9, NULL);
		return res;
	}

	tls->entity = TLS_CONNECT_SERVER;

	return res;
}

int TLS_listen(TLS *tls, int backlog) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 10, NULL);
		return -1;
	}

	int res;
	if ((res = listen(tls->sockfd, backlog)) < 0) {
		TLS_DPRINTF("listen: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_LISTEN,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 11, NULL);
		return res;
	}

	return res;
}

TLS * TLS_accept(TLS *tls, struct sockaddr *addr, socklen_t *addrlen) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 12, NULL);
		return NULL;
	}

	int sockfd;
	if ((sockfd = accept(tls->sockfd, addr, addrlen)) < 0) {
		TLS_DPRINTF("accept: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_ACCEPT,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 13, NULL);
		return NULL;
	}

	TLS *newone;
	if ((newone = TLS_dup(tls)) == NULL) {
		TLS_DPRINTF("TLS_dup");
		OK_set_error(ERR_ST_TLS_TLS_DUP,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 14, NULL);
		return NULL;
	}

	newone->sockfd = sockfd;
	newone->entity = TLS_CONNECT_SERVER;

#ifdef TLS_DEBUG
	const char on = 1;
	setsockopt(newone->sockfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
#endif

	if (tls->opt.immediate_handshake) {
		TLS_DPRINTF("immediate_handshake");
		if (! TLS_handshake(newone)) {
			TLS_free(newone);
			return NULL;
		}
	}

	return newone;
}

int TLS_close(TLS *tls) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 15, NULL);
		return -1;
	}

	send_close_notify(tls);

	return close(tls->sockfd);
}

int TLS_shutdown(TLS *tls, int how) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 16, NULL);
		return -1;
	}

	send_close_notify(tls);

	return shutdown(tls->sockfd, how);
}

int TLS_get_fd(TLS *tls) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 17, NULL);
		return 0;
	}
	return tls->sockfd;
}

int TLS_set_fd(TLS *tls, int sockfd) {
	if (tls == NULL) {
		TLS_DPRINTF("tls");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_SOCKET + 18, NULL);
		return 0;
	}
	return (tls->sockfd = sockfd);
}
