/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_mac.h"

#include <string.h>

/* for HMAC_MD5, HMAC_SHA1, HMAC_SHA256, HMAC_SHA384 and HMAC_SHA512. */
#include <aicrypto/ok_hmac.h>

/** size of md5 digest */
static const int32_t TLS_MAC_HMAC_MD5_LEN = 16;

/** size of sha1 digest */
static const int32_t TLS_MAC_HMAC_SHA1_LEN = 20;

/** size of sha256 digest */
static const int32_t TLS_MAC_HMAC_SHA256_LEN = 32;

/** size of sha384 digest */
static const int32_t TLS_MAC_HMAC_SHA384_LEN = 48;

/** size of sha512 digest */
static const int32_t TLS_MAC_HMAC_SHA512_LEN = 64;

/**
 * set a sequence number within the text that calculates the MAC.
 */
static int32_t init_seqnum(const struct tls_connection conn,
			   uint8_t *buf, const uint32_t buf_len);

/**
 * set a record type within the text that calculates the MAC.
 */
static int32_t init_type(const enum tls_record_ctype type,
			 uint8_t *buf, const uint32_t buf_len);

/**
 * set a version that is used by record layer within the text that
 * calculates the MAC.
 */
static int32_t init_ver(const struct tls_protocol_version version,
			uint8_t *buf, const uint32_t buf_len);

/**
 * set a length that is used by record layer within the text that
 * calculates the MAC.
 */
static int32_t init_len(const uint32_t len,
			uint8_t *buf, const uint32_t buf_len);

/**
 * get the text that calculates the MAC or AAD.
 *
 * To calculate the AAD, cbuf must be NULL.
 */
static int32_t init_text_tls12(const struct tls_connection conn,
			 const struct tls_protocol_version version,
			 uint8_t *text,
			 const uint32_t text_len,
			 const enum tls_record_ctype type,
			 const uint8_t *cbuf,
			 const int32_t clen);

/**
 * get the AAD for TLS 1.3.
 */
static int32_t init_text_tls13(uint8_t *text, const uint32_t text_len,
			       const struct tls_protocol_version version,
			       const enum tls_record_ctype type,
			       const int32_t len);

static int32_t init_seqnum(const struct tls_connection conn,
			   uint8_t *buf, const uint32_t buf_len) {
	uint64_t num = conn.seqnum;
	uint32_t off = 0;

	const uint32_t seqnum_len = 6;
	if (buf_len < seqnum_len) {
		return -1;
	}

	buf[off++]  = (num >> 56) & 0xff;
	buf[off++]  = (num >> 48) & 0xff;
	buf[off++]  = (num >> 40) & 0xff;
	buf[off++]  = (num >> 32) & 0xff;
	buf[off++]  = (num >> 24) & 0xff;
	buf[off++]  = (num >> 16) & 0xff;
	buf[off++]  = (num >>  8) & 0xff;
	buf[off++]  = (num)       & 0xff;

	return off;
}

static int32_t init_type(const enum tls_record_ctype type,
			 uint8_t *buf, const uint32_t buf_len) {
	uint32_t off = 0;

	const uint32_t type_len = 1;
	if (buf_len < type_len) {
		return -1;
	}

	buf[off++] = type;

	return off;
}

static int32_t init_ver(const struct tls_protocol_version version,
			uint8_t *buf, const uint32_t buf_len) {
	uint32_t off = 0;

	const uint32_t ver_len = 2;
	if (buf_len < ver_len) {
		return -1;
	}

	buf[off++] = version.major;
	buf[off++] = version.minor;

	return off;
}

static int32_t init_len(const uint32_t len,
			uint8_t *buf, const uint32_t buf_len) {
	uint32_t off = 0;

	const uint32_t length_bytes = 2;
	if (buf_len < length_bytes) {
TLS_DPRINTF("xxx");
		return -1;
	}

	buf[off++] = (len >> 8) & 0xff;
	buf[off++] = (len)      & 0xff;

	return off;
}

static int32_t init_text_tls12(const struct tls_connection conn,
			 const struct tls_protocol_version version,
			 uint8_t *text,
			 const uint32_t text_len,
			 const enum tls_record_ctype type,
			 const uint8_t *cbuf,
			 const int32_t clen) {
	int32_t offset = 0;

	int32_t seqnum_len;
	if ((seqnum_len = init_seqnum(conn, &(text[offset]),
				      text_len - offset)) < 0) {
		return -1;
	}
	offset += seqnum_len;

	int32_t type_len;
	if ((type_len = init_type(type, &(text[offset]),
				  text_len - offset)) < 0) {
		return -1;
	}
	offset += type_len;

	int32_t ver_len;
	if ((ver_len = init_ver(version, &(text[offset]),
				text_len - offset)) < 0) {
		return -1;
	}
	offset += ver_len;

	int32_t len_bytes;
	if ((len_bytes = init_len(clen, &(text[offset]),
				  text_len - offset)) < 0) {
		return -1;
	}
	offset += len_bytes;

	if (cbuf) {
		if (text_len < (uint32_t)(offset + clen)) {
			return -1;
		}

		memcpy(&(text[offset]), &(cbuf[0]), clen);
		offset += clen;
	}

	return offset;
}

static int32_t init_text_tls13(uint8_t *text, const uint32_t text_len,
			       const struct tls_protocol_version version,
			       const enum tls_record_ctype type,
			       const int32_t len) {
	int32_t offset = 0;

	int32_t type_len;
	if ((type_len = init_type(type, &(text[offset]),
				  text_len - offset)) < 0) {
		return -1;
	}
	offset += type_len;

	int32_t ver_len;
	if ((ver_len = init_ver(version, &(text[offset]),
				text_len - offset)) < 0) {
		return -1;
	}
	offset += ver_len;

	int32_t len_bytes;
	if ((len_bytes = init_len(len, &(text[offset]),
				  text_len - offset)) < 0) {
		return -1;
	}
	offset += len_bytes;

	return offset;
}

bool tls_mac_init(const struct tls_connection conn,
		  const struct tls_protocol_version version,
		  uint8_t *mac,
		  const int32_t mac_len,
		  const enum tls_record_ctype type,
		  const uint8_t *cbuf,
		  const int32_t clen) {
	/* generate MAC. calculating formula is as follows (RFC 5246
	 * section 6.2.3.1).
	 *
	 *   MAC(MAC_write_key, text)
	 *
	 *   text = seq_num +
	 *            TLSCompressed.type    +
	 *            TLSCompressed.version +
	 *            TLSCompressed.length  +
	 *            TLSCompressed.fragment
	 */

	enum mac_algorithm algo;
	if ((algo = conn.cipher.mac_algorithm) == TLS_MAC_NULL) {
		return true;
	}

	uint8_t *key     = conn.mac_key;
	uint32_t key_len = conn.cipher.mac_key_length;

	/* text_len is as follows.
	 *
	 * |--------------|
	 * | seq_num  (8) |
	 * | type     (1) |
	 * | version  (2) |
	 * | length   (2) |
	 * | fragment (n) |
	 * |--------------|
	 */
	uint32_t text_len = (8 + 1 + 2 + 2 + clen);
	uint8_t  text[text_len];
	if (init_text_tls12(conn, version,
		      &(text[0]), text_len, type, cbuf, clen) < 0) {
		return false;
	}

	switch (algo) {
	case TLS_MAC_HMAC_MD5:
		if (mac_len != TLS_MAC_HMAC_MD5_LEN) {
			OK_set_error(ERR_ST_TLS_INVALID_MAC_LENGTH,
				     ERR_LC_TLS4, ERR_PT_TLS_MAC + 0, NULL);
			return false;
		}
		HMAC_MD5(text_len, text, key_len, key, mac);
		break;

	case TLS_MAC_HMAC_SHA1:
		if (mac_len != TLS_MAC_HMAC_SHA1_LEN) {
			OK_set_error(ERR_ST_TLS_INVALID_MAC_LENGTH,
				     ERR_LC_TLS4, ERR_PT_TLS_MAC + 1, NULL);
			return false;
		}
		HMAC_SHA1(text_len, text, key_len, key, mac);
		break;

	case TLS_MAC_HMAC_SHA256:
		if (mac_len != TLS_MAC_HMAC_SHA256_LEN) {
			OK_set_error(ERR_ST_TLS_INVALID_MAC_LENGTH,
				     ERR_LC_TLS4, ERR_PT_TLS_MAC + 2, NULL);
			return false;
		}
		HMAC_SHA256(text_len, text, key_len, key, mac);
		break;

	case TLS_MAC_HMAC_SHA384:
		if (mac_len != TLS_MAC_HMAC_SHA384_LEN) {
			OK_set_error(ERR_ST_TLS_INVALID_MAC_LENGTH,
				     ERR_LC_TLS4, ERR_PT_TLS_MAC + 3, NULL);
			return false;
		}
		HMAC_SHA384(text_len, text, key_len, key, mac);
		break;

	case TLS_MAC_HMAC_SHA512:
		if (mac_len != TLS_MAC_HMAC_SHA512_LEN) {
			OK_set_error(ERR_ST_TLS_INVALID_MAC_LENGTH,
				     ERR_LC_TLS4, ERR_PT_TLS_MAC + 4, NULL);
			return false;
		}
		HMAC_SHA512(text_len, text, key_len, key, mac);
		break;

	default:
		assert(!"mac algorithm is unknown.");
	}

	return true;
}

bool tls_aad_generate(TLS * tls,
		      const struct tls_connection conn,
		      uint8_t *aad,
		      const enum tls_record_ctype type,
		      const int32_t clen)
{
	assert(aad != NULL);

	uint16_t ver = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (ver) {
	case TLS_VER_TLS12:
		;
		uint8_t text_tls12[AAD_SIZE_TLS12];

		/* Generate Additional Authenticated Data (AAD).
		 * Calculating formula is as follows (RFC 5246
		 * section 6.2.3.3).
		 *
		 * additional_data = seq_num + TLSCompressed.type +
		 *                   TLSCompressed.version + TLSCompressed.length;
		 */
		if (init_text_tls12(conn, tls->negotiated_version,
			      &(text_tls12[0]), AAD_SIZE_TLS12, type, NULL, clen) < 0) {
			OK_set_error(ERR_ST_TLS_INVALID_MAC_LENGTH,
				     ERR_LC_TLS4, ERR_PT_TLS_MAC + 5, NULL);
			return false;
		}

		memcpy(aad, text_tls12, AAD_SIZE_TLS12);
		break;

	case TLS_VER_TLS13:
		;
		uint8_t text_tls13[AAD_SIZE_TLS13];
		if (init_text_tls13(text_tls13, AAD_SIZE_TLS13,
				    tls->record_version, type, clen) < 0) {
			OK_set_error(ERR_ST_TLS_INTERNAL_ERROR,
				     ERR_LC_TLS4, ERR_PT_TLS_MAC + 6, NULL);
			return false;
		}

		memcpy(aad, text_tls13, AAD_SIZE_TLS13);
		break;

	default:
		return false;
	}

	return true;
}
