/* ok_ssl.h */
/*
 * Modified by National Institute of Informatics in Japan, 2014-2016.
 *
 */
/*
 * Copyright (C) 1998-2002
 * Akira Iwata & Takuto Okuno
 * Akira Iwata Laboratory,
 * Nagoya Institute of Technology in Japan.
 *
 * All rights reserved.
 *
 * This software is written by Takuto Okuno(usapato@anet.ne.jp)
 * And if you want to contact us, send an email to Kimitake Wakayama
 * (wakayama@elcom.nitech.ac.jp)
 *
 * Redistribution and use in source and binary forms, with or without modification,
 * are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *	this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *	this list of conditions and the following disclaimer in the documentation
 *	and/or other materials provided with the distribution.
 *
 * 3. All advertising materials mentioning features or use of this software must
 *	display the following acknowledgment:
 *	"This product includes software developed by Akira Iwata Laboratory,
 *	Nagoya Institute of Technology in Japan (http://mars.elcom.nitech.ac.jp/)."
 *
 * 4. Redistributions of any form whatsoever must retain the following
 *	acknowledgment:
 *	"This product includes software developed by Akira Iwata Laboratory,
 *	 Nagoya Institute of Technology in Japan (http://mars.elcom.nitech.ac.jp/)."
 *
 *
 *   THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY.
 *   AKIRA IWATA LABORATORY DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
 *   SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS,
 *   IN NO EVENT SHALL AKIRA IWATA LABORATORY BE LIABLE FOR ANY SPECIAL,
 *   INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
 *   FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
 *   NEGLIGENCE OR OTHER TORTUOUS ACTION, ARISING OUT OF OR IN CONNECTION
 *   WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <aicrypto/ok_err.h>
#include <aicrypto/ok_tool.h>
#include <aicrypto/ok_x509.h>
#include <aicrypto/ok_ssl.h>
#include "ssl.h"

/* ssl_bind.c */
int SSL_alloc_contexts(SSL *);

/*-----------------------------------------
  alloc check of SSLCTX
-----------------------------------------*/
int ssl_check_sslctx(SSL *ssl){
	if(ssl==NULL){
		OK_set_error(ERR_ST_NULLPOINTER,ERR_LC_SSL,ERR_PT_SSL_TOOL,NULL);
		return -1;
	}
	if(ssl->ctx==NULL)
	  if(SSL_alloc_contexts(ssl))
	    return -1;
	return 0;
}

/*-----------------------------------------
  SSL set server certificate & key
-----------------------------------------*/
int SSL_set_server_p12(SSL *ssl,char *fname,char *passwd){

	if(ssl_check_sslctx(ssl)) goto error;

	OK_set_passwd(passwd);
	if((ssl->ctx->sp12=P12_read_file(fname))==NULL) goto error;

	OK_clear_passwd();

	ssl->opt |= SSL_SYS_SERVER;     /* set server flag */
	if(P12_check_chain(ssl->ctx->sp12,0)) goto error;

	return 0;
error:
	return -1;
}

int SSL_set_serverkey_file(SSL *ssl,char *fname,char *passwd){
	return SSL_set_server_p12(ssl,fname,passwd);
}

int SSL_set_serverkey_p12(SSL *ssl,PKCS12 *p12){

	if(ssl_check_sslctx(ssl)) goto error;
	ssl->ctx->sp12 = p12;

	ssl->opt |= SSL_SYS_SERVER;     /* set server flag */
	if(P12_check_chain(ssl->ctx->sp12,0)) goto error;

	return 0;
error:
	return -1;
}

int SSL_keyid2p12(SSL *ssl, char *id, PKCS12 **p12){
	CertList *cl,*top=NULL;
	Key *key=NULL;
	CStore *cs;
	CSBag *cbg,*kbg;
	char buf[PWD_BUFLEN];	
	int i,ok = -1;

	*p12 = NULL;

	/* find a certificate */
	if((cbg=STM_find_byID(ssl->ctx->stm,STORE_MY,CSTORE_ON_STORAGE,CSTORE_CTX_CERT,id))==NULL){
		OK_set_error(ERR_ST_BADCERTID,ERR_LC_SSL,ERR_PT_SSL_TOOL+1,NULL);
		goto done;
	}
	/* find a key */
	if((cs=STM_find_byName(ssl->ctx->stm,STORE_MY,CSTORE_ON_STORAGE,CSTORE_CTX_KEY))==NULL){
		OK_set_error(ERR_ST_BADKEYID,ERR_LC_SSL,ERR_PT_SSL_TOOL+1,NULL);
		goto done;
	}
	/* check private key password */
	OK_get_passwd("Open Private Key: ",(unsigned char*)buf,0);
	OK_set_passwd(buf);

	if(cs_get_keyhash(((Cert*)cbg->cache)->pubkey,buf,&i)) goto done;
	if((kbg=CStore_find_byKeyHash(cs->bags,(unsigned char*)buf))==NULL) goto done;
	if((key=CStore_get_key(cs,kbg))==NULL) goto done;

	OK_clear_passwd();

	/* get PKCS#12 */
	if((*p12=P12_new())==NULL) goto done;
	if((top=STM_get_pathcert(ssl->ctx->stm,(Cert*)cbg->cache))==NULL) goto done;

	for(cl=top; cl ; cl=cl->next){
		if(P12_add_cert(*p12,cl->cert,NULL,0xff)) goto done;
		cl->cert=NULL;
	}
	if(P12_add_key(*p12,key,NULL,0xff)) goto done;
	key = NULL;

	ok = 0;
done:
	Key_free(key);
	Certlist_free_all(top);
	return ok;
}

int SSL_set_serverkey_id(SSL *ssl,char *id){
	PKCS12 *p12=NULL;
	int ok = -1;

	if(ssl_check_sslctx(ssl)) goto done;
	
	if(SSL_keyid2p12(ssl,id,&p12)) goto done;

	ok = SSL_set_serverkey_p12(ssl,p12);
	p12=NULL;
done:
	P12_free(p12);
	return ok;
}

/*-----------------------------------------
  SSL set client certificate & key
-----------------------------------------*/
int SSL_set_client_p12(SSL *ssl,char *fname,char *passwd){

	if(ssl_check_sslctx(ssl)) goto error;

	OK_set_passwd(passwd);
	if((ssl->ctx->cp12=P12_read_file(fname))==NULL) goto error;

	OK_clear_passwd();

	if(P12_check_chain(ssl->ctx->cp12,0)) goto error;
	return 0;
error:
	return -1;
}

int SSL_set_clientkey_file(SSL *ssl,char *fname,char *passwd){
	return SSL_set_client_p12(ssl,fname,passwd);
}

int SSL_set_clientkey_p12(SSL *ssl,PKCS12 *p12){
	if(ssl_check_sslctx(ssl)) return -1;

	ssl->ctx->cp12 = p12;
	if(P12_check_chain(ssl->ctx->cp12,0)) return -1;

	return 0;
}

int SSL_set_clientkey_id(SSL *ssl,char *id){
	PKCS12 *p12=NULL;
	int ok = -1;

	if(ssl_check_sslctx(ssl)) goto done;
	
	if(SSL_keyid2p12(ssl,id,&p12)) goto done;

	ok = SSL_set_clientkey_p12(ssl,p12);
	p12=NULL;
done:
	P12_free(p12);
	return ok;
}

/*-----------------------------------------
  SSL get certificate
-----------------------------------------*/
Cert *SSL_get_scert(SSLCTX *ctx){
	Cert *ret=NULL;
	if(ctx==NULL) return NULL;
	if(ctx->sp12) ret=P12_get_usercert(ctx->sp12);
	return ret;
}

Cert *SSL_get_ccert(SSLCTX *ctx){
	Cert *ret=NULL;
	if(ctx==NULL) return NULL;
	if(ctx->cp12) ret=P12_get_usercert(ctx->cp12);
	return ret;
}

Cert *SSL_get_peer_cert(SSLCTX *ctx){
	Cert *ret=NULL;
	if(ctx==NULL) return NULL;
	if(ctx->serv){
	    if(ctx->cp12) ret=P12_get_usercert(ctx->cp12);
	}else{
	    if(ctx->sp12) ret=P12_get_usercert(ctx->sp12);
	}
	return ret;
}

Cert *SSL_get_client_cert(SSL *ssl)
{
	return SSL_get_ccert((ssl)->ctx);
}

Cert *SSL_get_server_cert(SSL *ssl)
{
	return SSL_get_scert((ssl)->ctx);
}

/*-----------------------------------------
  SSL got certificate request
-----------------------------------------*/
int SSL_got_certreq(SSL *ssl)
{
	return (ssl->opt & SSL_SYS_GOT_CERTREQ);
}
