/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_cipher.h"
#include "tls_mac.h"
#include "tls_alert.h"

#include <aicrypto/nrg_modes.h>

/* for AES_set_iv, AES_gcm_decrypt and AES_gcm_encrypt */
#include <aicrypto/ok_aes.h>

/* for chacha20_poly1305_encrypt */
#include <aicrypto/nrg_chacha.h>

static void copy_seqnum_to_buf(uint64_t num, uint8_t *buf);

typedef struct GCMNonce{
	uint8_t salt[4];
	uint8_t nonce_explicit[8];
} GCMNonce_t;

int32_t tls_cipher_aead(TLS *tls,
			uint8_t *dest,
			const enum tls_record_ctype type,
			const uint8_t *src,
			const int32_t len)
{
	uint32_t offset = 0;

	uint32_t fixed_iv_length = tls->active_write.cipher.fixed_iv_length;
	uint32_t record_iv_length = tls->active_write.cipher.record_iv_length;
	GCMNonce_t gcm_nonce;
	uint8_t *nonce;
	gcm_param_t gcm_param;
	Key *key = tls->active_write.key;
	int rc;
	uint16_t ciphered_fragment_size_max;
	enum tls_record_ctype aad_type;
	uint16_t aad_len;
	uint16_t aad_arg_len;

	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
		/* XXX: not implemented */
		return -1;

	case TLS_VER_TLS12:
		ciphered_fragment_size_max =
		    TLS_RECORD_CIPHERED_FRAGMENT_SIZE_MAX_UP_TO_TLS12;
		aad_type = type;
		aad_len = AAD_SIZE_TLS12;
		aad_arg_len = len;
		break;

	case TLS_VER_TLS13:
		ciphered_fragment_size_max =
		    TLS_RECORD_CIPHERED_FRAGMENT_SIZE_MAX_TLS13;
		aad_type = TLS_CTYPE_APPLICATION_DATA;
		aad_len = AAD_SIZE_TLS13;
		aad_arg_len = len + TLS_AES_GCM_AUTHENTICATION_TAG_SIZE;
		break;

	default:
		return -1;
	}

	if ((offset + len) > ciphered_fragment_size_max) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_CIPHER_AEAD + 0, NULL);
		return -1;
	}

	uint8_t nonce_buf[TLS_AES_GCM_NONCE_SIZE];
	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
		/* XXX: not implemented */
		return -1;

	case TLS_VER_TLS12:
		switch (tls->active_write.cipher.cipher_algorithm) {
		case TLS_BULK_CIPHER_AES:
			memcpy(&(gcm_nonce.salt[0]), ((Key_AES *) key)->iv,
			       fixed_iv_length);
			break;

		default:
			assert(!"selected cipher algorithm is unknown.");
		}

		/* 1. process nonce_explicit. */
		copy_seqnum_to_buf(tls->active_write.seqnum,
				   &(gcm_nonce.nonce_explicit[0]));
		memcpy(&(dest[offset]), &(gcm_nonce.nonce_explicit[0]),
		       record_iv_length);
		offset += record_iv_length;

		/*
		 * dest buffer size: TLS_RECORD_CIPHERED_FRAGMENT_SIZE_MAX_UP_TO_TLS12
		 */
		if (len + record_iv_length + TLS_AES_GCM_AUTHENTICATION_TAG_SIZE
		    > ciphered_fragment_size_max) {
			/* TODO */
			;
		}

		nonce = (uint8_t *) &gcm_nonce;
		break;

	case TLS_VER_TLS13:
		/*
		 * RFC8446 5.3.  Per-Record Nonce
		 *
		 *                          ...   The per-record nonce for the AEAD
		 *    construction is formed as follows:
		 *
		 *    1.  The 64-bit record sequence number is encoded in network byte
		 *        order and padded to the left with zeros to iv_length.
		 *
		 *    2.  The padded sequence number is XORed with either the static
		 *        client_write_iv or server_write_iv (depending on the role).
		 *
		 *    The resulting quantity (of length iv_length) is used as the
		 *    per-record nonce.
		 */
		memset(nonce_buf, 0, TLS_AES_GCM_NONCE_SIZE);
		copy_seqnum_to_buf(tls->active_write.seqnum, &(nonce_buf[4]));

		uint8_t *iv;
		switch (tls->active_read.cipher.cipher_algorithm) {
		case TLS_BULK_CIPHER_AES:
			iv = ((Key_AES *) key)->iv;
			break;

		case TLS_BULK_CIPHER_CHACHA20:
			iv = ((Key_ChaCha *) key)->iv;
			break;

		default:
			TLS_DPRINTF("selected cipher algorithm unknown");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		uint8_t iv_length = fixed_iv_length + record_iv_length;
		for (int i = 0; i < iv_length; i++) {
			nonce_buf[i] = nonce_buf[i] ^ iv[i];
		}

		nonce = nonce_buf;
		break;

	default:
		return -1;
	}

	/* 2. process content. */
	uint8_t aad[aad_len];
	if (!tls_aad_generate(tls, tls->active_write,
			      aad, aad_type, aad_arg_len)) {
		TLS_DPRINTF("aead: internal error");
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	/*
	 *    AEADEncrypted = AEAD-Encrypt(write_key, nonce, plaintext,
	 *                                 additional_data)
	 */
	switch (tls->active_write.cipher.cipher_algorithm) {
	case TLS_BULK_CIPHER_AES:
		gcm_param_set_key(&gcm_param, key);
		gcm_param_set_iv(&gcm_param, nonce, TLS_AES_GCM_NONCE_SIZE);
		gcm_param_set_aad(&gcm_param, aad, sizeof(aad));

		rc = AES_gcm_encrypt(&gcm_param, len, src, &(dest[offset]));
		if (rc == -1) {
			TLS_DPRINTF("aead: internal error");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}
		break;

	case TLS_BULK_CIPHER_CHACHA20: {
		uint8_t *chacha_key = ((Key_ChaCha *)key)->key;
		uint8_t tag_len = TLS_CHACHA20_POLY1305_AUTHENTICATION_TAG_SIZE;
		uint8_t tag[tag_len];
		chacha20_poly1305_encrypt(aad, aad_len, chacha_key, nonce, src,
					  len, dest, tag);

		memcpy(&(dest[len]), tag, tag_len);
		break;
	}

	default:
		assert(!"selected cipher algorithm is unknown.");
	}

	offset += len + TLS_AES_GCM_AUTHENTICATION_TAG_SIZE;
	return offset;
}

int32_t tls_decipher_aead(TLS *tls,
			  uint8_t *dest,
			  const enum tls_record_ctype type,
			  const uint8_t *src,
			  const int32_t len)
{
	uint32_t offset = 0;

	/* read nonce_explicit. */
	int32_t fixed_iv_length = tls->active_read.cipher.fixed_iv_length;
	int32_t record_iv_length = tls->active_read.cipher.record_iv_length;
	GCMNonce_t gcm_nonce;
	uint8_t *nonce;
	gcm_param_t gcm_param;
	Key *key = tls->active_read.key;
	int rc;
	uint16_t aad_len;
	uint16_t aad_arg_len;

	uint8_t nonce_buf[TLS_AES_GCM_NONCE_SIZE];
	uint16_t version = tls_util_convert_protover_to_ver(
		&(tls->negotiated_version));
	switch (version) {
	case TLS_VER_SSL30:
	case TLS_VER_TLS10:
	case TLS_VER_TLS11:
		/* XXX: not implemented */
		return -1;

	case TLS_VER_TLS12:
		if (len - record_iv_length - 16 < 0) {
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS1, ERR_PT_TLS_CIPHER_AEAD + 1, NULL);
			return -1;
		}

		memcpy(&(gcm_nonce.nonce_explicit[0]), &(src[offset]),
		       record_iv_length);
		offset += record_iv_length;

		switch (tls->active_read.cipher.cipher_algorithm) {
		case TLS_BULK_CIPHER_AES:
			memcpy(&(gcm_nonce.salt[0]), ((Key_AES *) key)->iv,
			       fixed_iv_length);
			break;

		default:
			assert(!"selected cipher algorithm is unknown.");
		}

		aad_len = AAD_SIZE_TLS12;
		aad_arg_len = len - TLS_AES_GCM_AUTHENTICATION_TAG_SIZE
		    - record_iv_length;
		nonce = (uint8_t *) &gcm_nonce;
		break;

	case TLS_VER_TLS13:
		/*
		 * RFC8446 5.3.  Per-Record Nonce
		 *
		 *                          ...   The per-record nonce for the AEAD
		 *    construction is formed as follows:
		 *
		 *    1.  The 64-bit record sequence number is encoded in network byte
		 *        order and padded to the left with zeros to iv_length.
		 *
		 *    2.  The padded sequence number is XORed with either the static
		 *        client_write_iv or server_write_iv (depending on the role).
		 *
		 *    The resulting quantity (of length iv_length) is used as the
		 *    per-record nonce.
		 */
		memset(nonce_buf, 0, TLS_AES_GCM_NONCE_SIZE);
		copy_seqnum_to_buf(tls->active_read.seqnum, &(nonce_buf[4]));

		uint8_t *iv;
		switch (tls->active_read.cipher.cipher_algorithm) {
		case TLS_BULK_CIPHER_AES:
			iv = ((Key_AES *) key)->iv;
			break;

		case TLS_BULK_CIPHER_CHACHA20:
			iv = ((Key_ChaCha *) key)->iv;
			break;

		default:
			TLS_DPRINTF("selected cipher algorithm unknown");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
			return -1;
		}

		uint8_t iv_length = fixed_iv_length + record_iv_length;
		for (int i = 0; i < iv_length; i++) {
			nonce_buf[i] = nonce_buf[i] ^ iv[i];
		}

		aad_len = AAD_SIZE_TLS13;
		aad_arg_len = len;
		nonce = nonce_buf;
		break;

	default:
		return -1;
	}

	uint8_t aad[aad_len];
	if (!tls_aad_generate(tls, tls->active_read,
			      aad, type, aad_arg_len)) {
		return -1;
	}

	/* 1. process content. */
	/*
	 *    TLSCompressed.fragment = AEAD-Decrypt(write_key, nonce,
	 *                                          AEADEncrypted,
	 *                                          additional_data)
	 */
	/*
	 * RFC8446 5.2.  Record Payload Protection
	 *
	 *   If the decryption fails, the receiver MUST terminate the connection
	 *   with a "bad_record_mac" alert.
	 */
	switch (tls->active_read.cipher.cipher_algorithm) {
	case TLS_BULK_CIPHER_AES:
		gcm_param_set_key(&gcm_param, key);
		gcm_param_set_iv(&gcm_param, nonce, TLS_AES_GCM_NONCE_SIZE);
		gcm_param_set_aad(&gcm_param, aad, sizeof(aad));

		rc = AES_gcm_decrypt(&gcm_param, len-offset, &(src[offset]),
				     &(dest[0]));
		if (rc == 0) {
			break;
		}

		if (rc == -1) {
			TLS_DPRINTF("aead: internal error");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		}
		if (rc == -2) {
			TLS_DPRINTF("aead: decryption fails.");
			OK_set_error(ERR_ST_TLS_DECRYPT_AEAD_FAILED,
				     ERR_LC_TLS1, ERR_PT_TLS_CIPHER_AEAD + 2,
				     NULL);
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_BAD_RECORD_MAC);
		}
		return -1;

	case TLS_BULK_CIPHER_CHACHA20: {
		uint8_t tag_len = TLS_CHACHA20_POLY1305_AUTHENTICATION_TAG_SIZE;
		uint8_t *chacha_key = ((Key_ChaCha *)key)->key;
		uint8_t tag[tag_len];
		chacha20_poly1305_decrypt(aad, aad_len, chacha_key, nonce, src,
					  len - tag_len, dest, tag);

		if ((rc = memcmp(&(src[len - tag_len]), tag, tag_len)) != 0) {
			TLS_DPRINTF("aead: decryption fails.");
			TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_BAD_RECORD_MAC);
			return -1;
		}
	}
		break;

	default:
		assert(!"selected cipher algorithm is unknown.");
	}

	return len - offset - TLS_AES_GCM_AUTHENTICATION_TAG_SIZE;
}

static void copy_seqnum_to_buf(uint64_t num, uint8_t *buf)
{
	uint32_t off = 0;

	assert(buf);

	buf[off++]  = (num >> 56) & 0xff;
	buf[off++]  = (num >> 48) & 0xff;
	buf[off++]  = (num >> 40) & 0xff;
	buf[off++]  = (num >> 32) & 0xff;
	buf[off++]  = (num >> 24) & 0xff;
	buf[off++]  = (num >> 16) & 0xff;
	buf[off++]  = (num >>  8) & 0xff;
	buf[off++]  = (num)       & 0xff;
}
