/* testserver.c - SSL test server (simple version) */
/*
 * Copyright (c) 2012-2015 National Institute of Informatics in Japan, 
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must 
 * include this notice in the file.
 */
/*
 * Copyright (C) 1998-2002
 * Akira Iwata & Takuto Okuno
 * Akira Iwata Laboratory,
 * Nagoya Institute of Technology in Japan.
 *
 * All rights reserved.
 *
 * This software is written by Takuto Okuno(usapato@anet.ne.jp)
 * And if you want to contact us, send an email to Kimitake Wakayama
 * (wakayama@elcom.nitech.ac.jp)
 *
 * Redistribution and use in source and binary forms, with or without modification,
 * are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 *
 * 3. All advertising materials mentioning features or use of this software must
 *    display the following acknowledgment:
 *    "This product includes software developed by Akira Iwata Laboratory,
 *    Nagoya Institute of Technology in Japan (http://mars.elcom.nitech.ac.jp/)."
 *
 * 4. Redistributions of any form whatsoever must retain the following
 *    acknowledgment:
 *    "This product includes software developed by Akira Iwata Laboratory,
 *     Nagoya Institute of Technology in Japan (http://mars.elcom.nitech.ac.jp/)."
 *
 *   THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY.
 *   AKIRA IWATA LABORATORY DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
 *   SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS,
 *   IN NO EVENT SHALL AKIRA IWATA LABORATORY BE LIABLE FOR ANY SPECIAL,
 *   INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
 *   FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
 *   NEGLIGENCE OR OTHER TORTUOUS ACTION, ARISING OUT OF OR IN CONNECTION
 *   WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 */

#include "aiconfig.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>

#ifdef HAVE_NETDB_H
#include <netdb.h>
#endif

#include <sys/select.h>
#include <sys/types.h>
#ifdef __WINDOWS__
#undef ULONG
#include <winsock2.h>
#include <io.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#endif

#include <aicrypto/ok_err.h>
#include <aicrypto/ok_pkcs.h>
#include <aicrypto/ok_ssl.h>
#include <aicrypto/ok_x509.h>
#include "ssl.h"

#define SERVER_PORT "11112"

#ifdef __WINDOWS__
#define P12FNAME                ".\\00001.p12"
#else
#define P12FNAME	"00001.p12"
#endif
#ifndef PATH
# define PATH	"."
#endif

/* test/getfpath.c */
char *get_fpath(char *path, char *filename);

static int myread( int sock,char *buf,int len);
static int mywrite(int sock,char *buf,int len);
static int read_debug(SSL *ssl, int dat);
static int write_debug(SSL *ssl, int dat);
static int vfy_func(SSL *ssl, Cert *ct);

int test_do(SSL *ssl);

int main(int argc, char **argv){
	SSL *ssls[FD_SETSIZE];
	SSL **s = ssls;
	int ssls_len = 0;
	Cert *ct;
	char buf[64];
	struct addrinfo hints;
	struct addrinfo *ai, *aitop;
	fd_set fds, readfds;
	int i, smax, error;
	int af = AF_UNSPEC;
	char *fp_P12FNAME = get_fpath(PATH, P12FNAME);

	/* parse command line options like check_opt() in testserver2.c */
	if(argc == 1+1) {
		if(strcmp(argv[1], "-4") == 0) {
			/* use IPv4 only */
			af = AF_INET;
		} else if(strcmp(argv[1], "-6") == 0) {
			/* use IPv6 only */
			af = AF_INET6;
		} else {
			printf("error : unknown option\n");
			return -1;
		}
	}

#ifdef __WINDOWS__
	WSADATA	wsaData;

	if(WSAStartup(MAKEWORD(1, 1), &wsaData) != 0) {
		printf("error : WSAStartup\n");
		return(-1);}
#endif

	/* get addrinfo list */
	memset(&hints, 0, sizeof(hints));
	hints.ai_family = af;
	hints.ai_socktype = SOCK_STREAM;
	hints.ai_flags = AI_PASSIVE;
	if((error = getaddrinfo(NULL, SERVER_PORT, &hints, &aitop)) != 0) {
		printf("error : getaddrinfo(%s)\n", gai_strerror(error));
		return -1;
	}

	for(ai = aitop; ai; ai = ai->ai_next) {
		/* 1. get server socket */
		*s = SSL_socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
		if(*s == NULL) {
			printf("error : SSL_socket\n");
			return -1;
		}
		if(ai->ai_family == AF_INET6) {
			/* disable IPv4 mapped address */
			int on = 1;
			if(SSL_setsockopt(*s, IPPROTO_IPV6, IPV6_V6ONLY, &on, sizeof(on)) == -1) {
				printf("error : SSL_setsockopt\n");
				return -1;
			}
		}

		/* 2. setup server information */
		/* 2.1 read server certificate & private key */
		if (SSL_set_server_p12(*s, fp_P12FNAME, "abcde")) {
			printf("error : SSL_set_server_p12 : %s\n",OK_get_errstr());
			return -1;
		}
		printf("read server certificate:\n");
		ct = SSL_get_server_cert(*s);	/* server PKCS#12 file */
		printf("subject: %s\n", ct->subject);
		printf("issuer:  %s\n", ct->issuer);

		/* 2.2 set certificate request option. */
		printf("set certificate request option.\n");
		SSL_setopt(*s, SSL_OPT_CERTREQ);

		/* 2.3 set debug callback functions */
		SSL_set_read_cb(*s, myread);
		SSL_set_write_cb(*s, mywrite);
		SSL_set_readdebug_cb(*s, read_debug);
		SSL_set_writedebug_cb(*s, write_debug);
		SSL_set_vfy_cb(*s, vfy_func);

		/* 3. bind & listen socket*/
		if(-1 == SSL_bind(*s, ai->ai_addr, ai->ai_addrlen)) {
			printf("error : SSL_bind : %s\n", OK_get_errstr());
			return -1;
		}
		if(SSL_listen(*s, 1)) {
			printf("error : SSL_listen : %s\n", OK_get_errstr());
			return -1;
		}

		++s;
		++ssls_len;
		if(ssls_len >= FD_SETSIZE) {
			break;
		}
	}
	freeaddrinfo(aitop);
	free(fp_P12FNAME);

	FD_ZERO(&fds);
	smax = -1;
	for(i = 0; i < ssls_len; ++i) {
		FD_SET(ssls[i]->sock, &fds);
		if (ssls[i]->sock > smax) {
			smax = ssls[i]->sock;
		}
	}
	while(1) {
		memcpy(&readfds, &fds, sizeof(fd_set));
		if(select(smax + 1, &readfds, NULL, NULL, NULL) < 0) {
			if(errno == EINTR)
				continue;
			printf("error : select\n");
			return -1;
		}
		for(i = 0; i < ssls_len; ++i) {
			if(FD_ISSET(ssls[i]->sock, &readfds)) {
				/* 4. accept the connection */
				struct sockaddr_storage sa;
				socklen_t sa_len = sizeof(sa);
				SSL *ssl = SSL_accept(ssls[i], (struct sockaddr*)&sa, &sa_len);
				if(ssl == NULL) {
					printf("error: SSL_accept : %s\n", OK_get_errstr());
					SSL_close(ssl);
					return -1;
				}
				SSL_close(ssls[i]);

				/* 5. do SSL handshake :-) */
				if(SSL_handshake(ssl)) {
					printf("error: SSL_handshake : %s\n", OK_get_errstr());
					SSL_close(ssl);
					return -1;
				}

				printf("SSL connection was established!\n");
				SSL_cspec_str(ssl->ctx, buf);
				printf("using cipher : %s\n", buf);

				if(ssl->ctx->cp12) {
					printf("read client certificate:\n");
					ct = SSL_get_client_cert(ssl);	/* client PKCS#12 file */
					printf("subject: %s\n", ct->subject);
					printf("issuer:  %s\n", ct->issuer);
				}

				/* 6. and now, send messages with SSL_read & SSL_write !! */
				if(test_do(ssl))
					return -1;

				/* free SSL socket */
				SSL_free(ssls[i]);
				SSL_free(ssl);
				break;
			}
		}
		break;
	}

#ifdef __WINDOWS__
	WSACleanup();
#endif
	return 0;
}

int test_do(SSL *ssl){
	unsigned char buf[256];
	int i;

	memset(buf,0,256);
	/* read message first */
	if((i=SSL_read(ssl,buf,256))<0){
		printf("error : SSL_read() : %s\n",OK_get_errstr());
		return -1;
	}
	printf("now getting one message : %s\n",buf);

	strcpy(buf,"I hear you!!");
	printf("now writing one message : %s\n",buf);
	if(SSL_write(ssl,buf,strlen(buf))<0){
		printf("error : SSL_write() : %s\n",OK_get_errstr());
		return -1;
	}

	/* server doesn't need to get close_notify.
	 * this one is just for test.
	 */
	SSL_read(ssl,buf,256);

	SSL_close(ssl);
	return 0;
}

static int read_debug(SSL *ssl, int dat){
    printf("s:r_debug  : md=%d, op=%.2x, st=%d, %d\n",
	    ssl->mode,ssl->opt,ssl->ctx->state,dat);
    return 0;
}

static int write_debug(SSL *ssl, int dat){
    printf("s:w_debug : md=%d, op=%.2x, st=%d, %d\n",
	    ssl->mode,ssl->opt,ssl->ctx->state,dat);
    return 0;
}

static int vfy_func(SSL *ssl, Cert *ct){
    printf("s:vfy_func: %s\n",ct->subject);
    return 0;
}

static int myread(int sock,char *buf,int len){
    int i;
#ifdef __WINDOWS__
	i=recv(sock,buf,len,0);
#else
	i=read(sock,buf,len);
#endif
    printf("&%d&",i);
    return i;
}

static int mywrite(int sock,char *buf,int len){
    int i;
#ifdef __WINDOWS__
	i=send(sock,buf,len,0);
#else
	i=write(sock,buf,len);
#endif
    printf("#%d#",i);
    return i;
}
