/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_handshake.h"
#include "tls_digitally_signed.h"
#include "tls_alert.h"

#include <string.h>

/**
 * write certificate verify data to message structure.
 */
static int32_t write_certvfy_up_to_tls12(TLS *tls, struct tls_hs_msg *msg);

static int32_t write_certvfy_tls13(TLS *tls, struct tls_hs_msg *msg);

static int32_t write_certvfy(TLS *tls, struct tls_hs_msg *msg);

/**
 * read certificate verify data from message structure.
 */
static int32_t read_certvfy_up_to_tls12(TLS *tls, struct tls_hs_msg *msg,
					uint32_t offset);

static int32_t read_certvfy_tls13(TLS *tls, struct tls_hs_msg *msg,
				  uint32_t offset);

static int32_t read_certvfy(TLS *tls, struct tls_hs_msg *msg, uint32_t offset);

static int32_t write_certvfy_up_to_tls12(TLS *tls, struct tls_hs_msg *msg) {
	PKCS12 *p12 = tls->pkcs12_client;

	int32_t dslen;
	if ((dslen = tls_digitally_signed_write_hash(tls, p12, msg)) < 0) {
		TLS_DPRINTF("tls_digitally_signed_write_hash");
		return -1;
	}
	tls->is_ccert_auth = true;

	return dslen;
}

static int32_t write_certvfy_tls13(TLS *tls, struct tls_hs_msg *msg) {
	PKCS12 *p12;

	int32_t dslen;
	switch (tls->entity) {
	case TLS_CONNECT_CLIENT:
		p12 = tls->pkcs12_client;
		if ((dslen = tls_digitally_signed_write_hash(tls, p12, msg)) < 0) {
			TLS_DPRINTF("tls_digitally_signed_write_hash");
			return -1;
		}
		tls->is_ccert_auth = true;
		break;

	case TLS_CONNECT_SERVER:
		p12 = tls->pkcs12_server;
		if ((dslen = tls_digitally_signed_write_hash(tls, p12, msg)) < 0) {
			TLS_DPRINTF("tls_digitally_signed_write_hash");
			return -1;
		}
		break;

	default:
		TLS_DPRINTF("unknown entity");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTVFY + 0, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	return dslen;
}

static int32_t write_certvfy(TLS *tls, struct tls_hs_msg *msg) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
		return write_certvfy_up_to_tls12(tls, msg);

	case TLS_VER_TLS13:
		return write_certvfy_tls13(tls, msg);

	default:
		TLS_DPRINTF("unknown version");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTVFY + 1, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
}

static int32_t read_certvfy_up_to_tls12(TLS *tls, struct tls_hs_msg *msg,
					uint32_t offset) {
	PKCS12 *p12 = tls->pkcs12_client;

	int32_t dslen;
	if ((dslen = tls_digitally_signed_read_hash(tls, p12, msg, offset))
	    < 0) {
		TLS_DPRINTF("tls_digitally_signed_read_hash");
		return -1;
	}
	tls->is_ccert_auth = true;

	return dslen;
}

static int32_t read_certvfy_tls13(TLS *tls, struct tls_hs_msg *msg,
				  uint32_t offset) {
	PKCS12 *p12;

	int32_t dslen;
	switch (tls->entity) {
	case TLS_CONNECT_CLIENT:
		p12 = tls->pkcs12_server;
		if ((dslen = tls_digitally_signed_read_hash(tls, p12, msg, offset))
		    < 0) {
			TLS_DPRINTF("tls_digitally_signed_read_hash");
			return -1;
		}
		break;

	case TLS_CONNECT_SERVER:
		p12 = tls->pkcs12_client;
		if ((dslen = tls_digitally_signed_read_hash(tls, p12, msg, offset))
		    < 0) {
			TLS_DPRINTF("tls_digitally_signed_read_hash");
			return -1;
		}
		tls->is_ccert_auth = true;
		break;

	default:
		TLS_DPRINTF("unknown entity");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTVFY + 2, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	return dslen;
}

static int32_t read_certvfy(TLS *tls, struct tls_hs_msg *msg, uint32_t offset) {
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
	case TLS_VER_TLS12:
		return read_certvfy_up_to_tls12(tls, msg, offset);

	case TLS_VER_TLS13:
		return read_certvfy_tls13(tls, msg, offset);

	default:
		TLS_DPRINTF("unknown version");
		OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTVFY + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}
}

struct tls_hs_msg * tls_hs_certvfy_compose(TLS *tls) {
	uint32_t offset = 0;

	struct tls_hs_msg *msg;

	if ((msg = tls_hs_msg_init()) == NULL) {
		TLS_DPRINTF("tls_hs_msg_init");
		return NULL;
	}

	msg->type = TLS_HANDSHAKE_CERTIFICATE_VERIFY;

	int32_t write_bytes;
	if ((write_bytes = write_certvfy(tls, msg)) < 0) {
		TLS_DPRINTF("write_certvfy");
		goto failed;
	}
	offset += write_bytes;

	return msg;

failed:
	tls_hs_msg_free(msg);
	return NULL;
}

bool tls_hs_certvfy_parse(TLS *tls, struct tls_hs_msg *msg) {
	uint32_t offset = 0;

	if (msg->type != TLS_HANDSHAKE_CERTIFICATE_VERIFY) {
		TLS_DPRINTF("invalid handshake type");
		OK_set_error(ERR_ST_TLS_UNEXPECTED_MSG,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTVFY + 17, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_UNEXPECTED_MESSAGE);
		return false;
	}

	int32_t read_bytes;
	if ((read_bytes = read_certvfy(tls, msg, offset)) < 0) {
		TLS_DPRINTF("read_certvfy");
		return false;
	}
	offset += read_bytes;

	if (msg->len != offset) {
		TLS_DPRINTF("invalid record length");
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS2, ERR_PT_TLS_HS_MSG_CERTVFY + 18, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
	}

	return true;
}
