/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_cipher.h"
#include "tls_mac.h"
#include "tls_alert.h"

#include <string.h>

/* for DES3_set_iv, DES3_cbc_decrypt and DES3_cbc_encrypt */
#include <aicrypto/ok_des.h>

/* for AES_set_iv, AES_cbc_decrypt and AES_cbc_encrypt */
#include <aicrypto/ok_aes.h>

/* CBC block Cipher has following structure (RFC 5246 section
 * 6.2.3.2).
 *
 *   struct {
 *       opaque IV[SecurityParameters.record_iv_length];
 *       block-ciphered struct {
 *           opaque content[TLSCompressed.length];
 *           opaque MAC[SecurityParameters.mac_length];
 *           uint8 padding[GenericBlockCipher.padding_length];
 *           uint8 padding_length;
 *       };
 *   } GenericBlockCipher;
 *
 *   IV
 *      The Initialization Vector (IV) SHOULD be chosen at random, and
 *      MUST be unpredictable.  Note that in versions of TLS prior to
 *      1.1, there was no IV field, and the last ciphertext block of the
 *      previous record (the "CBC residue") was used as the IV.  This
 *      was changed to prevent the attacks described in [CBCATT].  For
 *      block ciphers, the IV length is of length
 *      SecurityParameters.record_iv_length, which is equal to the
 *      SecurityParameters.block_size.
 *
 *   padding
 *      Padding that is added to force the length of the plaintext to be
 *      an integral multiple of the block cipher's block length.  The
 *      padding MAY be any length up to 255 bytes, as long as it results
 *      in the TLSCiphertext.length being an integral multiple of the
 *      block length.  Lengths longer than necessary might be desirable
 *      to frustrate attacks on a protocol that are based on analysis of
 *      the lengths of exchanged messages.  Each uint8 in the padding
 *      data vector MUST be filled with the padding length value.  The
 *      receiver MUST check this padding and MUST use the bad_record_mac
 *      alert to indicate padding errors.
 *
 *   padding_length
 *      The padding length MUST be such that the total size of the
 *      GenericBlockCipher structure is a multiple of the cipher's block
 *      length.  Legal values range from zero to 255, inclusive.  This
 *      length specifies the length of the padding field exclusive of
 *      the padding_length field itself.
 */

/**
 * make block of block cipher.
 */
static int32_t init_block(TLS *tls,
			  uint8_t *dest,
			  const enum tls_record_ctype type,
			  const uint8_t *src,
			  const int32_t len);

/**
 * read ciphered block and decrypto it.
 */
static int32_t read_block(TLS *tls,
			  uint8_t *dest,
			  const uint8_t *src, const uint32_t src_len,
			  uint8_t *iv);

/**
 * read decrypted block and get content from that.
 *
 * in this function, to avoid the timing attack, this function is sure
 * to do MAC calculation any cases.
 */
static int32_t read_content(TLS *tls,
			    uint8_t *dest,
			    const enum tls_record_ctype type,
			    const uint8_t *src,
			    const int32_t len);

static int32_t init_block(TLS *tls,
			  uint8_t *dest,
			  const enum tls_record_ctype type,
			  const uint8_t *src,
			  const int32_t len) {
	uint32_t offset = 0;

	/* 1. process content. */
	if ((offset + len) > TLS_RECORD_CIPHERED_FRAGMENT_SIZE_MAX_UP_TO_TLS12) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_CIPHER_BLOCK + 0, NULL);
		return -1;
	}

	memcpy(&(dest[offset]), &(src[0]), len);
	offset += len;

	/* 2. process MAC. */
	int32_t mlen = tls->active_write.cipher.mac_length;
	uint8_t mac[mlen];
	if (mlen > 0) {
		if (! tls_mac_init(tls->active_write,
				   tls->negotiated_version,
				   mac, mlen, type, src, len)) {
			return -1;
		}

		if ((offset + mlen) > TLS_RECORD_CIPHERED_FRAGMENT_SIZE_MAX_UP_TO_TLS12) {
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS1,
				     ERR_PT_TLS_CIPHER_BLOCK + 1, NULL);
			return -1;
		}

		memcpy(&(dest[offset]), &(mac[0]), mlen);
		offset += mlen;
	}

	/* 3. proces padding. */
	uint32_t block_length = tls->active_write.cipher.block_length;

	/* calculate padding length.
	 *
	 * padding lenght make that the block become the integral
	 * multiple of the block_length.
	 *
	 * | contents       (len)  |
	 * | MAC            (mlen) |
	 * | padding        (n)    |
	 * | padding length (1)    |
	 *
	 * calculating formula is as follows.
	 *
	 *  1. (contents + MAC) % block_length
	 *
	 *  2. block_length - padding_length - (result of 1)
	 */
	uint32_t padding_length = block_length - 1 - (offset % block_length);

	if ((offset + padding_length + 1) >
	    TLS_RECORD_CIPHERED_FRAGMENT_SIZE_MAX_UP_TO_TLS12) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_CIPHER_BLOCK + 2, NULL);
		return -1;
	}

	/* fill padding. */
	for (uint32_t i = 0; i < padding_length; ++i) {
		dest[offset++] = padding_length & 0xff;
	}

	/* fill padding length. */
	dest[offset++] = padding_length & 0xff;

	TLS_DPRINTF("block: content length = %d", len);
	TLS_DPRINTF("block: mac     length = %d", mlen);
	TLS_DPRINTF("block: block   length = %d", block_length);
	TLS_DPRINTF("block: padding length = %d", padding_length);

	return offset;
}

int32_t tls_cipher_block(TLS *tls,
			 uint8_t *dest,
			 const enum tls_record_ctype type,
			 const uint8_t *src,
			 const int32_t len) {
	uint32_t offset = 0;

	uint32_t record_iv_length = tls->active_write.cipher.record_iv_length;
	uint8_t *iv = NULL;

	TLS_DPRINTF("block: record_iv_length = %d", record_iv_length);

	switch(tls->negotiated_version.minor) {
	case TLS_MINOR_SSL30:
	case TLS_MINOR_TLS10:
		/* XXX: not implemented */
		break;

	case TLS_MINOR_TLS11:
	case TLS_MINOR_TLS12:
		/* generate IV with random value. */
		if (! tls_util_get_random(&(dest[offset]), record_iv_length)) {
			OK_set_error(ERR_ST_TLS_GET_RANDOM, ERR_LC_TLS1,
				     ERR_PT_TLS_CIPHER_BLOCK + 3, NULL);
			return -1;
		}
		iv = dest;
		offset += record_iv_length;
		break;

	default:
		break;
	}

	if (iv == NULL) {
		return -1;
	}

	/* generate block. */
	uint8_t block[TLS_RECORD_CIPHERED_FRAGMENT_SIZE_MAX_UP_TO_TLS12];
	int32_t block_len;
	if ((block_len = init_block(tls, &(block[0]), type, src, len)) < 0) {
		return -1;
	}

	switch (tls->active_write.cipher.cipher_algorithm) {
	case TLS_BULK_CIPHER_3DES:
		DES3_set_iv((Key_3DES *)(tls->active_write.key), &(iv[0]));
		DES3_cbc_encrypt((Key_3DES *)(tls->active_write.key),
				 block_len, &(block[0]), &(dest[offset]));
		break;

	case TLS_BULK_CIPHER_AES:
		AES_set_iv((Key_AES *)(tls->active_write.key), &(iv[0]));
		AES_cbc_encrypt((Key_AES *)(tls->active_write.key),
				 block_len, &(block[0]), &(dest[offset]));
		break;

	default:
		assert(!"selected cipher algorithm is unknown.");
	}

	/* encryption function of aicrypto do not return length of
	 * encrypted data. so, consider the length same as input
	 * length. */
	offset += block_len;

	return offset;
}

static int32_t read_block(TLS *tls,
			  uint8_t *dest,
			  const uint8_t *src, const uint32_t src_len,
			  uint8_t *iv) {
	uint8_t block[src_len];

	/* declare block array because arg of AES_cbc_decrypt and
	 * DES3_cbc_decrypt that correspond to src are not const. so, to
	 * avoid warning, copy src to block. */
	memcpy(&(block[0]), &(src[0]), src_len);

	switch (tls->active_read.cipher.cipher_algorithm) {
	case TLS_BULK_CIPHER_3DES:
		DES3_set_iv((Key_3DES *)(tls->active_read.key), &(iv[0]));
		DES3_cbc_decrypt((Key_3DES *)(tls->active_read.key),
				 src_len, &(block[0]), &(dest[0]));
		break;

	case TLS_BULK_CIPHER_AES:
		AES_set_iv((Key_AES *)(tls->active_read.key), &(iv[0]));
		AES_cbc_decrypt((Key_AES *)(tls->active_read.key),
				src_len, &(block[0]), &(dest[0]));
		break;

	default:
		assert(!"selected cipher algorithm is unknown.");
	}

	return src_len;
}

static int32_t read_content(TLS *tls,
			    uint8_t *dest,
			    const enum tls_record_ctype type,
			    const uint8_t *src,
			    const int32_t len) {
	/* in this routine, must perform mac calculation certainly. this
	 * is by RFC 5246 section 6.2.3.2.
	 *
	 *   Implementation note: Canvel et al. [CBCTIME] have
	 *   demonstrated a timing attack on CBC padding based on the
	 *   time required to compute the MAC.  In order to defend
	 *   against this attack, implementations MUST ensure that
	 *   record processing time is essentially the same whether or
	 *   not the padding is correct.  In general, the best way to do
	 *   this is to compute the MAC even if the padding is
	 *   incorrect, and only then reject the packet.
	 */

	bool status = true;

	/* check padding. */
	uint32_t padding_length_bytes = 1;
	uint32_t padding_length       = src[len - padding_length_bytes];
	for (uint32_t i = 1; i < padding_length; ++i) {
		if ((int32_t)(len - i) < 0) {
			status = false;
			padding_length = 0;
			padding_length_bytes = 0;
			break;
		}

		if (src[len - i] != padding_length) {
			/* XXX: unmatched padding */
			status = false;
			padding_length = 0;
			padding_length_bytes = 0;
			break;
		}
	}

	/*
	 * | contents       (content_length)       |
	 * | MAC            (mac_length)           |
	 * | padding        (padding_length)       |
	 * | padding length (padding_length_bytes) |
	 *  */

	uint32_t mac_length    = tls->active_read.cipher.mac_length;
	int32_t content_length = (len - padding_length_bytes
				  - padding_length - mac_length);
	if (content_length < 0) {
		/* invalid content length. */
		content_length = len;
		status = false;
	}

	/* get content. */
	memcpy(&(dest[0]), &(src[0]), content_length);

	/* get mac. */
	uint8_t mac[mac_length];
	if (mac_length > (uint32_t)len) {
		/* invalid mac length. */
		memset(&(mac[0]), 0, mac_length);
		status = false;
	} else {
		memcpy(&(mac[0]), &(src[content_length]), mac_length);
	}

	/* calc mac from content. */
	uint8_t calced_mac[mac_length];
	if (! tls_mac_init(tls->active_read,
			   tls->negotiated_version,
			   calced_mac, mac_length,
			   type, dest, content_length)) {
		/* calculating mac failed. */
		status = false;
	}

	/* verify mac. */
	if (memcmp(&(mac[0]), &(calced_mac[0]), mac_length) != 0) {
		status = false;
	}

	/* if error has occurred in the above process, return error even
	 * if mac matched by any chance. */
	if (status == false) {
		return -1;
	}

	return content_length;
}

int32_t tls_decipher_block(TLS *tls,
			   uint8_t *dest,
			   const enum tls_record_ctype type,
			   const uint8_t *src,
			   const int32_t len) {
	uint32_t offset = 0;

	/* read IV. */
	uint32_t record_iv_length = tls->active_read.cipher.record_iv_length;
	uint8_t iv[record_iv_length];

	switch(tls->negotiated_version.minor) {
	case TLS_MINOR_SSL30:
	case TLS_MINOR_TLS10:
		/* XXX: not implemented */
		break;

	case TLS_MINOR_TLS11:
	case TLS_MINOR_TLS12:
		memcpy(&(iv[0]), &(src[offset]), record_iv_length);
		offset += record_iv_length;
		break;

	default:
		break;
	}

	TLS_DPRINTF("block: record_iv_length = %d", record_iv_length);

	/* read encrypted block. */
	uint32_t block_len = len - offset;
	uint8_t  block[block_len];
	int32_t  skip;
	if ((skip = read_block(tls,
			       &(block[0]), &(src[offset]), block_len,
			       &(iv[0]))) < 0) {
		/* should not happen. decryption function of aicrypto
		 * does not return any error.  */
		return -1;
	}
	offset += skip;

	/* read content from decrypted bock */
	int32_t content_length;
	if ((content_length = read_content(tls, &(dest[0]),
					   type,
					   &(block[0]), block_len)) < 0) {
		OK_set_error(ERR_ST_TLS_READ_BLOCK,
			     ERR_LC_TLS1, ERR_PT_TLS_CIPHER_BLOCK + 4, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_BAD_RECORD_MAC);
		return -1;
	}

	return content_length;
}
