/*
 * Copyright (c) 2015 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls.h"

#include <string.h>

/* for Key_free, Cert, CertList, Certlist_join and Certlist_free_all */
#include <aicrypto/ok_x509.h>

/* for PKCS12, P12_new, P12_add_cert and P12_add_key */
#include <aicrypto/ok_pkcs12.h>

/* for STM_open, STM_find_byName and STM_find_byID, cs_get_keyhash,
 * CSBag, CStore, CStore_2certlist, CStore_find_byKeyHash,
 * CStore_get_key, STM_get_pathcert and STM_verify_cert. */
#include <aicrypto/ok_store.h>

/* for OK_get_passwd, OK_set_passwd and OK_clear_passwd */
#include <aicrypto/ok_tool.h>

#ifndef STOREDIR
/**
 * default path of store manage (aistore).
 *
 * this definition is from aissl.
 */
#define STOREDIR	"/etc/naregi/store"
#pragma message("NOTE: using default value of STOREDIR, " STOREDIR "")
#endif

/**
 * read password (passphrase) for private key of specified pkcs12.
 *
 * this function is from aissl.
 */
static void read_passwrod(char* password);

/**
 * read private key from specified certificate chain.
 *
 * this function is from aissl.
 */
static Key * read_key(const TLS *tls, const CSBag *csbag);

/**
 * read pkcs12 data from store manager.
 *
 * this function is from aissl.
 */
static PKCS12 * read_pkcs12(const TLS *tls, CSBag *csbag, Key *key);

static void read_passwrod(char* password) {
	char prompt[] = "Open Private Key: ";
	int mode = 0;

	/* NOTE: why is it that arg2 of OK_get_passwd is unsigned char,
	 * and arg of OK_set_passwd is char.  */
	OK_get_passwd(prompt, (unsigned char *)password, mode);
}

static Key * read_key(const TLS *tls, const CSBag *csbag) {
	CStore *cstore;
	if ((cstore = STM_find_byName(tls->store_manager, STORE_MY,
				      CSTORE_ON_STORAGE,
				      CSTORE_CTX_KEY)) == NULL) {
		OK_set_error(ERR_ST_TLS_STM_FIND_BYNAME,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 0, NULL);
		return NULL;
	}

	/* PWD_BUFLEN defined in aicrypto/ok_tool.h. */
	char password[PWD_BUFLEN];
	read_passwrod(password);

	OK_set_passwd(password);

	/* check private key password */
	Cert *cert = csbag->cache;
	int32_t len;
	/* is it no problem to call this function? (it's really public
	 * function?) */
	if (cs_get_keyhash(cert->pubkey, (unsigned char *)password, &len) < 0) {
		OK_set_error(ERR_ST_TLS_STM_CS_GET_KEYHASH,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 1, NULL);
		return NULL;
	}
	/* XXX: check len */

	CSBag *keybag;
	if ((keybag = CStore_find_byKeyHash(
		     cstore->bags, (unsigned char *)password)) == NULL) {
		OK_set_error(ERR_ST_TLS_STM_CSTORE_FIND_BYKEYHASH,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 2, NULL);
		return NULL;
	}

	Key *key;
	if ((key = CStore_get_key(cstore, keybag)) == NULL) {
		OK_set_error(ERR_ST_TLS_STM_CSTORE_GET_KEY,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 3, NULL);
		return NULL;
	}

	/* in the above NULL case, is it not necessary to call
	 * OK_clear_passwd? */
	OK_clear_passwd();

	return key;
}

static PKCS12 * read_pkcs12(const TLS *tls, CSBag *csbag, Key *key) {
	PKCS12 *p12 = NULL;

	if ((p12 = P12_new()) == NULL) {
		OK_set_error(ERR_ST_TLS_P12_NEW,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 4, NULL);
		return NULL;
	}

	Cert *cert = csbag->cache;
	CertList *list = NULL;
	if ((list = STM_get_pathcert(tls->store_manager, cert)) == NULL) {
		OK_set_error(ERR_ST_TLS_STM_GET_PATHCERT,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 5, NULL);
		goto fin;
	}

	for (CertList *c = list; c; c = c->next) {
		if (P12_add_cert(p12, c->cert, NULL, 0xff)) {
			goto fin;
		}
		c->cert = NULL;
	}

	if (P12_add_key(p12, key, NULL, 0xff)) {
		goto fin;
	}

fin:
	Certlist_free_all(list);
	return p12;
}

bool tls_stm_nullp(TLS *tls) {
	if (tls->store_manager == NULL) {
		return true;
	}
	return false;
}

CertList * tls_stm_get_cert_list(TLS *tls) {
	CertList *list = NULL;

	CStore *cstore;
	if ((cstore = STM_find_byName(tls->store_manager,
				      STORE_ROOT,
				      CSTORE_ON_STORAGE,
				      CSTORE_CTX_CERT)) == NULL) {
		OK_set_error(ERR_ST_TLS_STM_FIND_BYNAME,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 6, NULL);
		goto done;
	}

	list = CStore_2certlist(cstore);

	if ((cstore = STM_find_byName(tls->store_manager,
				      STORE_MIDCA,
				      CSTORE_ON_STORAGE,
				      CSTORE_CTX_CERT)) == NULL) {
		OK_set_error(ERR_ST_TLS_STM_FIND_BYNAME,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 7, NULL);
		goto done;
	}

	CertList *tmp = CStore_2certlist(cstore);

	return Certlist_join(list, tmp);

done:
	if (list != NULL) {
		Certlist_free_all(list);
	}
	return NULL;
}

PKCS12 * tls_stm_find(TLS *tls, char* keyid) {
	/* this function ported from SSL_keyid2p12 (ssl_tool.c). */
	if (tls_stm_nullp(tls) == true) {
		return NULL;
	}

	/* find a certificate */
	CSBag *csbag;
	if ((csbag = STM_find_byID(tls->store_manager, STORE_MY,
				   CSTORE_ON_STORAGE, CSTORE_CTX_CERT,
				   keyid)) == NULL) {
		OK_set_error(ERR_ST_TLS_STM_FIND_ID,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 8, NULL);
		return NULL;
	}

	/* find a key */
	Key *key;
	if ((key = read_key(tls, csbag)) == NULL) {
		return NULL;
	}

	/* get PKCS12 */
	PKCS12 *p12;
	if ((p12 = read_pkcs12(tls, csbag, key)) == NULL) {
		Key_free(key);
		return NULL;
	}

	return p12;
}

int tls_stm_verify(TLS *tls, Cert *cert) {
	int res = 0;

	/* if do not set store manager, no check perform. */
	if (tls_stm_nullp(tls) == true) {
		return res;
	}

	if ((res = STM_verify_cert(tls->store_manager,
				   cert, tls->opt.verify_type)) < 0) {
		TLS_DPRINTF("STM_verify_cert");
		OK_set_error(ERR_ST_TLS_STM_VERIFY_CERT,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 9, NULL);
	}

	return res;
}

bool TLS_stm_set(TLS *tls, char *path) {
	/* these default value used default value of ssl
	 * implementation. */
	const char default_store_directory[] = STOREDIR;
	const uint32_t path_max = 126;

	if (tls == NULL) {
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 10, NULL);
		return false;
	}

	if (path != NULL) {
		goto store;
	}

	int32_t n;
	if ((n = snprintf(path, path_max, "%s", default_store_directory)) < 0) {
		TLS_DPRINTF("snprintf: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_SNPRINTF,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 11, NULL);
		return false;
	}

	if (n >= (int32_t)path_max) {
		return false;
	}

store:
	if ((tls->store_manager_path = strdup(path)) == NULL) {
		TLS_DPRINTF("strdup: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_STRDUP,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 12, NULL);
		return false;
	}

	if ((tls->store_manager = STM_open(path)) == NULL) {
		OK_set_error(ERR_ST_TLS_STM_OPEN,
			     ERR_LC_TLS4, ERR_PT_TLS_STM + 13, NULL);
		return false;
	}

	return true;
}
