/*
 * Copyright (c) 2015-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "tls_cipher.h"
#include "tls_mac.h"
#include "tls_alert.h"

#include <string.h>

#ifdef HAVE_ARC4
/* for RC4_do_crypt and Key_RC4. */
#include <aicrypto/ok_rc4.h>
#endif

/* this file treates stream ciphers of tls record. stream ciphers has
 * following structure (RFC 5246 section 6.2.3.1.)
 *
 *   stream-ciphered struct generic_stream_cipher GenericStreamCipher;
 *
 *   struct generic_stream_cipher {
 *       opaque content[TLSCompressed.length];
 *       opaque MAC[SecurityParameters.mac_length];
 *   }
 */

/**
 * make one stream data.
 *
 * in this function, do not make ciphered text. do other part. for,
 * exmaple, do MAC calculation and so on.
 */
static int32_t write_stream(const TLS *tls,
			    uint8_t *dest,
			    const enum tls_record_ctype type,
			    const uint8_t *src,
			    const int32_t len);

/**
 * read one stream data.
 *
 * this function handles a deciphered text. check the MAC from
 * deciphered text. and, extract contents.
 */
static int32_t read_stream(TLS *tls,
			   uint8_t *dest,
			   const enum tls_record_ctype type,
			   const uint8_t *src,
			   const int32_t len);

static int32_t write_stream(const TLS *tls,
			    uint8_t *dest,
			    const enum tls_record_ctype type,
			    const uint8_t *src,
			    const int32_t len) {
	int32_t offset = 0;

	/* 1. process content. */
	if ((offset + len) > TLS_RECORD_CIPHERED_FRAGMENT_SIZE_MAX_UP_TO_TLS12) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_CIPHER_STREAM + 0, NULL);
		return -1;
	}

	memcpy(&(dest[0]), &(src[0]), len);
	offset += len;

	/* 2. process MAC. */
	int32_t mlen = tls->active_write.cipher.mac_length;
	if (mlen > 0) {
		uint8_t mac[mlen];

		if (! tls_mac_init(tls->active_write,
				   tls->negotiated_version,
				   mac, mlen, type, src, len)) {
			return -1;
		}

		if ((offset + mlen) > TLS_RECORD_CIPHERED_FRAGMENT_SIZE_MAX_UP_TO_TLS12) {
			OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
				     ERR_LC_TLS1,
				     ERR_PT_TLS_CIPHER_STREAM + 1, NULL);
			return -1;
		}

		memcpy(&(dest[offset]), &(mac[0]), mlen);
		offset += mlen;
	}

	return offset;
}

static int32_t read_stream(TLS *tls,
			   uint8_t *dest,
			   const enum tls_record_ctype type,
			   const uint8_t *src,
			   const int32_t len) {
	/* get MAC from fragment. */
	int32_t mac_length     = tls->active_read.cipher.mac_length;
	int32_t content_length = len - mac_length;

	if (content_length < 0) {
		OK_set_error(ERR_ST_TLS_INVALID_RECORD_LENGTH,
			     ERR_LC_TLS1, ERR_PT_TLS_CIPHER_STREAM + 2, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_DECODE_ERROR);
		return -1;
	}

	uint8_t mac[mac_length];
	memcpy(&(mac[0]), &(src[content_length]), mac_length);

	/* get content from fragment */
	memcpy(&(dest[0]), &(src[0]), content_length);

	/* verify MAC. */
	uint8_t tmp[mac_length];
	if (! tls_mac_init(tls->active_read,
			   tls->negotiated_version,
			   tmp, mac_length, type, dest, content_length)) {
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_INTERNAL_ERROR);
		return -1;
	}

	if (memcmp(mac, tmp, mac_length) != 0) {
		OK_set_error(ERR_ST_TLS_INVALID_MAC,
			     ERR_LC_TLS1, ERR_PT_TLS_CIPHER_STREAM + 3, NULL);
		TLS_ALERT_FATAL(tls, TLS_ALERT_DESC_BAD_RECORD_MAC);
		return -1;
	}

	return content_length;
}

int32_t tls_cipher_stream(TLS *tls,
			  uint8_t *dest,
			  const enum tls_record_ctype type,
			  const uint8_t *src,
			  const int32_t len) {
	/* init stream (contents + mac). save stream data to dest
	 * temporarily (stream length is unknown now). */
	int32_t stream_length;
	if ((stream_length = write_stream(tls, dest, type, src, len)) < 0) {
		return -1;
	}

	/* save stream data to stream array. */
	uint8_t stream[stream_length];
	memcpy(&(stream), &(dest[0]), stream_length);

	/* encipher stream */
	int32_t cipher_length;
	switch (tls->active_write.cipher.cipher_algorithm) {
	case TLS_BULK_CIPHER_NULL:
		/* A stream data has been already stored in the dest
		 * array. So do nothing here. */
		break;


#ifdef HAVE_ARC4
	case TLS_BULK_CIPHER_RC4: {
		RC4_do_crypt((Key_RC4 *)(tls->active_write.key),
			     stream_length, &(stream[0]), &(dest[0]));
	}
		break;
#endif

	default:
		assert(!"selected cipher algorithm is unknown.");
	}

	/* encryption function of aicrypto do not return length of
	 * encryption data. so consider the length same as input
	 * length. */
	cipher_length = stream_length;

	return cipher_length;
}

int32_t tls_decipher_stream(TLS *tls,
			    uint8_t *dest,
			    const enum tls_record_ctype type,
			    const uint8_t *src,
			    const int32_t len) {
	uint8_t stream[len];
	switch (tls->active_read.cipher.cipher_algorithm) {
	case TLS_BULK_CIPHER_NULL:
		memcpy(&(stream[0]), &(src[0]), len);
		break;

#ifdef HAVE_ARC4
	case TLS_BULK_CIPHER_RC4:
		RC4_do_crypt((Key_RC4 *)(tls->active_read.key),
			     len, (uint8_t *) &(src[0]), &(stream[0]));
		break;
#endif

	default:
		assert(!"selected cipher algorithm is unknown.");
	}

	int32_t content_length;
	if ((content_length = read_stream(
		     tls, dest, type, stream, len)) < 0) {
		return -1;
	}

	return content_length;
}
