/*
 * Copyright (c) 2016-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_digitally_signed.h"
#include "tls_cipher.h"
#include "tls_alert.h"

/* for P1_sign_digest and OK_do_verify. */
#include <aicrypto/ok_tool.h>

/* for RSA_PSS_params_set_recommend and RSA_PSS_sign_digest */
#include <aicrypto/ok_rsa.h>

/**
 * check whether signature scheme passed by arguement is a valid rsa signature
 * scheme for signature.
 */
static bool check_rsa_sigscheme_availability(uint16_t sigscheme);

/**
 * check whether signature scheme passed by arguement is a valid ecdsa signature
 * scheme for signature.
 */
static bool check_ecdsa_sigscheme_availability(uint16_t sigscheme,
					       int curve_type);

/**
 * check whether signature scheme passed by arguement is a valid rsassa-pss
 * signature scheme for signature.
 */
static bool check_rsassa_pss_sigscheme_availability(uint16_t sigscheme) UNUSED;

/**
 * search signature/hash pair list for signature scheme.
 */
static bool search_sigscheme_in_sighash(struct tls_hs_sighash_list *list,
				   uint16_t sigscheme);

/**
 * compose data to be signed or verified.
 */
static int32_t compose_verification_data(
    TLS *tls, uint8_t *data, enum tls_hs_sighash_hash_algo hash_algo);

/**
 * get valid signature scheme for signing considering public key in certificate.
 */
static bool get_signature_scheme(Cert *cert, 
				 struct tls_hs_sighash_list *sighash_list,
				 uint16_t *sigscheme);

/**
 * check whether certificate is compatible for signature verification.
 */
static bool check_cert(TLS *tls, Cert *cert, uint16_t sigscheme);

/**
 * write digitally signed hash of TLS 1.2 to the send handshake data.
 */
static int32_t write_digitally_signed_hash_tls12(TLS *tls, PKCS12 *p12,
						 struct tls_hs_msg *msg);

/**
 * write digitally signed hash of TLS 1.3 to the send handshake data.
 */
static int32_t write_digitally_signed_hash_tls13(TLS *tls, PKCS12 *p12,
						 struct tls_hs_msg *msg);

/**
 * read digitally signed hash of TLS 1.2 from received handshake.
 */
static int32_t read_digitally_signed_hash_tls12(TLS *tls, PKCS12 *p12,
						struct tls_hs_msg *msg,
						const uint32_t offset);

/**
 * read digitally signed hash of TLS 1.3 from received handshake.
 */
static int32_t read_digitally_signed_hash_tls13(TLS *tls, PKCS12 *p12,
						struct tls_hs_msg *msg,
						const uint32_t offset);

static bool check_rsa_sigscheme_availability(uint16_t sigscheme) {
	/*
	 * RFC8446 4.4.3.  Certificate Verify
	 *
	 *                                       ...   RSA signatures MUST use an
	 *    RSASSA-PSS algorithm, regardless of whether RSASSA-PKCS1-v1_5
	 *    algorithms appear in "signature_algorithms".  The SHA-1 algorithm
	 *    MUST NOT be used in any signatures of CertificateVerify messages.
	 *
	 *    All SHA-1 signature algorithms in this specification are defined
	 *    solely for use in legacy certificates and are not valid for
	 *    CertificateVerify signatures.
	 */
	switch (sigscheme) {
	case TLS_SS_RSA_PSS_RSAE_SHA256:
	case TLS_SS_RSA_PSS_RSAE_SHA384:
	case TLS_SS_RSA_PSS_RSAE_SHA512:
		return true;

	default:
		return false;
	}
}

static bool check_ecdsa_sigscheme_availability(uint16_t sigscheme,
					       int curve_type) {
	uint16_t expected_ss = TLS_SS_ANON_NONE;
	switch (tls_hs_ecdh_get_named_curve(curve_type)) {
	case TLS_ECC_CURVE_SECP256R1:
		expected_ss = TLS_SS_ECDSA_SECP256R1_SHA256;
		break;

	case TLS_ECC_CURVE_SECP384R1:
		expected_ss = TLS_SS_ECDSA_SECP384R1_SHA384;
		break;

	case TLS_ECC_CURVE_SECP521R1:
		expected_ss = TLS_SS_ECDSA_SECP521R1_SHA512;
		break;

	default:
		return false;
	}

	return expected_ss == sigscheme ? true : false;
}

UNUSED
static bool check_rsassa_pss_sigscheme_availability(uint16_t sigscheme) {
	switch (sigscheme) {
	case TLS_SS_RSA_PSS_PSS_SHA256:
	case TLS_SS_RSA_PSS_PSS_SHA384:
	case TLS_SS_RSA_PSS_PSS_SHA512:
		return true;

	default:
		return false;
	}
}

static bool search_sigscheme_in_sighash(struct tls_hs_sighash_list *list,
				   uint16_t sigscheme) {
	uint16_t ss;
	struct tls_hs_sighash_algo sighash;
	for (int i = 0; i < list->len; i++) {
		sighash = list->list[i];
		ss = tls_hs_sighash_convert_sighash_to_sigscheme(&sighash);
		if (sigscheme == ss) {
			return true;
		}
	}

	return false;
}

static int32_t compose_verification_data(
    TLS *tls, uint8_t *data, enum tls_hs_sighash_hash_algo hash_algo) {
	int32_t offset = 0;

	/*
	 * RFC8446 4.4.3.  Certificate Verify
	 *
	 *    The digital signature is then computed over the concatenation of:
	 *
	 *    -  A string that consists of octet 32 (0x20) repeated 64 times
	 *
	 *    -  The context string
	 *
	 *    -  A single 0 byte which serves as the separator
	 *
	 *    -  The content to be signed
	 *
	 *    This structure is intended to prevent an attack on previous versions
	 *    of TLS in which the ServerKeyExchange format meant that attackers
	 *    could obtain a signature of a message with a chosen 32-byte prefix
	 *    (ClientHello.random).  The initial 64-byte pad clears that prefix
	 *    along with the server-controlled ServerHello.random.
	 *
	 *    The context string for a server signature is
	 *    "TLS 1.3, server CertificateVerify".  The context string for a
	 *    client signature is "TLS 1.3, client CertificateVerify".  It is
	 *    used to provide separation between signatures made in different
	 *    contexts, helping against potential cross-protocol attacks.
	 */

	/* compute data covered by signature */
	char *context_string;
	switch (tls->state) {
	case TLS_STATE_HS_RECV_SCERT:
	case TLS_STATE_HS_SEND_SCERT:
			context_string = "TLS 1.3, server CertificateVerify";
			break;

	case TLS_STATE_HS_SEND_CCERT:
	case TLS_STATE_HS_RECV_CCERT:
			context_string = "TLS 1.3, client CertificateVerify";
			break;

	default:
		TLS_DPRINTF("unknown state");
		OK_set_error(ERR_ST_TLS_INVALID_STATUS,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 17,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	memset(data, 0x20, TLS_DS_PADDING_LENGTH);
	offset += TLS_DS_PADDING_LENGTH;

	memcpy(&(data[offset]), context_string, TLS_DS_CONTEXT_STRING_LENGTH);
	offset += TLS_DS_CONTEXT_STRING_LENGTH;

	data[offset] = 0x00;
	offset += TLS_DS_SEPARATOR_LENGTH;

	int32_t hash_len;
	if ((hash_len = tls_hs_sighash_get_hash_size(hash_algo)) < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_hash_size");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_hs_hash_get_digest(hash_algo, tls, &(data[offset]));
	offset += hash_len;

	return offset;
}

static bool get_signature_scheme(Cert *cert, 
				 struct tls_hs_sighash_list *sighash_list,
				 uint16_t *sigscheme) {
	struct tls_hs_sighash_algo sighash;
	uint16_t ss;
	for (int i = 0; i < sighash_list->len; i++) {
		sighash = sighash_list->list[i];
		ss = tls_hs_sighash_convert_sighash_to_sigscheme(&sighash);

		switch (cert->pubkey_algo) {
		case KEY_RSA_PUB:
			if (check_rsa_sigscheme_availability(ss) == true) {
				*sigscheme = ss;
				TLS_DPRINTF("signature scheme = %.4x", ss);
				return true;
			}
		break;

		case KEY_ECDSA_PUB:
			;
			ECParam *param = ((Pubkey_ECDSA *) cert->pubkey)->E;
			int curve_type = param->curve_type;
			if (check_ecdsa_sigscheme_availability(ss, curve_type)
			    == true) {
				*sigscheme = ss;
				TLS_DPRINTF("signature scheme = %.4x", ss);
				return true;
			}
			break;

		default:
			/* TODO: implement EdDSA and RSASSA-PSS */
			break;
		}
	}

	return false;
}

static bool check_cert(TLS *tls, Cert *cert, uint16_t sigscheme) {
	/*
	 * RFC8446 4.4.3.  Certificate Verify
	 *
	 *    In addition, the signature algorithm MUST be compatible with the key
	 *    in the sender's end-entity certificate.  RSA signatures MUST use an
	 *    RSASSA-PSS algorithm, regardless of whether RSASSA-PKCS1-v1_5
	 *    algorithms appear in "signature_algorithms".
	 */
	switch (cert->pubkey_algo) {
	case KEY_RSA_PUB:
		if (check_rsa_sigscheme_availability(sigscheme) == false) {
			TLS_DPRINTF("mismatch between signature scheme and"
				    " certificate key algorithm");
			OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
				     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 18,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return false;
		}
		break;

	case KEY_ECDSA_PUB:
		;
		ECParam *param = ((Pubkey_ECDSA *) cert->pubkey)->E;
		int curve_type = param->curve_type;
		if (check_ecdsa_sigscheme_availability(sigscheme, curve_type)
		    == false) {
			TLS_DPRINTF("mismatch between signature scheme and"
				    " certificate key algorithm");
			OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
				     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 19,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return false;
		}
		break;

	default:
		/* TODO: implement EdDSA and RSASSA-PSS */
		TLS_DPRINTF("unknown key algo");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEYALGO,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 20,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return false;
	}

	return true;
}

static int32_t write_digitally_signed_hash_tls12(TLS *tls, PKCS12 *p12,
						 struct tls_hs_msg *msg)
{
	uint32_t offset = 0;

	/* NOTE: might be to choose signature algorithm of certificate that
	 * sent by the client certificate handshake. */

	/* RFC 5246 section 7.4.8 says
	 *
	 * The hash and signature algorithms used in the signature MUST be
	 * one of those present in the supported_signature_algorithms field
	 * of the CertificateRequest message.  In addition, the hash and
	 * signature algorithms MUST be compatible with the key in the
	 * client's end-entity certificate.
	 */

	Cert *cert;
	if ((cert = P12_get_usercert(p12)) == NULL) {
		TLS_DPRINTF("P12_get_usercert");
		OK_set_error(ERR_ST_TLS_P12_GET_USERCERT,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 0,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_HANDSHAKE_FAILURE);
		return -1;
	}

	struct tls_hs_sighash_algo sighash;
	switch (cert->pubkey_algo) {
	case KEY_RSA_PUB:
		for (int i = 0; i < tls->sighash_list->len; ++i) {
			sighash = tls->sighash_list->list[i];

			if (sighash.sig   == TLS_SIG_ALGO_RSA     &&
			    (sighash.hash == TLS_HASH_ALGO_MD5    ||
			     sighash.hash == TLS_HASH_ALGO_SHA1   ||
			     sighash.hash == TLS_HASH_ALGO_SHA224 ||
			     sighash.hash == TLS_HASH_ALGO_SHA256 ||
			     sighash.hash == TLS_HASH_ALGO_SHA384 ||
			     sighash.hash == TLS_HASH_ALGO_SHA512)) {
				goto found_sighash_algo;
			}
		}

		TLS_DPRINTF("not found sig/hash algo");
		OK_set_error(ERR_ST_TLS_UNMATCH_SIGHASH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 1,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_HANDSHAKE_FAILURE);
		return -1;

	case KEY_DSA_PUB:
	case KEY_ECDSA_PUB:
		/* TODO: not implementation. */
		TLS_DPRINTF("unsupported key algo");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEYALGO,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 2,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_HANDSHAKE_FAILURE);
		return -1;

	default:
		TLS_DPRINTF("unknown key algo");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEYALGO,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 3,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}

	/*
	 * RFC5246 4.7.  Cryptographic Attributes
	 *
	 *    A digitally-signed element is encoded as a struct DigitallySigned:
	 *
	 *       struct {
	 *          SignatureAndHashAlgorithm algorithm;
	 *          opaque signature<0..2^16-1>;
	 *       } DigitallySigned;
	 */

found_sighash_algo:
	if (! tls_hs_msg_write_1(msg, sighash.hash)) {
		return -1;
	}
	offset++;

	if (! tls_hs_msg_write_1(msg, sighash.sig)) {
		return -1;
	}
	offset++;

	/* write dummy length once. */
	int32_t cert_len_pos = msg->len;
	if (! tls_hs_msg_write_2(msg, 0)) {
		return -1;
	}
	offset += 2;

	uint32_t size = tls_hs_sighash_get_hash_size(sighash.hash);
	uint8_t  seed[size];

	switch (msg->type) {
	case TLS_HANDSHAKE_SERVER_KEY_EXCHANGE:
		/* signature digest */
		tls_hs_signature_get_digest(sighash.hash, tls, seed);
		break;
	case TLS_HANDSHAKE_CERTIFICATE_VERIFY:
		/* handshake_messages digest */
		tls_hs_hash_get_digest(sighash.hash, tls, seed);
		break;
	default:
		assert(!"message type error");
	}

	TLS_DPRINTF("digitally_signed: hash = %d (len = %d), sig = %d",
		    sighash.hash, size, sighash.sig);

	Key *privkey;
	if((privkey = P12_get_privatekey(p12)) == NULL) {
		TLS_DPRINTF("P12_get_privatekey");
		OK_set_error(ERR_ST_TLS_P12_GET_PRIVATEKEY,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 4,
			     NULL);
		return -1;
	}

	int32_t hash_type = tls_hs_sighash_get_ai_hash_type(sighash.hash);
	if (hash_type < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_ai_hash_type %d", hash_type);
		return -1;
	}

	uint8_t *dest = NULL;
	switch(sighash.sig) {
	case TLS_SIG_ALGO_RSA:
		if ((dest = P1_sign_digest(privkey, seed, size,
					   hash_type)) == NULL) {
			TLS_DPRINTF("P1_sign_digest");
			OK_set_error(ERR_ST_TLS_P12_SIGN_DIGEST,
				     ERR_LC_TLS1,
				     ERR_PT_TLS_DIGITALLY_SIGNED + 5, NULL);
			return -1;
		}
		offset += privkey->size;
		break;

	default:
		return -1;
	}

	if (dest == NULL) {
		TLS_DPRINTF("dest == NULL");
		return -1;
	}

	const int32_t sig_len_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (privkey->size > sig_len_max) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_DIGITALLY_SIGNED2 + 0,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		free(dest);
		return -1;
	}

	if (! tls_hs_msg_write_n(msg, dest, privkey->size)) {
		free(dest);
		return -1;
	}

	tls_util_write_2(&(msg->msg[cert_len_pos]), privkey->size);

	free(dest);
	return offset;
}

static int32_t write_digitally_signed_hash_tls13(TLS *tls, PKCS12 *p12,
						 struct tls_hs_msg *msg)
{
	uint32_t offset = 0;

	Cert *cert;
	if ((cert = P12_get_usercert(p12)) == NULL) {
		TLS_DPRINTF("P12_get_usercert");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/* select signature scheme for signature */
	uint16_t sigscheme;
	if (get_signature_scheme(cert, tls->sighash_list, &sigscheme)
	    == false) {
		if (tls->entity == TLS_CONNECT_CLIENT) {
			TLS_DPRINTF("get_signature_scheme");
			OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
				     ERR_LC_TLS1,
				     ERR_PT_TLS_DIGITALLY_SIGNED + 21, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		/*
		 * fallback may happen in certificate message. signature scheme
		 * for signing doesn't exist in tls->sighash_list, then use
		 * default signature scheme list to search.
		 */
		struct tls_hs_sighash_list *sighash_list = NULL;
		if ((sighash_list = tls_hs_sighash_list(tls)) == NULL) {
			TLS_DPRINTF("hs: m: cert: tls_hs_sighash_list");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		if (get_signature_scheme(cert, sighash_list, &sigscheme)
		    == false) {
			TLS_DPRINTF("get_signature_scheme");
			OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
				     ERR_LC_TLS1,
				     ERR_PT_TLS_DIGITALLY_SIGNED + 22, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			tls_hs_sighash_free(sighash_list);
			return -1;
		}

		tls_hs_sighash_free(sighash_list);
	}

	/* write signature scheme */
	uint32_t sigscheme_length = 2;
	if (tls_hs_msg_write_2(msg, sigscheme) == false) {
		return -1;
	}
	offset += sigscheme_length;

	/* write dummy length */
	int32_t cert_len_pos = msg->len;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	offset += 2;

	/* compose data for signing */
	enum tls_hs_sighash_hash_algo hash_algo_thash;
	if ((hash_algo_thash = tls_cipher_hashalgo(tls->pending->cipher_suite))
	    == TLS_HASH_ALGO_NONE) {
		TLS_DPRINTF("tls_cipher_hashalgo");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	int32_t hash_len;
	if ((hash_len = tls_hs_sighash_get_hash_size(hash_algo_thash)) < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_hash_size");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	TLS_DPRINTF("transcript hash length = %d", hash_len);

	int32_t data_len = TLS_DS_PREFIX_LENGTH + hash_len;
	uint8_t data[data_len];
	int32_t len;
	if ((len = compose_verification_data(tls, data, hash_algo_thash)) < 0) {
		TLS_DPRINTF("compose_verification_data");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	TLS_DPRINTF("data length = %d", data_len);

	if (data_len != len) {
		TLS_DPRINTF("compose_verification_data");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 23,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/* get private key related to certificate */
	Key *privkey;
	if((privkey = P12_get_privatekey(p12)) == NULL) {
		TLS_DPRINTF("P12_get_privatekey");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/* prepare parameters for signing */
	int32_t sig_type;
	if ((sig_type = tls_hs_sighash_get_ai_sig_type_by_ss(sigscheme)) < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_ai_sig_type_by_ss");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	enum tls_hs_sighash_hash_algo hash_algo_sign;
	if (tls_hs_sighash_get_hash_type(sigscheme, &hash_algo_sign) < 0) {
		TLS_DPRINTF("unknown hash type: 0x%.4x", sigscheme);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}

	int32_t hash_type;
	if ((hash_type = tls_hs_sighash_get_ai_hash_type(hash_algo_sign)) < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_ai_hash_type");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	rsassa_pss_params_t params;
	switch (privkey->key_type) {
	case KEY_RSA_PRV:
		if (RSA_PSS_params_set_recommend(&params, hash_type) < 0) {
			TLS_DPRINTF("RSA_PSS_params_set_recommend");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		TLS_DPRINTF("rsassa_psa_params.hashAlgorithm    = %d",
			    params.hashAlgorithm);
		TLS_DPRINTF("rsassa_psa_params.maskGenAlgorithm = %d",
			    params.maskGenAlgorithm);
		TLS_DPRINTF("rsassa_psa_params.saltLength       = %d",
			    params.saltLength);
		TLS_DPRINTF("rsassa_psa_params.trailerField     = %d",
			    params.trailerField);
		break;

	default:
		break;
	}

	/* sign data */
	uint8_t *sig = NULL;
	int sig_len;
	if (NRG_do_signature(privkey, data, data_len, &sig, &sig_len,
			     sig_type, &params) < 0) {
		TLS_DPRINTF("NRG_do_signature");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/*
	 * RFC8446 4.4.3.  Certificate Verify
	 *
	 *           opaque signature<0..2^16-1>;
	 */
	const int32_t sig_len_max = TLS_VECTOR_2_BYTE_SIZE_MAX;
	if (sig_len > sig_len_max) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_DIGITALLY_SIGNED2 + 1,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		free(sig);
		return -1;
	}

	offset += sig_len;

	if (tls_hs_msg_write_n(msg, sig, sig_len) == false) {
		free(sig);
		return -1;
	}

	tls_util_write_2(&(msg->msg[cert_len_pos]), sig_len);
	TLS_DPRINTF("sig_len = %d", sig_len);

	free(sig);
	return offset;
}

static int32_t read_digitally_signed_hash_tls12(TLS *tls, PKCS12 *p12,
						struct tls_hs_msg *msg,
						const uint32_t offset)
{
	int32_t read_bytes = 0;

	const uint32_t length_bytes = 2;
	if (msg->len < (offset + length_bytes)) {
		TLS_DPRINTF("record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 6,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += length_bytes;

	/*
	 * RFC5246 4.7.  Cryptographic Attributes
	 *
	 *    A digitally-signed element is encoded as a struct DigitallySigned:
	 *
	 *       struct {
	 *          SignatureAndHashAlgorithm algorithm;
	 *          opaque signature<0..2^16-1>;
	 *       } DigitallySigned;
	 */
	struct tls_hs_sighash_algo sighash;
	sighash.hash = msg->msg[offset + 0];
	sighash.sig  = msg->msg[offset + 1];

	TLS_DPRINTF("digitally_signed: hash = %d, sig = %d",
		    sighash.hash, sighash.sig);

	Cert *cert;
	if ((cert = P12_get_usercert(p12)) == NULL) {
		TLS_DPRINTF("P12_get_usercert");
		OK_set_error(ERR_ST_TLS_P12_GET_USERCERT,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 7,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_HANDSHAKE_FAILURE);
		return -1;
	}

	/* RFC 5246 section 7.4.8 says
	 *
	 * The hash and signature algorithms used in the signature MUST be
	 * one of those present in the supported_signature_algorithms field
	 * of the CertificateRequest message.  In addition, the hash and
	 * signature algorithms MUST be compatible with the key in the
	 * client's end-entity certificate.
	 */
	if (tls_hs_sighash_availablep(tls, sighash) == false) {
		TLS_DPRINTF("certreq: invalid sighash algo.");
		OK_set_error(ERR_ST_TLS_INVALID_SIGHASH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 8,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	switch (cert->pubkey_algo) {
	case KEY_RSA_PUB:
		if (sighash.sig   != TLS_SIG_ALGO_RSA     ||
		    (sighash.hash != TLS_HASH_ALGO_MD5    &&
		     sighash.hash != TLS_HASH_ALGO_SHA1   &&
		     sighash.hash != TLS_HASH_ALGO_SHA224 &&
		     sighash.hash != TLS_HASH_ALGO_SHA256 &&
		     sighash.hash != TLS_HASH_ALGO_SHA384 &&
		     sighash.hash != TLS_HASH_ALGO_SHA512)) {
			TLS_DPRINTF("not compatible sig/hash algo");
			OK_set_error(ERR_ST_TLS_INVALID_SIGHASH,
				     ERR_LC_TLS1,
				     ERR_PT_TLS_DIGITALLY_SIGNED + 9, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_HANDSHAKE_FAILURE);
			return -1;
		}
		break;

	case KEY_DSA_PUB:
	case KEY_ECDSA_PUB:
		/* TODO: not implementation. */
		TLS_DPRINTF("unsupported key algo");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEYALGO,
			     ERR_LC_TLS1,
			     ERR_PT_TLS_DIGITALLY_SIGNED + 10, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_HANDSHAKE_FAILURE);
		return -1;

	default:
		TLS_DPRINTF("unknown key algo");
		OK_set_error(ERR_ST_TLS_UNSUPPORTED_KEYALGO,
			     ERR_LC_TLS1,
			     ERR_PT_TLS_DIGITALLY_SIGNED + 11, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}

	const uint32_t sig_length_bytes = 2;
	if (msg->len < (offset + read_bytes + sig_length_bytes)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 12,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/* get signature length. */
	uint16_t siglen = tls_util_read_2(&(msg->msg[offset + read_bytes]));
	if (siglen != cert->pubkey->size) {
		TLS_DPRINTF("pubkey size");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 13,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += sig_length_bytes;

	if (msg->len < (offset + read_bytes + siglen)) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 14,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	uint32_t size = tls_hs_sighash_get_hash_size(sighash.hash);
	uint8_t  digest[size];

	switch (msg->type) {
	case TLS_HANDSHAKE_SERVER_KEY_EXCHANGE:
		/* signature digest */
		tls_hs_signature_get_digest(sighash.hash, tls, digest);
		break;
	case TLS_HANDSHAKE_CERTIFICATE_VERIFY:
		/* handshake_messages digest */
		tls_hs_hash_get_digest(sighash.hash, tls, digest);
		break;
	default:
		assert(!"message type error");
	}

	/* verifiy signature. */
	int32_t aicrypto_sighash_type = tls_hs_sighash_get_ai_sig_type(sighash);
	if (aicrypto_sighash_type < 0) {
		TLS_DPRINTF("aicrypto signature type");
		OK_set_error(ERR_ST_TLS_INVALID_SIGHASH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 15,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}

	uint8_t *signature = &(msg->msg[offset + read_bytes]);
	if (OK_do_verify(cert->pubkey, &(digest[0]), &(signature[0]),
			 aicrypto_sighash_type) != 0) {
		TLS_DPRINTF("aicrypto signature type");
		OK_set_error(ERR_ST_TLS_OK_DO_VERIFY,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 16,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECRYPT_ERROR);
		return -1;
	}
	read_bytes += siglen;

	return read_bytes;
}

static int32_t read_digitally_signed_hash_tls13(TLS *tls, PKCS12 *p12,
						struct tls_hs_msg *msg,
						const uint32_t offset)
{
	int32_t read_bytes = 0;

	const uint32_t sigscheme_bytes = 2;
	if (msg->len < (offset + sigscheme_bytes)) {
		TLS_DPRINTF("record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 24,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	/*
	 * RFC8446 4.4.3.  Certificate Verify
	 *
	 *           opaque signature<0..2^16-1>;
	 */
	uint16_t sigscheme = tls_util_read_2(&(msg->msg[offset]));
	read_bytes += sigscheme_bytes;

	/*
	 * RFC8446 4.4.3.  Certificate Verify
	 *
	 *                                             ...  The SHA-1 algorithm
	 *    MUST NOT be used in any signatures of CertificateVerify messages.
	 *
	 *    All SHA-1 signature algorithms in this specification are defined
	 *    solely for use in legacy certificates and are not valid for
	 *    CertificateVerify signatures.
	 */
	enum tls_hs_sighash_hash_algo hash_algo_sign;
	if (tls_hs_sighash_get_hash_type(sigscheme, &hash_algo_sign) < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_hash_type");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}

	if (hash_algo_sign == TLS_HASH_ALGO_SHA1) {
		TLS_DPRINTF("SHA1 is used for signature");
		OK_set_error(ERR_ST_TLS_FORBIDDEN_HASH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 25,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		return -1;
	}

	/*
	 * RFC8446 4.4.3.  Certificate Verify
	 *
	 *    If the CertificateVerify message is sent by a server, the signature
	 *    algorithm MUST be one offered in the client's "signature_algorithms"
	 *    extension unless no valid certificate chain can be produced without
	 *    unsupported algorithms (see Section 4.2.3).
	 *
	 *    If sent by a client, the signature algorithm used in the signature
	 *    MUST be one of those present in the supported_signature_algorithms
	 *    field of the "signature_algorithms" extension in the
	 *    CertificateRequest message.
	 */

	/*
	 * TODO: use sighash_list that was sent in client hello or certificate
	 * reqeust message. it is not implemented now, then use
	 * tls_hs_sighash_list() because it has same values.
	 */
	struct tls_hs_sighash_list *sighash_list;
	if ((sighash_list = tls_hs_sighash_list(tls)) == NULL) {
		return -1;
	}

	if (search_sigscheme_in_sighash(sighash_list, sigscheme) == false) {
		TLS_DPRINTF("signature scheme not found");
		OK_set_error(ERR_ST_TLS_ILLEGAL_PARAMETER,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 26,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
		tls_hs_sighash_free(sighash_list);
		return -1;
	}

	tls_hs_sighash_free(sighash_list);

	TLS_DPRINTF("digitally_signed: sigscheme = 0x%.4x", sigscheme);

	/* get peer end entity certificate */
	Cert *cert;
	if ((cert = P12_get_usercert(p12)) == NULL) {
		TLS_DPRINTF("P12_get_usercert");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/* check compatibility of certificate with signature scheme. */
	if (check_cert(tls, cert, sigscheme) == false) {
		TLS_DPRINTF("check_cert");
		return -1;
	}

	const uint32_t sig_length_bytes = 2;
	if (msg->len < (offset + read_bytes + sig_length_bytes)) {
		TLS_DPRINTF("record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 27,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	uint16_t sig_len = tls_util_read_2(&(msg->msg[offset + read_bytes]));
	read_bytes += sig_length_bytes;

	if (msg->len < (offset + read_bytes + sig_len)) {
		TLS_DPRINTF("record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 28,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	TLS_DPRINTF("sig_len = %d", sig_len);

	/* compose data for verifying */
	enum tls_hs_sighash_hash_algo hash_algo_thash;
	if ((hash_algo_thash = tls_cipher_hashalgo(tls->pending->cipher_suite))
	    == TLS_HASH_ALGO_NONE) {
		TLS_DPRINTF("tls_cipher_hashalgo");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	int32_t hash_len_thash;
	if ((hash_len_thash = tls_hs_sighash_get_hash_size(hash_algo_thash))
	    < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_hash_size");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	int32_t data_len = TLS_DS_PREFIX_LENGTH + hash_len_thash;
	uint8_t data[data_len];
	int32_t len;
	if ((len = compose_verification_data(tls, data, hash_algo_thash)) < 0) {
		TLS_DPRINTF("compose_verification_data");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	if (data_len != len) {
		TLS_DPRINTF("compose_verification_data");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 29,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	int32_t sig_type;
	if ((sig_type = tls_hs_sighash_get_ai_sig_type_by_ss(sigscheme)) < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_ai_sig_type_by_ss");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/* prepare parameters for verification */
	int32_t hash_type;
	if ((hash_type = tls_hs_sighash_get_ai_hash_type(hash_algo_sign)) < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_ai_hash_type");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	rsassa_pss_params_t params;
	switch (cert->pubkey->key_type) {
	case KEY_RSA_PUB:
		if (RSA_PSS_params_set_recommend(&params, hash_type) < 0) {
			TLS_DPRINTF("RSA_PSS_params_set_recommend");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		TLS_DPRINTF("rsassa_psa_params.hashAlgorithm    = %d",
			    params.hashAlgorithm);
		TLS_DPRINTF("rsassa_psa_params.maskGenAlgorithm = %d",
			    params.maskGenAlgorithm);
		TLS_DPRINTF("rsassa_psa_params.saltLength       = %d",
			    params.saltLength);
		TLS_DPRINTF("rsassa_psa_params.trailerField     = %d",
			    params.trailerField);
		break;

	default:
		break;
	}

	/* verify signature */
	int32_t hash_len_sign;
	if ((hash_len_sign = tls_hs_sighash_get_hash_size(hash_algo_sign)) < 0) {
		TLS_DPRINTF("tls_hs_sighash_get_hash_size");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	uint8_t hash[hash_len_sign];
	int32_t ret_len;
	if (OK_do_digest(hash_type, data, data_len, hash, &ret_len) == NULL) {
		TLS_DPRINTF("OK_do_digest");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/*
	 * RFC8446 4.4.3.  Certificate Verify
	 *
	 *    If the verification fails, the receiver MUST terminate the handshake
	 *    with a "decrypt_error" alert.
	 */
	uint8_t *sig = &(msg->msg[offset + read_bytes]);
	if (NRG_do_verify(cert->pubkey, hash, sig, sig_type, &params) < 0) {
		TLS_DPRINTF("NRG_do_verify");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECRYPT_ERROR);
		return -1;
	}
	read_bytes += sig_len;

	return read_bytes;
}

int32_t tls_digitally_signed_write_hash(TLS *tls, PKCS12 *p12,
					struct tls_hs_msg *msg)
{
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
		/* TODO: not implementation. */
		return -1;

	case TLS_VER_TLS12:
		return write_digitally_signed_hash_tls12(tls, p12, msg);

	case TLS_VER_TLS13:
		return write_digitally_signed_hash_tls13(tls, p12, msg);

	default:
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 30,
			     NULL);
		return -1;
	}
}

int32_t tls_digitally_signed_read_hash(TLS *tls, PKCS12 *p12,
				       struct tls_hs_msg *msg,
				       const uint32_t offset)
{
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
		/* TODO: not implementation. */
		return -1;

	case TLS_VER_TLS12:
		return read_digitally_signed_hash_tls12(tls, p12, msg, offset);

	case TLS_VER_TLS13:
		return read_digitally_signed_hash_tls13(tls, p12, msg, offset);

	default:
		OK_set_error(ERR_ST_TLS_PROTOCOL_VERSION,
			     ERR_LC_TLS1, ERR_PT_TLS_DIGITALLY_SIGNED + 31,
			     NULL);
		return -1;
	}
}
