/*
 * Copyright (c) 2015-2020 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_cert.h"
#include "extension/tls_sighash.h"
#include "tls_alert.h"

#include <string.h>

/* for CertList, Cert, CertDN and cert_dn_free. */
#include <aicrypto/ok_x509.h>

/* for ASN1_step, ASN1_skip and ASN1_get_subject */
#include <aicrypto/ok_asn1.h>

/** @see tls_stm.c */
extern bool tls_stm_nullp(TLS *tls);

/** @see tls_stm.c */
extern CertList * tls_stm_get_cert_list(TLS *tls);

/**
 * write acceptable certificate type to the send data.
 */
static int32_t write_cert_type(TLS *tls, struct tls_hs_msg *msg);

/**
 * write acceptable certificate authorities to the send data.
 */
static int32_t write_cert_authorities(TLS *tls, struct tls_hs_msg *msg);

/**
 * read server acceptable certificate type from received handshake.
 */
static int32_t read_cert_type(TLS *tls,
			      const struct tls_hs_msg *msg,
			      const uint32_t offset);

/**
 * read server acceptable hash/signature algorithm pair from received
 * handshake.
 */
static int32_t read_cert_sigalgo(TLS *tls,
				 const struct tls_hs_msg *msg,
				 const uint32_t offset);

/**
 * read server acceptable certificate authorities from received
 * handshake.
 */
static int32_t read_cert_authorities(TLS *tls,
				     const struct tls_hs_msg *msg,
				     const uint32_t offset);

/**
 * check whether the received extension is available in this module.
 */
static bool check_ext_availability_tls13(const enum tls_extension_type type);

static bool check_ext_availability(TLS *tls, const enum tls_extension_type type);

/**
 * write signature_algorithms extension to the send data.
 */
static int32_t write_ext_sigalgo(TLS *tls, struct tls_hs_msg *msg);

/**
 * write signature_algorithms_cert extension to the send data.
 */
static int32_t write_ext_sigalgo_cert(TLS *tls, struct tls_hs_msg *msg);

/**
 * write certificate_authorities extension to the send data.
 */
static int32_t write_ext_cert_authorities(TLS *tls, struct tls_hs_msg *msg);

/**
 * handle extensions stored in list.
 */
static bool read_ext_list(TLS *tls,
			  const enum tls_extension_type type,
			  const struct tls_hs_msg *msg,
			  const uint32_t offset);

/**
 * interpret extensions stored in tls_interim_params.
 */
static bool interpret_ext_list(TLS *tls);

/**
 * write certificate request data to message structure.
 */
static int32_t write_certreq_up_to_tls12(TLS *tls,
					 struct tls_hs_msg *msg);

static int32_t write_certreq_tls13(TLS *tls,
				   struct tls_hs_msg *msg);

static int32_t write_certreq(TLS *tls, struct tls_hs_msg *msg);

/**
 * read certificate request data from message structure.
 */
static int32_t read_certreq_up_to_tls12(TLS *tls, struct tls_hs_msg *msg,
				  uint32_t offset);

static int32_t read_certreq_tls13(TLS *tls, struct tls_hs_msg *msg,
				  uint32_t offset);

static int32_t read_certreq(TLS *tls, struct tls_hs_msg *msg,
				  uint32_t offset);

static int32_t write_cert_type(TLS *tls, struct tls_hs_msg *msg) {
	/* write acceptable certificate type list in the server. it is
	 * not necessary to save this data. */

	int32_t offset = 0;

	struct tls_cert_type_list *certtypes;
	if ((certtypes =  tls_cert_type_list(tls)) == NULL) {
		return -1;
	}

	/*
	 * RFC5246 7.4.4.  Certificate Request
	 *
	 *           ClientCertificateType certificate_types<1..2^8-1>;
	 */
	const int32_t cert_list_min = 1;
	const int32_t cert_list_max = TLS_VECTOR_1_BYTE_SIZE_MAX;
	if (certtypes->len < cert_list_min || cert_list_max < certtypes->len) {
		TLS_DPRINTF("certreq: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTREQ + 9,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		tls_cert_type_free(certtypes);
		return -1;
	}

	if (! tls_hs_msg_write_1(msg, certtypes->len)) {
		tls_cert_type_free(certtypes);
		return -1;
	}
	offset += 1;

	for (int i = 0; i < certtypes->len; ++i) {
		if (! tls_hs_msg_write_1(msg, certtypes->list[i])) {
			tls_cert_type_free(certtypes);
			return -1;
		}
		offset += 1;
	}

	/* I think it should remember what the server sent as server
	 * acceptable certificate type list. but, it is clear here. so I
	 * skip that routine. */
	tls_cert_type_free(certtypes);

	return offset;
}

int32_t write_cert_authorities(TLS *tls, struct tls_hs_msg *msg) {
	/* write acceptable CA list in the server. it is not necessary
	 * to save this data. */

	const int32_t ca_list_length_bytes = 2;

	int32_t offset = 0;

	CertList *list = NULL;
	if (tls_stm_nullp(tls) == false) {
		if ((list = tls_stm_get_cert_list(tls)) == NULL) {
			return offset;
		}
	}

	/* write dummy length. */
	int32_t pos = msg->len;
	if (! tls_hs_msg_write_2(msg, 0)) {
		goto err;
	}
	offset += ca_list_length_bytes;

	for (; list; list = list->next) {
		if (! list->cert) {
			continue;
		}

		/* TODO: refactor this code. */

		/* set CA subject list */
		uint8_t *cp;
		if (*(cp = ASN1_step(list->cert->der, 2)) == 0xa0) {
			/* version 3 */
			cp = ASN1_skip(cp);
		}

		/* serial number, signature algo, issuer DN, validate  */
		for (int i = 0; i < 4; ++i) {
			cp = ASN1_skip(cp);
		}

		/* XXX: what is 0x30. */
		if (*(cp) == 0x30) {
			int32_t len = tls_util_asn1_length(cp + 1);
			if (! tls_hs_msg_write_2(msg, len)) {
				goto err;
			}
			offset += 2;

			if (! tls_hs_msg_write_n(msg, cp, len)) {
				goto err;
			}
			offset += len;
		}
	}

	/* write length of CA list. */
	tls_util_write_2(&(msg->msg[pos]), offset - ca_list_length_bytes);

	if (list != NULL) {
		Certlist_free_all(list);
	}
	return offset;

err:
	if (list != NULL) {
		Certlist_free_all(list);
	}
	return -1;
}

static int32_t read_cert_type(TLS *tls,
			      const struct tls_hs_msg *msg,
			      const uint32_t offset) {
	uint32_t read_bytes = 0;

	const int32_t length_bytes = 1;
	if (msg->len < (offset + length_bytes)) {
		TLS_DPRINTF("hs: m: certreq: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTREQ + 0, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += length_bytes;

	uint8_t list_length = msg->msg[offset];

	/*
	 * RFC5246 7.4.4.  Certificate Request
	 *
	 *           ClientCertificateType certificate_types<1..2^8-1>;
	 */
	const uint8_t list_length_min = 1;
	if (list_length < list_length_min) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTREQ + 10, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	if (msg->len < (offset + read_bytes + list_length)) {
		TLS_DPRINTF("hs: m: certreq: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTREQ + 1, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	enum tls_cert_type type_list[list_length];
	int32_t len = 0;
	for (uint8_t i = 0; i < list_length; ++ i) {
		enum tls_cert_type type = msg->msg[offset + length_bytes + i];
		if (tls_cert_type_availablep(type) == true) {
			type_list[len] = type;
			len++;
		}
	}
	read_bytes += list_length;

	struct tls_cert_type_list *list;
	if ((list = malloc (1 * sizeof (struct tls_cert_type_list))) == NULL) {
		TLS_DPRINTF("hs: m: certreq: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTREQ + 2, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	list->len = len;

	if ((list->list = malloc(list->len *
				 sizeof (enum tls_cert_type))) == NULL) {
		TLS_DPRINTF("hs: m: certreq: malloc: %s", strerror(errno));
		OK_set_error(ERR_ST_TLS_MALLOC,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTREQ + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		free(list);
		return -1;
	}
	memcpy(&(list->list[0]), &(type_list[0]), len);

	tls->certtype_list_server = list;

	return read_bytes;
}

static int32_t read_cert_sigalgo(TLS *tls,
				 const struct tls_hs_msg *msg,
				 const uint32_t offset)
{
	int sighashlen = 0;
	uint32_t read_bytes = 0;

	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_TLS12:
		if ((sighashlen = tls_hs_sighash_read(tls, msg, offset)) < 0) {
			TLS_DPRINTF("tls_hs_sighash_read failed");
			return -1;
		}
		break;

	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	default:
		TLS_DPRINTF("hs: m: certreq: version");
		goto done;
	}
	read_bytes += sighashlen;

	if (tls->pkcs12_client == NULL) {
		TLS_DPRINTF("hs: m: certreq: NULL");
		goto done;
	}

	Cert *cert = P12_get_usercert(tls->pkcs12_client);
	if (cert == NULL) {
		TLS_DPRINTF("P12_get_usercert failed");
		return -1;
	}

	TLS_DPRINTF("certreq: cert->signature_algo = %d", cert->signature_algo);

	for (int i = 0; i < tls->sighash_list->len; ++i) {
		struct tls_hs_sighash_algo sighash
			= tls->sighash_list->list[i];

		int ai_sig;
		if ((ai_sig = tls_hs_sighash_get_ai_sig_type(sighash)) < 0) {
			TLS_DPRINTF("hs: m: certreq: unsupported sig type: %d, %d",
				    sighash.sig, sighash.hash);
			OK_set_error(ERR_ST_TLS_AICRYPTO_SIG_TYPE,
				     ERR_LC_TLS2,
				     ERR_PT_TLS_HS_MSG_CERTREQ + 8, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		if (cert->signature_algo == ai_sig) {
			TLS_DPRINTF("hs: m: certreq: server supported: %d", ai_sig);
			goto done;
		}
	}

	/* available certificates are none. in this case, next client
	 * certificate handshake protocol message should send null
	 * certificate. therefore, free tls->pkcs12_client. */
	P12_free(tls->pkcs12_client);
	tls->pkcs12_client = NULL;

	TLS_DPRINTF("hs: m: certreq: server unsupported: %d",
		    cert->signature_algo);

done:
	return read_bytes;
}

static int32_t read_cert_authorities(TLS *tls,
				     const struct tls_hs_msg *msg,
				     const uint32_t offset) {
	uint32_t read_bytes = 0;

	const int32_t length_bytes = 2;

	/* in the case to use aistore, do not refer the
	 * tls->store_manager structure member. in that case(use
	 * aistore), application that use this tls library set the id of
	 * aistore by *_set_clientkey_id() function, and, use it.  */
	char* issuer = NULL;
	if (tls->pkcs12_client != NULL) {
		Cert *cert;
		if ((cert = P12_get_usercert(tls->pkcs12_client)) != NULL) {
			issuer = cert->issuer;
		}
	}
	TLS_DPRINTF("certreq: ca: issuer = %s", issuer);

	if (msg->len < (offset + length_bytes)) {
		TLS_DPRINTF("certreq: ca: length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 0, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}
	read_bytes += length_bytes;

	uint16_t list_length = tls_util_read_2(&(msg->msg[offset]));

	/*
	 * RFC8446 4.2.4.  Certificate Authorities
	 *
	 *           DistinguishedName authorities<3..2^16-1>;
	 *
	 * RFC5246 7.4.4.  Certificate Request
	 *
	 *           DistinguishedName certificate_authorities<0..2^16-1>;
	 *
	 * RFC4346 7.4.4. Certificate request
	 *
	 *           DistinguishedName certificate_authorities<0..2^16-1>;
	 *
	 * RFC2246 7.4.4. Certificate request
	 *
	 *            DistinguishedName certificate_authorities<3..2^16-1>;
	 *
	 * RFC6101 5.6.4.  Certificate Request
	 *
	 *             DistinguishedName certificate_authorities<3..2^16-1>;
	 */
	const uint16_t cert_authorities_size_min = 3;
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_SSL30:
		if (list_length < cert_authorities_size_min) {
			TLS_DPRINTF("certreq: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS6,
				     ERR_PT_TLS_HS_MSG_CERTREQ2 + 1, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_ILLEGAL_PARAMETER);
			return -1;
		}
		break;

	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
		if (list_length == 0) {
			/* all certificates are available. */
			TLS_DPRINTF("certreq: ca: null");
			goto done;
		}
		break;

	case TLS_VER_TLS10:
	case TLS_VER_TLS13:
		if (list_length < cert_authorities_size_min) {
			TLS_DPRINTF("certreq: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS6,
				     ERR_PT_TLS_HS_MSG_CERTREQ2 + 2, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
			return -1;
		}
		break;

	default:
		TLS_DPRINTF("unknown version");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	const uint32_t subject_length = 2;
	for (uint32_t read_offset = read_bytes; read_offset < list_length;) {
		const uint32_t base = offset + read_offset;
		if (msg->len < (base + subject_length)) {
			TLS_DPRINTF("hs: m: certreq: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS2,
				     ERR_PT_TLS_HS_MSG_CERTREQ + 4, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
			return -1;
		}

		const uint16_t len = tls_util_read_2(&(msg->msg[base]));
		if (msg->len < (base + subject_length + len)) {
			TLS_DPRINTF("hs: m: certreq: invalid record length");
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS2,
				     ERR_PT_TLS_HS_MSG_CERTREQ + 5, NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
			return -1;
		}

		/* if use just a &(msg->msg[base + subject_length]) as a
		 * argument of ASN1_get_subject function, valgrind
		 * produce the error message as overflow detection of
		 * malloc'ed space. So, in here, copy &(msg->msg[base +
		 * subject_length]) to auto variable once.  */
		uint8_t buf[len + 1];
		memset(buf, 0, sizeof(buf));
		memcpy(buf, &(msg->msg[base + subject_length]), len);

		bool matched = false;

		CertDN dn;
		char *subject = NULL;
		if ((subject = ASN1_get_subject(&(buf[0]), &dn)) != NULL) {
			if ((issuer != NULL) &&
			    (strcmp(subject, issuer) == 0)) {
				matched = true;
			}
		}

		cert_dn_free(&dn);
		free(subject);

		if (matched == true) {
			goto done;
		}

		read_offset += subject_length + len;
	}

	/* available certificates are none. in this case, next client
	 * certificate handshake protocol message should send null
	 * certificate. therefore, free tls->pkcs12_client. */
	P12_free(tls->pkcs12_client);
	tls->pkcs12_client = NULL;

	TLS_DPRINTF("certreq: ca: unmatched");

done:
	return read_bytes + list_length;
}

static bool check_ext_availability_tls13(const enum tls_extension_type type) {
	switch(type) {
	case TLS_EXT_SIGNATURE_ALGO:
	case TLS_EXT_SIGNED_CERTIFICATE_TIMESTAMP:
	case TLS_EXT_CERTIFICATE_AUTHORITIES:
	case TLS_EXT_OID_FILTERS:
	case TLS_EXT_SIGNATURE_ALGO_CERT:
		return true;

	default:
		break;
	}

	return false;
}

static bool check_ext_availability(TLS *tls,
					 const enum tls_extension_type type) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));

	switch(version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
		/* Not supported */
		break;

	case TLS_VER_TLS13:
		return check_ext_availability_tls13(type);

	default:
		/* Unknown version */
		break;
	}

	return false;
}

static int32_t write_ext_sigalgo(TLS *tls, struct tls_hs_msg *msg) {
	int32_t offset = 0;

	const uint32_t type_bytes = 2;
	if (tls_hs_msg_write_2(msg, TLS_EXT_SIGNATURE_ALGO) == false) {
		return -1;
	}
	offset += type_bytes;

	/* write dummy length bytes. */
	int32_t pos = msg->len;

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	offset += len_bytes;

	int32_t sighash_len;
	if ((sighash_len = tls_hs_sighash_write(tls, msg)) < 0) {
		return -1;
	}
	offset += sighash_len;

	const int32_t extlen_max = TLS_EXT_SIZE_MAX;
	if (sighash_len > extlen_max) {
		TLS_DPRINTF("certreq: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 10,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), sighash_len);

	return offset;
}

static int32_t write_ext_sigalgo_cert(TLS *tls, struct tls_hs_msg *msg) {
	int32_t offset = 0;

	const uint32_t type_bytes = 2;
	if (tls_hs_msg_write_2(msg, TLS_EXT_SIGNATURE_ALGO_CERT) == false) {
		return -1;
	}
	offset += type_bytes;

	/* write dummy length bytes. */
	int32_t pos = msg->len;

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	offset += len_bytes;

	int32_t sighash_len;
	if ((sighash_len = tls_hs_sighash_cert_write(tls, msg)) < 0) {
		return -1;
	}
	offset += sighash_len;

	const int32_t extlen_max = TLS_EXT_SIZE_MAX;
	if (sighash_len > extlen_max) {
		TLS_DPRINTF("certreq: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 11,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), sighash_len);

	return offset;
}

static int32_t write_ext_cert_authorities(TLS *tls, struct tls_hs_msg *msg) {
	int32_t offset = 0;

	const uint32_t type_bytes = 2;
	if (tls_hs_msg_write_2(msg, TLS_EXT_CERTIFICATE_AUTHORITIES) == false) {
		return -1;
	}
	offset += type_bytes;

	/* write dummy length bytes. */
	int32_t pos = msg->len;

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	offset += len_bytes;

	int32_t cert_authorities_len;
	if ((cert_authorities_len = write_cert_authorities(tls, msg)) < 0) {
		return -1;
	}
	offset += cert_authorities_len;

	/*
	 * RFC8446 4.2.4.  Certificate Authorities
	 *
	 *           DistinguishedName authorities<3..2^16-1>;
	 */
	const int32_t cert_authorities_size_min = 3;
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
		/* Not supported */
		TLS_DPRINTF("Not supported version");
		return -1;

	case TLS_VER_TLS13:
		/*
		 * In TLS 1.3, certificate_authorities extension must not be
		 * empty. If the content is empty, rollback this extention.
		 */
		if (cert_authorities_len < cert_authorities_size_min) {
			msg->len -= offset;
			return 0;
		}
		break;

	default:
		TLS_DPRINTF("Unknown version");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 3, NULL);
		return -1;
	}

	TLS_DPRINTF("cert_authorities_len = %d", cert_authorities_len);

	const int32_t extlen_max = TLS_EXT_SIZE_MAX;
	if (cert_authorities_len > extlen_max) {
		TLS_DPRINTF("certreq: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 12,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), cert_authorities_len);

	return offset;
}

static bool read_ext_list(TLS *tls,
			  const enum tls_extension_type type,
			  const struct tls_hs_msg *msg,
			  const uint32_t offset)
{
	/*
	 * assume unknown extensions never come because check
	 * is performed before this function is called.
	 */
	/*
	 * RFC8446 4.2.  Extensions
	 *
	 *                        There MUST NOT be more than one extension of the
	 *    same type in a given extension block.
	 */
	bool *recv_exts = tls->interim_params->recv_ext_flags;
	if (recv_exts[type] == true) {
		TLS_DPRINTF("certreq: extensions of same type come multiple times");
		return false;
	}
	recv_exts[type] = true;

	switch(type) {
	case TLS_EXT_SIGNATURE_ALGO:
		if (tls_hs_sighash_read(tls, msg, offset) < 0) {
			TLS_DPRINTF("tls_hs_sighash_read");
			return false;
		}
		return true;

	case TLS_EXT_SIGNATURE_ALGO_CERT:
		if (tls_hs_sighash_cert_read(tls, msg, offset) < 0) {
			TLS_DPRINTF("tls_hs_sighash_ext_cert_read");
			return false;
		}
		break;

	case TLS_EXT_CERTIFICATE_AUTHORITIES:
		if (read_cert_authorities(tls, msg, offset) < 0) {
			TLS_DPRINTF("read_cert_authorities");
			return false;
		}
		break;

	default:
		/*
		 * RFC8446 4.3.2.  Certificate Request
		 *
		 *    extensions:  A set of extensions describing the parameters of the
		 *       certificate being requested.  The "signature_algorithms" extension
		 *       MUST be specified, and other extensions may optionally be included
		 *       if defined for this message.  Clients MUST ignore unrecognized
		 *       extensions.
		 */
		break;
	}

	return true;
}

static bool interpret_ext_list(TLS *tls) {
	struct tls_hs_interim_params *params = tls->interim_params;
	struct tls_extension *ext;
	struct tls_hs_msg msg;

	/* search an extension list for signature_algorithms_cert extension. */
	struct tls_extension *ext_sig_cert = NULL;
	TAILQ_FOREACH(ext, &(params->head), link) {
		if (ext->type == TLS_EXT_SIGNATURE_ALGO_CERT) {
			ext_sig_cert = ext;
		}
	}

	/*
	 * signature_algorithms_cert extension must be processed earlier than
	 * signature_algorithms extension in TLS 1.3.
	 */
	if (ext_sig_cert != NULL) {
		msg.type = TLS_HANDSHAKE_CERTIFICATE_REQUEST;
		msg.len = ext_sig_cert->len;
		msg.max = ext_sig_cert->len;
		msg.msg = ext_sig_cert->opaque;

		if (! read_ext_list(tls, ext_sig_cert->type, &msg, 0)) {
			/* alerts is sent by internal of read_ext_list. */
			TLS_DPRINTF("read_ext_list");
			return false;
		}

		TAILQ_REMOVE(&(params->head), ext_sig_cert, link);
		tls_extension_free(ext_sig_cert);
	}

	TAILQ_FOREACH(ext, &(params->head), link) {
		if (check_ext_availability(tls, ext->type) == false) {
			continue;
		}

		msg.type = TLS_HANDSHAKE_CERTIFICATE_REQUEST;
		msg.len = ext->len;
		msg.max = ext->len;
		msg.msg = ext->opaque;

		if (! read_ext_list(tls, ext->type, &msg, 0)) {
			TLS_DPRINTF("read_ext_list");
			/* alerts is sent by internal of read_ext_list. */
			return false;
		}
	}

	if (params->recv_ext_flags[TLS_EXT_SIGNATURE_ALGO] == false) {
		TLS_DPRINTF("certreq: missing signature_algorithms extension");
		OK_set_error(ERR_ST_TLS_MISSING_EXTENSION,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 4, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_MISSING_EXTENSION);
		return false;
	}

	while (!TAILQ_EMPTY(&(params->head))) {
		ext = TAILQ_FIRST(&(params->head));
		TAILQ_REMOVE(&(params->head), ext, link);
		tls_extension_free(ext);
	}

	return true;
}

static int32_t write_certreq_up_to_tls12(TLS *tls,
					 struct tls_hs_msg *msg) {
	int32_t offset = 0;

	int certtypelen;
	if ((certtypelen = write_cert_type(tls, msg)) < 0) {
		TLS_DPRINTF("write_certificate_type");
		return -1;
	}
	offset += certtypelen;

	int sighashlen;
	if ((sighashlen = tls_hs_sighash_write(tls, msg)) < 0) {
		TLS_DPRINTF("write_sighash_algorithm");
		return -1;
	}
	offset += sighashlen;

	int certauthlen;
	if ((certauthlen = write_cert_authorities(tls, msg)) < 0) {
		TLS_DPRINTF("tls_cert_write_authorities");
		return -1;
	}
	offset += certauthlen;

	return offset;
}

static int32_t write_certreq_tls13(TLS *tls,
				   struct tls_hs_msg *msg) {
	int32_t offset = 0;

	/* CertificateRequest message has following structure.
	 *
	 * | type                   (1) |
	 * | length of message      (3) |
	 * | request context length (1) |
	 * | request context        (n) |
	 * | extension length       (2) |
	 * | extension              (n) |
	 */

	/*
	 * RFC8446 4.3.2.  Certificate Request
	 *
	 *    certificate_request_context:  An opaque string which identifies the
	 *       certificate request and which will be echoed in the client's
	 *       Certificate message.  The certificate_request_context MUST be
	 *       unique within the scope of this connection (thus preventing replay
	 *       of client CertificateVerify messages).  This field SHALL be zero
	 *       length unless used for the post-handshake authentication exchanges
	 *       described in Section 4.6.2.  When requesting post-handshake
	 *       authentication, the server SHOULD make the context unpredictable
	 *       to the client (e.g., by randomly generating it) in order to
	 *       prevent an attacker who has temporary access to the client's
	 *       private key from pre-computing valid CertificateVerify messages.
	 */
	/*
	 * RFC8446 4.3.2.  Certificate Request
	 *
	 *           opaque certificate_request_context<0..2^8-1>;
	 */
	uint32_t request_context_length_bytes = 1;
	uint8_t request_context_len = 0;
	if (tls_hs_msg_write_1(msg, request_context_len) == false) {
		return -1;
	}
	offset += request_context_length_bytes;

	if (request_context_len > 0) {
		/*
		 * TODO: generate random value for request context when
		 * post-handshake authentication is used.
		 */
	}

	/* write dummy length bytes. */
	int32_t pos = msg->len;

	const uint32_t len_bytes = 2;
	if (tls_hs_msg_write_2(msg, 0) == false) {
		return -1;
	}
	offset += len_bytes;

	int32_t sighash_len;
	if ((sighash_len = write_ext_sigalgo(tls, msg)) < 0) {
		return -1;
	}
	offset += sighash_len;

	int32_t sighash_cert_len;
	if ((sighash_cert_len = write_ext_sigalgo_cert(tls, msg)) < 0) {
		return -1;
	}
	offset += sighash_cert_len;

	int32_t cert_authorities_len;
	if ((cert_authorities_len = write_ext_cert_authorities(tls, msg)) < 0) {
		return -1;
	}
	offset += cert_authorities_len;

	/*
	 * RFC8446 4.3.2.  Certificate Request
	 *
	 *           Extension extensions<2..2^16-1>;
	 */
	uint16_t ext_len = offset - request_context_length_bytes
		- request_context_len - len_bytes;
	const uint16_t extensions_length_min = 2;
	if (ext_len < extensions_length_min) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 13, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	tls_util_write_2(&(msg->msg[pos]), ext_len);

	if (msg->len != (uint32_t) offset) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 5, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	return offset;
}

static int32_t write_certreq(TLS *tls, struct tls_hs_msg *msg) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
		return write_certreq_up_to_tls12(tls, msg);

	case TLS_VER_TLS13:
		return write_certreq_tls13(tls, msg);

	default:
		TLS_DPRINTF("certreq: unknown version");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
}

static int32_t read_certreq_up_to_tls12(TLS *tls, struct tls_hs_msg *msg,
				  uint32_t offset) {
	int32_t read_bytes = 0;

	int certtypelen;
	if ((certtypelen = read_cert_type(tls, msg, offset + read_bytes))
	    < 0) {
		TLS_DPRINTF("read_certificate_type");
		return -1;
	}
	read_bytes += certtypelen;

	int sigalgolen;
	if ((sigalgolen = read_cert_sigalgo(tls, msg, offset + read_bytes))
	    < 0) {
		TLS_DPRINTF("read_cert_sigalgo");
		return -1;
	}
	read_bytes += sigalgolen;

	int certauthlen;
	if ((certauthlen = read_cert_authorities(tls, msg, offset + read_bytes))
	    < 0) {
		TLS_DPRINTF("tls_cert_read_authorities");
		return -1;
	}
	read_bytes += certauthlen;

	return read_bytes;
}

static int32_t read_certreq_tls13(TLS *tls, struct tls_hs_msg *msg,
				  uint32_t offset) {
	int32_t read_bytes = 0;

	/*
	 * RFC8446 4.3.2.  Certificate Request
	 *
	 *           opaque certificate_request_context<0..2^8-1>;
	 */
	uint32_t request_context_length_bytes = 1;
	if (msg->len < (offset + request_context_length_bytes)) {
		TLS_DPRINTF("extension: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 6, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	uint8_t request_context_len = 0;
	request_context_len = msg->msg[offset];
	read_bytes += request_context_length_bytes;

	TLS_DPRINTF("request_context_len = %d", request_context_len);

	if (request_context_len > 0) {
		/*
		 * TODO: store request_context when implementing post-handshake
		 * authentication.
		 */
		TLS_DPRINTF("receive request_context");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 7, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	int32_t extlen;
	if ((extlen = tls_hs_extension_parse(tls, msg, offset + read_bytes)) < 0) {
		TLS_DPRINTF("tls_hs_extension_parse");
		return -1;
	}

	/*
	 * RFC8446 4.3.2.  Certificate Request
	 *
	 *           Extension extensions<2..2^16-1>;
	 */
	const uint16_t extlen_min = 2;
	if (extlen < extlen_min) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 14, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	read_bytes += extlen;

	if (msg->len != (uint32_t) read_bytes) {
		TLS_DPRINTF("extension: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 8, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	if (interpret_ext_list(tls) == false) {
		TLS_DPRINTF("interpret_ext_list");
		return -1;
	}

	return read_bytes;
}

static int32_t read_certreq(TLS *tls, struct tls_hs_msg *msg,
				  uint32_t offset) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
		return read_certreq_up_to_tls12(tls, msg, offset);

	case TLS_VER_TLS13:
		return read_certreq_tls13(tls, msg, offset);

	default:
		TLS_DPRINTF("certreq: unknown version");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS6, ERR_PT_TLS_HS_MSG_CERTREQ2 + 9,
			     NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
}

struct tls_hs_msg * tls_hs_certreq_compose(TLS *tls) {
	uint32_t offset = 0;
	struct tls_hs_msg *msg;

	if ((msg = tls_hs_msg_init()) == NULL) {
		TLS_DPRINTF("tls_hs_msg_init");
		return NULL;
	}

	msg->type = TLS_HANDSHAKE_CERTIFICATE_REQUEST;

	int write_bytes = 0;
	if ((write_bytes = write_certreq(tls, msg)) < 0) {
		TLS_DPRINTF("write_certreq");
		goto failed;
	}
	offset += write_bytes;

	msg->len = offset;

	tls->certreq_used = true;

	return msg;

failed:
	tls_hs_msg_free(msg);
	return NULL;
}

bool tls_hs_certreq_parse(TLS *tls, struct tls_hs_msg *msg) {
	uint32_t offset = 0;

	if (msg->type != TLS_HANDSHAKE_CERTIFICATE_REQUEST) {
		TLS_DPRINTF("hs: m: certreq: invalid handshake type");
		OK_set_error(ERR_ST_TLS_UNEXPECTED_MSG,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTREQ + 6, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
		return false;
	}

	int read_bytes = 0;
	if ((read_bytes = read_certreq(tls, msg, 0)) < 0) {
		TLS_DPRINTF("read_certreq");
		return false;
	}
	offset += read_bytes;

	if (msg->len != offset) {
		TLS_DPRINTF("hs: m: certreq: invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTREQ + 7, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return false;
	}

	/*
	 * RFC8446 4.4.2.  Certificate
	 *
	 *    The client MUST send a Certificate message if and only if the server
	 *    has requested client authentication via a CertificateRequest message
	 *    (Section 4.3.2).
	 */
	tls->certreq_used = true;

	return true;
}
