/*
 * Copyright (c) 2015-2020 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_cert.h"
#include "tls_cipher.h"
#include "tls_alert.h"

#include <string.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>

/* for ASN1_read_cert */
#include <aicrypto/ok_asn1.h>

/* for P12_get_usercert and P12_check_chain */
#include <aicrypto/ok_pkcs12.h>

/* for OK_set_passwd and OK_clear_passwd */
#include <aicrypto/ok_tool.h>

/** @see tls_stm.c */
extern PKCS12 * tls_stm_find(TLS *tls, char* keyid);

/** @see tls_stm.c */
extern int tls_stm_verify(TLS *tls, Cert *cert);

/**
 * check PKCS12. internally, check by P12_check_chain.
 */
static bool check_pkcs12(PKCS12 *pkcs12);

/**
 * check result of tls_stm_verify function.
 *
 * in here, check whether the certificate is revoked, expired and so on.
 * this function brought the implementation of AiSSL.
 */
static bool check_verify_status(TLS *tls, const int status);

/**
 * set pkcs12 to the tls structure.
 *
 * by entity as 2nd argument, determine where of pkcs12_client member or
 * pkcs12_server member function save.
 */
static void set_pkcs12(TLS *tls,
		       const enum connection_end entity,
		       PKCS12 *pkcs12);

/**
 * set certificate (pkcs12) to the tls strucutre from file path.
 */
static bool tls_cert_set_by_file(TLS *tls,
				 const enum connection_end entity,
				 char* filename, char* password);

/**
 * set certificate (pkcs12) to the tls strucutre from id of store
 * manager.
 */
static bool tls_cert_set_by_id(TLS *tls,
			       const enum connection_end entity,
			       char *id);

/**
 * Get curvetype from the public key.
 *
 * @param[in] pubkey
 * @param[in] info
 *
 * @ingroup tls_ecc
 */
static void get_pubkey_ecc_info(Key *pubkey, struct tls_cert_info *info);

/**
 * default certificate type of TLS 1.2 that is used by certificate
 * request handshake protocol message.
 */
static enum tls_cert_type tls_cert_type_tls12[] = {
	TLS_CERT_RSA_SIGN
#if 0 /* unsupported */
	TLS_CERT_DSS_SIGN
#endif /* 0 */
};

static bool check_pkcs12(PKCS12 *pkcs12) {
	enum { quiet = 0, verbose = 1 } print = quiet;
	if (P12_check_chain(pkcs12, print) < 0) {
		TLS_DPRINTF("cert: P12_check_chain");
		OK_set_error(ERR_ST_TLS_P12_CHECK_CHAIN,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 0, NULL);
		return false;
	}

	return true;
}

static bool check_verify_status(TLS *tls, const int status) {
	int type = status & 0xff00;

	TLS_DPRINTF("cert: cert verification (%d)", type);
	switch(type){
	case 0:
		break;

	case X509_VFY_ERR_REVOKED:
		OK_set_error(ERR_ST_TLS_X509_REVOKED,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 1, NULL);
		TLS_ALERT_FATAL(tls ,TLS_ALERT_DESC_CERTIFICATE_REVOKED);
		return false;

	case X509_VFY_ERR_NOTBEFORE:
		OK_set_error(ERR_ST_TLS_X509_NOTBEFORE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 2, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_CERTIFICATE_EXPIRED);
		return false;

	case X509_VFY_ERR_NOTAFTER:
		OK_set_error(ERR_ST_TLS_X509_NOTAFTER,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_CERTIFICATE_EXPIRED);
		return false;

	case X509_VFY_ERR_LASTUPDATE:
		OK_set_error(ERR_ST_TLS_X509_LASTUPDATE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 4, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_CERTIFICATE_EXPIRED);
		return false;

	case X509_VFY_ERR_NEXTUPDATE:
		OK_set_error(ERR_ST_TLS_X509_NEXTUPDATE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 5, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_CERTIFICATE_EXPIRED);
		return false;

	case X509_VFY_ERR_SIGNATURE:
		OK_set_error(ERR_ST_TLS_X509_SIGNATURE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 6, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_BAD_CERTIFICATE);
		return false;

	case X509_VFY_ERR_SIGNATURE_CRL:
		OK_set_error(ERR_ST_TLS_X509_SIGNATURE_CRL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 7, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_BAD_CERTIFICATE);
		return false;

	case X509_VFY_ERR_ISSUER_CRL:
		OK_set_error(ERR_ST_TLS_X509_ISSUER_CRL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 8, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_BAD_CERTIFICATE);
		return false;

	case X509_VFY_ERR_NOT_IN_CERTLIST:
		OK_set_error(ERR_ST_TLS_X509_NOT_IN_CERTLIST,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 9, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_BAD_CERTIFICATE);
		return false;

	case X509_VFY_ERR_SELF_SIGN:
		OK_set_error(ERR_ST_TLS_X509_SELF_SIGN,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 10, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_BAD_CERTIFICATE);
		return false;

	case X509_VFY_ERR_NOT_CACERT:
		OK_set_error(ERR_ST_TLS_X509_NOT_CACERT,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 11, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNKNOWN_CA);
		return false;

	case X509_VFY_ERR_SYSTEMERR:
	default:
		OK_set_error(ERR_ST_TLS_X509_SYSTEMERR,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 12, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_CERTIFICATE_UNKNOWN);
		return false;
	}

	return true;
}

static void set_pkcs12(TLS *tls,
		       const enum connection_end entity,
		       PKCS12 *pkcs12) {
	switch(entity) {
	case TLS_CONNECT_CLIENT:
		tls->pkcs12_client = pkcs12;
		break;

	case TLS_CONNECT_SERVER:
		tls->pkcs12_server = pkcs12;
		break;

	default:
		assert(!"unknown connection end specified.");
	}
}

static bool tls_cert_set_by_pkcs12(TLS *tls,
				   const enum connection_end entity,
				   PKCS12 *pkcs12) {
	if (! check_pkcs12(pkcs12)) {
		TLS_DPRINTF("cert: check_pkcs12");
		OK_set_error(ERR_ST_TLS_CHECK_PKCS12,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 13, NULL);
		return false;
	}

	set_pkcs12(tls, entity, pkcs12);

	return true;
}

static bool tls_cert_set_by_file(TLS *tls,
				 const enum connection_end entity,
				 /* XXX: these arguments should be
				  * const. but, arg of OK_set_passwd is
				  * not const. so declare non-const. */
				 char* filename, char* password) {
	if (tls == NULL) {
		TLS_DPRINTF("cert: tls (null)");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 14, NULL);
		return false;
	}

	OK_set_passwd(password);

	PKCS12 *pkcs12;
	if ((pkcs12 = P12_read_file(filename)) == NULL) {
		TLS_DPRINTF("cert: P12_read_file");
		OK_set_error(ERR_ST_TLS_P12_READ_FILE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 15, NULL);
		return false;
	}

	OK_clear_passwd();

	if (tls_cert_set_by_pkcs12(tls, entity, pkcs12) == false) {
		TLS_DPRINTF("cert: tls_cert_set_by_pkcs12");
		OK_set_error(ERR_ST_TLS_CERT_SET_BY_PKCS12,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 16, NULL);
		return false;
	}

	return true;
}

static bool tls_cert_set_by_id(TLS *tls,
			const enum connection_end entity,
			char *id) {
	if (tls == NULL) {
		TLS_DPRINTF("cert: tls (null)");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 17, NULL);
		return false;
	}

	PKCS12 *pkcs12;
	if ((pkcs12 = tls_stm_find(tls, id)) == NULL) {
		TLS_DPRINTF("cert: tls_stm_find");
		OK_set_error(ERR_ST_TLS_STM_FIND,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 18, NULL);
		return false;
	}

	if (tls_cert_set_by_pkcs12(tls, entity, pkcs12) == false) {
		TLS_DPRINTF("cert: tls_cert_set_by_pkcs12");
		OK_set_error(ERR_ST_TLS_CERT_SET_BY_PKCS12,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 19, NULL);
		return false;
	}

	pkcs12 = NULL;

	return true;
}

struct tls_cert_type_list * tls_cert_type_list(const TLS *tls) {
	struct tls_cert_type_list *ctlist;

	uint32_t len = 0;
	enum tls_cert_type *list;

	switch (tls->negotiated_version.minor) {
	case TLS_MINOR_SSL30:
	case TLS_MINOR_TLS10:
	case TLS_MINOR_TLS11:
		/* TODO: not implementation. */
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 7, NULL);
		return NULL;

	case TLS_MINOR_TLS12:
	default:
		list = tls_cert_type_tls12;
		len  = (sizeof (tls_cert_type_tls12) /
			sizeof (enum tls_cert_type));
		break;
	}

	if ((ctlist = malloc (1 *
			      sizeof (struct tls_cert_type_list))) == NULL) {
		TLS_DPRINTF("cert: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 20, NULL);
		return NULL;
	}

	ctlist->len = len;
	if ((ctlist->list = malloc (ctlist->len *
				    sizeof (enum tls_cert_type))) == NULL) {
		TLS_DPRINTF("cert: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 21, NULL);
		free(ctlist);
		return NULL;
	}

	for (uint16_t i = 0; i < len; ++i) {
		ctlist->list[i] = list[i];
	}

	return ctlist;
}

static void get_pubkey_ecc_info(Key *pubkey, struct tls_cert_info *info)
{
	assert(pubkey->key_type == KEY_ECDSA_PUB);

	Pubkey_ECDSA *ecdh_pubkey = (Pubkey_ECDSA *) pubkey;
	ECParam *param = ecdh_pubkey->E;

	info->curve_type = param->curve_type;
}

void tls_cert_type_free(struct tls_cert_type_list *list) {
	free(list->list);
	list->list = NULL;
	list->len  = 0;

	free(list);
	list = NULL;
}

bool tls_cert_type_availablep(enum tls_cert_type type) {
	switch (type) {
	case TLS_CERT_RSA_SIGN:
		return true;
	default:
		return false;
	}
}

void tls_cert_info_init(struct tls_cert_info *info)
{
	info->keyusage = 0;
	info->signature_algo = OBJ_SIG_NULL;
	info->pubkey_algo = KEY_NULL;
	info->curve_type = 0;;
}

bool tls_cert_info_get(PKCS12 *p12, struct tls_cert_info *info)
{
	tls_cert_info_init(info);

	/*
	 * If there is no certificate, return in a state in which
	 * the struct tls_cert_info was initialized.
	 */
	if (p12 == NULL) {
		return true;
	}

	/* get end entity certificate */
	Cert *cert = P12_get_usercert(p12);

	if (cert == NULL) {
		TLS_DPRINTF("cert: P12_get_usercert");
		OK_set_error(ERR_ST_TLS_P12_GET_USERCERT,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 0, NULL);
		return false;
	}

	/* Get key usage extension if it is present */
	info->keyusage = 0xff;
	if (cert->ext != NULL) {
		for (CertExt *ext = cert->ext; ext != NULL; ext = ext->next) {
			if (ext->extnID == OBJ_X509v3_KEY_Usage) {
				CE_KUsage *ext_keyusage = (CE_KUsage*)ext;
				info->keyusage = ext_keyusage->flag;
				break;
			}
		}
	}

	info->signature_algo = cert->signature_algo;

	info->pubkey_algo = cert->pubkey_algo;

	if (info->pubkey_algo == KEY_ECDSA_PUB) {
		get_pubkey_ecc_info(cert->pubkey, info);
	}

	return true;
}

enum tls_hs_ecc_ec_curve_type tls_cert_info_ecc_get_type(struct tls_cert_info
							 *cinfo)
{
	return tls_hs_ecdh_get_curve_type(cinfo->curve_type);
}

enum tls_hs_named_curve tls_cert_info_ecc_get_curve(struct tls_cert_info
							*cinfo)
{
	return tls_hs_ecdh_get_named_curve(cinfo->curve_type);
}

bool tls_cert_info_can_use_ecc_cipher_suite(struct tls_cert_info *cinfo,
					    struct tls_hs_ecc_eclist *eclist,
					    struct tls_hs_ecc_pflist *pflist)
{
	assert(cinfo != NULL);

	if (eclist == NULL || pflist == NULL) {
		OK_set_error(ERR_ST_NULLPOINTER,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 8, NULL);
		return false;
	}

	if (cinfo->pubkey_algo == KEY_RSA_PUB) {
		return true;
	}

	if (cinfo->pubkey_algo != KEY_ECDSA_PUB) {
		OK_set_error(ERR_ST_TLS_NOT_ECDH_PUB,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 9, NULL);
		return false;
	}

	enum tls_hs_ecc_ec_curve_type type;
	enum tls_hs_named_curve name;

	type = tls_cert_info_ecc_get_type(cinfo);

	switch (type) {
	case TLS_ECC_CTYPE_NAMED_CURVE:
		name = tls_cert_info_ecc_get_curve(cinfo);
		if (name == 0) {
			OK_set_error(ERR_ST_TLS_UNSUPPORTED_CURVE,
				     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 1,
				     NULL);
			return false;
		}
		for (int i = 0; i < eclist->len; i++) {
			if (name == eclist->list[i]) {
				return true;
			}
		}
		break;

	case TLS_ECC_CTYPE_EXPLICIT_CHAR2:
	case TLS_ECC_CTYPE_EXPLICIT_PRIME:
		/* TODO: unsupported */
		TLS_DPRINTF("cert: unspported curve type");
		break;
	default:
		/* Does not reach. */
		TLS_DPRINTF("cert: unknown curve type");
		break;
	}

	OK_set_error(ERR_ST_TLS_UNSUPPORTED_CURVE,
		     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 10, NULL);

	return false;
}

bool tls_cert_info_available(enum tls_cipher_suite suite,
			     struct tls_cert_info *info,  bool can_use_ecc,
			     bool set_error)
{
	enum tls_keyexchange_method keyexc_method;
	keyexc_method = tls_cipher_keymethod(suite);
	TLS_KXC_METHOD_DUMP(keyexc_method);

	/*
	 * RFC5246 section 7.4.2. Table
	 */
	switch (keyexc_method) {
	case TLS_KXC_RSA:
		/*
		 * RSA public key; the certificate MUST allow the key to
		 * be used for encryption (the keyEncipherment bit
		 * MUST be set if the key usage extension is present).
		 */
		if (info->pubkey_algo != KEY_RSA_PUB) {
			TLS_DPRINTF("cert: pubkey_algo=%d", info->pubkey_algo);
			if (set_error) {
				OK_set_error(ERR_ST_TLS_NOT_RSA_PUB,
					     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 2,
					     NULL);
			}
			return false;
		}

		if (!(info->keyusage & keyEncipherment)) {
			TLS_DPRINTF("cert: keyEncipherment bit");
			if (set_error) {
				OK_set_error(ERR_ST_TLS_CERT_KEYENCIPHERMENT,
					     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 3,
					     NULL);
			}
			return false;
		}
		return true;

	case TLS_KXC_ECDHE_RSA:
		if (can_use_ecc != true) {
			TLS_DPRINTF("cert: ecc cipher suite");
			return false;
		}
		/* FALLTHROUGH */
	case TLS_KXC_DHE_RSA:
		/*
		 * RSA public key; the certificate MUST allow the
		 * key to be used for signing (the
		 * digitalSignature bit MUST be set if the key
		 * usage extension is present) with the signature
		 * scheme and hash algorithm that will be employed
		 * in the server key exchange message.
		 */
		if (info->pubkey_algo != KEY_RSA_PUB) {
			TLS_DPRINTF("cert: pubkey_algo=%d", info->pubkey_algo);
			if (set_error) {
				OK_set_error(ERR_ST_TLS_NOT_RSA_PUB,
					     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 4,
					     NULL);
			}
			return false;
		}

		if (!(info->keyusage & digitalSignature)) {
			TLS_DPRINTF("cert: digitalSignature bit");
			if (set_error) {
				OK_set_error(
					ERR_ST_TLS_CERT_FLAG_DIGITALSIGNATURE,
					ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 5,
					NULL);
			}
			return false;
		}
		return true;

	case TLS_KXC_DHE_DSS:
		/*
		 * DSA public key; the certificate MUST allow the
		 * key to be used for signing with the hash
		 * algorithm that will be employed in the server
		 * key exchange message.
		 */
		/* TODO: not implementation. */
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEY_EXCHANGE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 11, NULL);
		return false;

	case TLS_KXC_DH_DSS:
	case TLS_KXC_DH_RSA:
		/*
		 * Diffie-Hellman public key; the keyAgreement bit
		 * MUST be set if the key usage extension is
		 * present.
		 */
		/* TODO: not implementation. */
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEY_EXCHANGE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 12, NULL);
		return false;

	case TLS_KXC_ECDH_ECDSA:
	case TLS_KXC_ECDH_RSA:
		/*
		 * ECDH-capable public key; the public key MUST
		 * use a curve and point format supported by the
		 * client, as described in [TLSECC].
		 */
		if (can_use_ecc != true) {
			TLS_DPRINTF("cert: ecc cipher suite");
			return false;
		}

		if (info->pubkey_algo != KEY_ECDSA_PUB) {
			TLS_DPRINTF("cert: pubkey_algo=%d", info->pubkey_algo);
			if (set_error) {
				OK_set_error(ERR_ST_TLS_NOT_ECDH_PUB,
					     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 6,
					     NULL);
			}
			return false;
		}
		return true;

	case TLS_KXC_ECDHE_ECDSA:
		/*
		 * ECDSA-capable public key; the certificate MUST
		 * allow the key to be used for signing with the
		 * hash algorithm that will be employed in the
		 * server key exchange message.  The public key
		 * MUST use a curve and point format supported by
		 * the client, as described in  [TLSECC].
		 */
		/* TODO: not implementation. */
		TLS_DPRINTF("cert: unsupported key excahnge");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEY_EXCHANGE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 13, NULL);
		return false;

	case TLS_KXC_DH_anon:
	case TLS_KXC_ECDH_anon:
		/* TODO: not implementation. */
		TLS_DPRINTF("cert: unsupported key excahnge");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEY_EXCHANGE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 14, NULL);
		return false;
	default:
		/* TODO */
		TLS_DPRINTF("cert: unknown key exchange method.");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEY_EXCHANGE,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 15, NULL);
		return false;
	}

	return true;
}

bool TLS_cert_verify(TLS *tls, PKCS12 *p12) {
	/* TODO: TLS_cert_verify uses store manager internally. So,
	 * currnet implementation do not support to verify certificates
	 * without store manager. this specification derived by
	 * aissl. */
	Cert *cert;
	if ((cert = P12_get_usercert(p12)) == NULL) {
		TLS_DPRINTF("cert: P12_get_usercert");
		OK_set_error(ERR_ST_TLS_P12_GET_USERCERT,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 22, NULL);
		return false;
	}

	int status;
	if ((status = tls_stm_verify(tls, cert)) < 0) {
		TLS_DPRINTF("cert: tls_stm_verify");
		OK_set_error(ERR_ST_TLS_STM_VERIFY,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 23, NULL);
		return false;
	}

	if (! check_verify_status(tls, status)) {
		TLS_DPRINTF("cert: check_verify_status");
		/*
		 * Do not call OK_set_error().
		 * Because disappears the reason message set by
		 * check_verify_status().
		 */
		return false;
	}

	return true;
}

int TLS_set_clientkey_file(TLS *tls, char* filename, char* password) {
	if (tls_cert_set_by_file(tls, TLS_CONNECT_CLIENT, filename, password)) {
		return 0;
	}
	return -1;
}

int TLS_set_serverkey_file(TLS *tls, char* filename, char* password) {
	if (tls_cert_set_by_file(tls, TLS_CONNECT_SERVER, filename, password)) {
		return 0;
	}
	return -1;
}

int TLS_set_clientkey_p12(TLS *tls, PKCS12 *p12) {
	if (tls_cert_set_by_pkcs12(tls, TLS_CONNECT_CLIENT, p12) == true) {
		return 0;
	}

	return -1;
}

int TLS_set_serverkey_p12(TLS *tls, PKCS12 *p12) {
	if (tls_cert_set_by_pkcs12(tls, TLS_CONNECT_SERVER, p12) == true) {
		return 0;
	}

	return -1;
}

int TLS_set_clientkey_id(TLS *tls, char *id) {
	if (tls_cert_set_by_id(tls, TLS_CONNECT_CLIENT, id)) {
		return 0;
	}
	return -1;
}

int TLS_set_serverkey_id(TLS *tls, char *id) {
	if (tls_cert_set_by_id(tls, TLS_CONNECT_SERVER, id)) {
		return 0;
	}
	return -1;
}

Cert * TLS_get_client_cert(TLS *tls) {
	if (tls == NULL) {
		TLS_DPRINTF("cert: tls (null)");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 25, NULL);
		return NULL;
	}

	Cert *crt;
	PKCS12 *pkcs12 = tls->pkcs12_client;
	if (pkcs12 == NULL) {
		TLS_DPRINTF("cert: pkcs12_client (null)");
		OK_set_error(ERR_ST_TLS_PKCS12_CLIENT_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 26, NULL);
		return NULL;
	}

	crt = Cert_dup(P12_get_usercert(pkcs12));
	if (NULL == crt) {
		TLS_DPRINTF("cert: usercert (null)");
		return NULL;
	}

	return crt;
}

Cert * TLS_get_server_cert(TLS *tls) {
	if (tls == NULL) {
		TLS_DPRINTF("cert: tls (null)");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 27, NULL);
		return NULL;
	}

	Cert *crt;
	PKCS12 *pkcs12 = tls->pkcs12_server;
	if (pkcs12 == NULL) {
		TLS_DPRINTF("cert: pkcs (null)");
		OK_set_error(ERR_ST_TLS_PKCS12_SERVER_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 28, NULL);
		return NULL;
	}

	crt = Cert_dup(P12_get_usercert(pkcs12));
	if (NULL == crt) {
		TLS_DPRINTF("cert: servercert (null)");
		return NULL;
	}

	return crt;
}

Cert * TLS_get_peer_certificate(TLS *tls) {
	if (tls == NULL) {
		TLS_DPRINTF("cert: tls (null)");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT + 29, NULL);
		return NULL;
	}

	switch(tls->entity) {
	case TLS_CONNECT_CLIENT:
		return TLS_get_server_cert(tls);

	case TLS_CONNECT_SERVER:
		return TLS_get_client_cert(tls);

	default:
		assert(!"unknown connection end specified.");
	}

	return NULL;
}

int TLS_set_server_name(TLS *tls, char *server_name) {
	if (tls == NULL) {
		TLS_DPRINTF("cert: tls (null)");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 16, NULL);
		return -1;
	}

	if (server_name == NULL) {
		TLS_DPRINTF("cert: server_name (null)");
		OK_set_error(ERR_ST_NULLPOINTER,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 17, NULL);
		return -1;
	}

	const size_t tls_ext_size_max = (2 << (16 - 1)) - 1;
	size_t len = strnlen(server_name, tls_ext_size_max + 1);
	if (len == tls_ext_size_max + 1) {
		TLS_DPRINTF("cert: server_name too long");
		OK_set_error(ERR_ST_TLS_SERVER_NAME_TOOLONG,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 18, NULL);
		return -1;
	}

	if (len == 0) {
		TLS_DPRINTF("cert: empty hostname");
		OK_set_error(ERR_ST_TLS_SERVER_NAME_EMPTY,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 19, NULL);
		return -1;
	}

	/* literal ip address is not permitted. */
	struct addrinfo hints = {
		.ai_family = AF_UNSPEC,
		.ai_socktype = SOCK_STREAM,
		.ai_flags = AI_NUMERICHOST,
		.ai_protocol = 0,
		.ai_canonname = NULL,
		.ai_addr = NULL,
		.ai_next = NULL,
	};
	struct addrinfo  *aihead;

	int gaicode;
	if ((gaicode = getaddrinfo(server_name, NULL, &hints, &aihead)) == 0) {
		OK_set_error(ERR_ST_TLS_SERVER_NAME_NUMHOST,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 20, NULL);
		freeaddrinfo(aihead);
		return -1;
	}

	if (gaicode != EAI_NONAME) {
		OK_set_error(ERR_ST_TLS_GETADDRINFO,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 21, NULL);
		return -1;
	}

	if (tls->server_name != NULL) {
		free(tls->server_name);
		tls->server_name = NULL;
	}

	if ((tls->server_name = malloc(len)) == NULL) {
		TLS_DPRINTF("cert: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 22, NULL);
		return -1;
	}

	strcpy(tls->server_name, server_name);

	/* remove trailing dots. */
	while (tls->server_name[len-1] == '.') {
		tls->server_name[len-1] = '\0';
		len--;
		if (len == 0) {
			TLS_DPRINTF("cert: hostname has only dots.");
			OK_set_error(ERR_ST_TLS_SERVER_NAME_INVALID,
				     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 23, NULL);
			return -1;
		}
	}

	return 0;
}

char *TLS_get_server_name(TLS *tls) {
	if (tls == NULL) {
		TLS_DPRINTF("cert: tls (null)");
		OK_set_error(ERR_ST_TLS_TLS_NULL,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 24, NULL);
		return NULL;
	}

	if (tls->server_name == NULL) {
		TLS_DPRINTF("cert: server_name (null)");
		OK_set_error(ERR_ST_NULLPOINTER,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 25, NULL);
		return NULL;
	}

	char *server_name;
	if ((server_name = strdup(tls->server_name)) == NULL) {
		TLS_DPRINTF("cert: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_STRDUP,
			     ERR_LC_TLS1, ERR_PT_TLS_CERT2 + 26, NULL);
		return NULL;
	}

	return server_name;
}
