/* ok_rsa.h */
/*
 * Modified by National Institute of Informatics in Japan, 2013-2017.
 *
 */
/*
 * Copyright (C) 1998-2002
 * Akira Iwata & Takuto Okuno
 * Akira Iwata Laboratory,
 * Nagoya Institute of Technology in Japan.
 *
 * All rights reserved.
 *
 * This software is written by Takuto Okuno(usapato@anet.ne.jp)
 * And if you want to contact us, send an email to Kimitake Wakayama
 * (wakayama@elcom.nitech.ac.jp)
 *
 * Redistribution and use in source and binary forms, with or without modification,
 * are permitted provided that the following conditions are met:
 *
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 *
 * 3. All advertising materials mentioning features or use of this software must
 *    display the following acknowledgment:
 *    "This product includes software developed by Akira Iwata Laboratory,
 *    Nagoya Institute of Technology in Japan (http://mars.elcom.nitech.ac.jp/)."
 *
 * 4. Redistributions of any form whatsoever must retain the following
 *    acknowledgment:
 *    "This product includes software developed by Akira Iwata Laboratory,
 *     Nagoya Institute of Technology in Japan (http://mars.elcom.nitech.ac.jp/)."
 *
 *   THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY.
 *   AKIRA IWATA LABORATORY DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS
 *   SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS,
 *   IN NO EVENT SHALL AKIRA IWATA LABORATORY BE LIABLE FOR ANY SPECIAL,
 *   INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
 *   FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
 *   NEGLIGENCE OR OTHER TORTUOUS ACTION, ARISING OUT OF OR IN CONNECTION
 *   WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 */

/**
 * @file ok_rsa.h
 * This file defines functions, structures and macros to RSA encryption.
 */

/**
 * @defgroup rsa RSA
 * This module provides an API for RSA public key encryption.
 *
 * For the specifications of RSA, see IEEE P1363.
 */

#ifndef INCLUSION_GUARD_UUID_614A7D71_7E80_426F_A570_64F51F84160E
#define INCLUSION_GUARD_UUID_614A7D71_7E80_426F_A570_64F51F84160E

#include <aicrypto/large_num.h>
#include <aicrypto/ok_x509.h>	/* "Key" object composition */

#ifdef __cplusplus
extern "C" {
#endif

/**
 * RSA public key object.
 * This structure has the key information of RSA public key.
 * @ingroup rsa
 */
typedef struct Public_key_RSA {
	int key_type;	/**< key identifier */
	int size;	/**< key length in bytes */

	LNm *n;		/**< public modulus */
	LNm *e;		/**< public exponent */
} Pubkey_RSA;

/**
 * RSA private key object.
 * This structure has the key information of RSA private key.
 * @ingroup rsa
 */
typedef struct Private_key_RSA {
	int key_type;	/**< key identifier */
	int size;	/**< key length in bytes */

	int version;
	LNm *n;		/**< public modulus */
	LNm *e;		/**< public exponent */
	LNm *d;		/**< private exponent */
	LNm *p;		/**< prime1 */
	LNm *q;		/**< prime2 */
	LNm *e1;	/**< exponent1 -- d mod (p-1) */
	LNm *e2;	/**< exponent2 -- d mod (q-1) */
	LNm *cof;	/**< coefficient -- (q-1) mod p */

	/** DER encode string */
	unsigned char *der;
} Prvkey_RSA;

/* set minimum and maximum key bit lengths */
#define RSA_KEY_BITLENGTH_MIN	512
#define RSA_KEY_BITLENGTH_MAX	(LN_BITLENGTH_MAX / 2)

/* Encoding Methods for Signatures with Appendix */
#define RSA_EMSA_PSS   1
#define RSA_EMSA_PKCS1 2

/* rsa.c */
/**
 * RSA encryption/decryption with public key.
 * 
 * This function performs an RSA encryption/decryption with public
 * key.
 *
 * @param[in] len length of @a in text in bytes.
 * @param[in] from plaintext if encryption. ciphertext if decryption.
 * @param[out] to ciphertext if encryption. plaintext if decryption.
 * @param[in] key Pubkey_RSA object.
 * @retval 0 success
 * @retval -1 error
 * @par Error state:
 * If the error is caused by bad parameter, this function
 * sets the error state to ERR_ST_BADPARAM.
 * @ingroup rsa
 */
int RSApub_doCrypt(int len, unsigned char *from, unsigned char *to,
		   Pubkey_RSA *key);

/**
 * RSA encryption/decryption with private key.
 *
 * This function performs an encryption/decryption with private key.
 *
 * @param[in] len length of @a in text in bytes.
 * @param[in] from plaintext if encryption. ciphertext if decryption.
 * @param[out] to ciphertext if encryption. plaintext if decryption.
 * @param[in] key Prvkey_RSA object.
 * @retval 0 success
 * @retval -1 error
 * @par Error state:
 * If the error is caused by bad parameter, this function
 * sets the error state to ERR_ST_BADPARAM.
 * @ingroup rsa
 */
int RSAprv_doCrypt(int len, unsigned char *from, unsigned char *to,
		   Prvkey_RSA *key);

/**
 * Set RSA public key parameters.
 *
 * This function sets the following parameter to @a key.
 * - Pubkey_RSA::n
 * - Pubkey_RSA::e
 *
 * @param[in,out] key Pubkey_RSA object to be set.
 * @param[in] n number of public module.
 * @param[in] e number of public exponent.
 * @ingroup rsa
 */
void RSA_set_pubkey(Pubkey_RSA *key, LNm *n, LNm *e);

/**
 * Set RSA private key parameters.
 *
 * This function sets the following parameter to @a key.
 * - Prvkey_RSA::n
 * - Prvkey_RSA::d
 *
 * @param[in,out] key Prvkey_RSA object to be set.
 * @param[in] n number of public module.
 * @param[in] d number of private exponent.
 * @ingroup rsa
 */
void RSA_set_prvkey(Prvkey_RSA *key, LNm *n, LNm *d);

/* old compatible name */
#define OK_RSA_docrypt_pubkey	RSApub_doCrypt
#define OK_RSA_docrypt_prvkey	RSAprv_doCrypt
#define OK_RSA_set_pubkey	RSA_set_pubkey
#define OK_RSA_set_prvkey	RSA_set_prvkey

/**
 * Get RSA encoding methods.
 *
 * @param[in] sig_algo specifies a signature algorithm id (ex.OBJ_SIG_SHA1RSA).
 * @retval RSA_EMSA_PSS EMSA-PSS encoding
 * @retval RSA_EMSA_PKCS1 EMSA-PKCS1-v1_5 encoding
 * @retval -1 error
 * @ingroup rsa
 */
int RSA_get_encoding_method(int sig_algo);

/* rsa_asn1.c */
/**
 * Convert an RSA private key to DER encode string.
 *
 * This function converts a DER encode string from @a prv.
 *
 * @param[in] prv Prvkey_RSA object.
 * @param[in,out] buf buffer pointer or NULL.
 *                if @a buf is not NULL, then set a pointer of return
 *                value to @a buf.
 * @param[out] ret_len length of DER encode string in bytes.
 * @returns DER encode string or NULL if error.
 * @par Error state:
 * If the error is caused by memory allocation, this function
 * sets the error state to ERR_ST_MEMALLOC.
 * @ingroup rsa
 */
unsigned char *RSAprv_toDER(Prvkey_RSA *prv, unsigned char *buf, int *ret_len);

/**
 * Convert an RSA public key to DER encode string.
 *
 * This function converts an RSA public key to DER encode string.
 *
 * @param[in] pub Pubkey_RSA object.
 * @param[in,out] buf buffer pointer or NULL.
 *                if @a buf is not NULL, then set a pointer of return
 *                value to @a buf.
 * @param[out] ret_len length of DER encode string in bytes.
 * @returns DER encode string or NULL if error.
 * @par Error state:
 * If the error is caused by memory allocation, this function
 * sets the error state to ERR_ST_MEMALLOC.
 * @ingroup rsa
 */
unsigned char *RSApub_toDER(Pubkey_RSA *pub, unsigned char *buf, int *ret_len);

/* rsa_key.c */
/**
 * Allocate a Pubkey_RSA object.
 *
 * This function allocates and initializes a Pubkey_RSA object.
 *
 * @returns newly allocated Pubkey_RSA object or NULL error.
 * @par Error state:
 * If the error is caused by memory allocation, this function
 * sets the error state to ERR_ST_MEMALLOC.
 * @attention
 * If the Pubkey_RSA object is no longer used, free the object by
 * RSAkey_free().
 * @ingroup rsa
 */
Pubkey_RSA *RSApubkey_new(void);

/**
 * Allocate a Prvkey_RSA object.
 *
 * This function allocates and initializes a Prvkey_RSA object.
 *
 * @returns newly allocated Prvkey_RSA object or NULL if error.
 * @par Error state:
 * If the error is caused by memory allocation, this function
 * sets the error state to ERR_ST_MEMALLOC.
 * @attention
 * If the Prvkey_RSA object is no longer used, free the object by
 * RSAkey_free().
 * @ingroup rsa
 */
Prvkey_RSA *RSAprvkey_new(void);

/**
 * Free an RSA key object (Prvkey_RSA or Pubkey_RSA).
 *
 * This function frees an RSA key object.
 *
 * @param[in] key RSA key object to be free. 
 *            Prvkey_RSA or Pubkey_RSA.
 * @ingroup rsa
 */
void RSAkey_free(Key *key);

/**
 * Generate an RSA private key.
 *
 * This function generates an RSA private key and set the key to @a ret.
 *
 * @param[out] ret generated private key.
 * @param[in] byte half size of key length in bytes.
 * @ingroup rsa
 */
int RSAprv_generate(Prvkey_RSA *ret, int byte);

/**
 * Copy public key parameters from Prvkey_RSA to Pubkey_RSA.
 *
 * This function copies public key parameters from @a prv to @a pub.
 *
 * @param[in] prv Prvkey_RSA to copy from.
 * @param[out] pub Pubkey_RSA to copy to.
 * @ingroup rsa
 */
void RSAprv_2pub(Prvkey_RSA *prv, Pubkey_RSA *pub);

/**
 * Duplicate a Pubkey_RSA object.
 *
 * This function duplicates a Pubkey_RSA object from @a src.
 *
 * @param[in] src Pubkey_RSA object to duplicate from.
 * @returns newly duplicated Pubkey_RSA object or NULL if error.
 * @par Error state:
 * - If the error is caused by null pointer, this function
 * sets the error state to ERR_ST_NULLPOINTER.
 * - If the error is caused by memory allocation, this function
 * sets the error state to ERR_ST_MEMALLOC.
 * @attention
 * If the duplicated object is no longer used, free the object by
 * RSAkey_free().
 * @ingroup rsa
 */
Pubkey_RSA *RSApubkey_dup(Pubkey_RSA *src);

/**
 * Duplicate a Prvkey_RSA object.
 *
 * This function duplicates a Prvkey_RSA object from @a src.
 *
 * @param[in] src Prvkey_RSA object to duplicate from.
 * @returns newly duplicated Prvkey_RSA object or NULL if error.
 * @par Error state:
 * - If the error is caused by null pointer, this function
 * sets the error state to ERR_ST_NULLPOINTER.
 * - If the error is caused by memory allocation, this function
 * sets the error state to ERR_ST_MEMALLOC.
 * @attention
 * If the duplicated object is no longer used, free the object by
 * RSAkey_free().
 * @ingroup rsa
 */
Prvkey_RSA *RSAprvkey_dup(Prvkey_RSA *src);

/**
 * Compare two RSA public keys.
 *
 * This function compares the two RSA public keys @a k1 and @a k2.
 *
 * @param[in] k1 RSA public key 1.
 * @param[in] k2 RSA public key 2.
 * @returns 0 if keys are the same, 1 or -1 if different.
 * @ingroup rsa
 */
int RSApubkey_cmp(Pubkey_RSA *k1, Pubkey_RSA *k2);

/**
 * Compare two RSA private keys.
 *
 * This function compares the two RSA private keys @a k1 and @a k2.
 *
 * @param[in] k1 RSA private key 1.
 * @param[in] k2 RSA private key 2.
 * @returns 0 if keys are the same, 1 or -1 if different.
 * @ingroup rsa
 */
int RSAprvkey_cmp(Prvkey_RSA *k1, Prvkey_RSA *k2);

/**
 * Validate a pair of RSA keys.
 *
 * This function validates a pair of RSA keys to make sure a valid pair.
 *
 * @param[in] prv RSA private key.
 * @param[in] pub RSA public key.
 * @returns 0 if keys are the valid pair, 1 or -1 if invalid pair.
 * @ingroup rsa
 */
int RSA_pair_cmp(Prvkey_RSA *prv, Pubkey_RSA *pub);


/* rsa_pss.c */

/**
 * type for RSASSA-PSS-params.
 *
 * AlgorithmIdentifier.parameters for id-RSASSA-PSS.
 *
 * RSASSA-PSS-params ::= SEQUENCE {
 *     hashAlgorithm      [0] HashAlgorithm      DEFAULT sha1,
 *     maskGenAlgorithm   [1] MaskGenAlgorithm   DEFAULT mgf1SHA1,
 *     saltLength         [2] INTEGER            DEFAULT 20,
 *     trailerField       [3] TrailerField       DEFAULT trailerFieldBC
 * }
 */
typedef struct rsassa_pss_params {
	int hashAlgorithm;
	int maskGenAlgorithm;
	int saltLength;
	int trailerField;
} rsassa_pss_params_t;

/**
 * type for salt generate function.
 *
 * @param[out] salt buffer pointer for setting generated salt
 * @param[in] sLen octet length to generate.
 * @ingroup rsa
 */
typedef void (*SGF)(uint8_t *salt, const int sLen);

/**
 * Check the algorithm is included in OAEP-PSSDigestAlgorithms or not.
 *
 * OAEP-PSSDigestAlgorithms is defined Appendix A.2.1 in RFC8017.
 *
 * @param[in] hash_algo specifies a hash algorithm id (ex.OBJ_HASH_SHA1)
 * @returns 0 if hash_algo is valid, -1 if invalid.
 *
 * @ingroup rsa
 */
int is_in_OAEP_PSSDigestAlgorithms(const int hash_algo);

/**
 * Check the algorithm is included in PKCS1MGFAlgorithms or not.
 *
 * PKCS1MGFAlgorithms is defined Appendix A.2.1 in RFC8017.
 *
 * @param[in] algo specifies a mask generate function id (ex.OBJ_MGF1_SHA1)
 * @returns 0 if algo is valid, -1 if invalid.
 *
 * @ingroup rsa
 */
int is_in_PKCS1MGFAlgorithms(const int algo);

/**
 * Get a MGF1 function id.
 *
 * Return the MGF1 ID (ex.OBJ_MGF1_SHA1) corresponding to the hash.
 *
 * @param[in] algo a hash algorithm id (ex.OBJ_HASH_SHA1)
 * @returns positive integer is MGF1 ID, -1 is error.
 *
 * @ingroup rsa
 */
int get_PKCS1MGFAlgorithms(const int algo);

/**
 * Get a Hash algorithm on which MFG1 is based.
 *
 * Return a Hash algorithm (ex.OBJ_HASH_SHA1) on which MFG1 is based.
 *
 * @param[in] mgf1algo specifies a mask generate function id (ex.OBJ_MGF1_SHA1)
 * @returns positive integer is hash algorithm ID, -1 is error.
 *
 * @ingroup rsa
 */
int get_HashAlgorithm(const int mgf1algo);

/**
 * Set salt generate function.
 *
 * if func == NULL, restore the detault setting.
 * The default function is RSA_PSS_salt_gen.
 *
 * @param[in] func pointer of salt generate function.
 *
 * @ingroup rsa
 */
void set_RSA_PSS_salt_gen_function(SGF func);

/**
 * Generate a random octet string salt of length sLen.
 *
 * if sLen = 0, then salt is the empty string.
 *
 * @param[out] salt buffer pointer for setting generated salt
 * @param[in] sLen byte length to generate.
 * @ingroup rsa
 */
void RSA_PSS_salt_gen(uint8_t *salt, const int sLen);

/**
 * Set recommended values based on hash_algo.
 *
 * - maskGenAlgorithm is MGF1 with hash_algo.
 * - saltLength is set the octet length of the hash value.
 * - trailerField is trailerFieldBC.
 *
 * @param[out] params buffer pointer for setting RSA-PSS-params.
 * @param[in] hash_algo hash algorithm.
 * @returns 0 if succeeded, -1 if arguments are invalid.
 * @ingroup rsa
 */
int RSA_PSS_params_set_recommend(rsassa_pss_params_t *params, const int hash_algo);

/**
 * Set RSA-PSS-params.
 *
 * @param[out] params buffer pointer for setting RSA-PSS-params.
 * @param[in] hash_algo hash algorithm.
 * @param[in] mgfalgo in the set PKCS1MGFAlgorithms (ex.OBJ_MGF1_SHA1)
 * @param[in] sLen octet length of salt.
 * @param[in] trailerfld the trailer field number. It SHALL be 1.
 * @returns 0 if succeeded, -1 if arguments are invalid.
 * @ingroup rsa
 */
int RSA_PSS_params_set(rsassa_pss_params_t *params, const int hash_algo,
		       const int mgfalgo, const int sLen, const int trailerfld);

/**
 * Set RSA-PSS-params with default values.
 *
 * - hashAlgorithm is sha1.
 * - maskGenAlgorithm is mgf1SHA1.
 * - saltLength is 20.
 * - trailerField is trailerFieldBC.
 *
 * @param[out] params buffer pointer for setting RSA-PSS-params.
 * @returns 0 if succeeded, -1 if arguments are invalid.
 * @ingroup rsa
 */
int RSA_PSS_params_set_default(rsassa_pss_params_t *params);

/**
 * Set the mask generation function.
 *
 * @param[out] params buffer pointer for setting RSA-PSS-params.
 * Update only maskGenAlgorithm.
 * @param[in] mgfalgo in the set PKCS1MGFAlgorithms (ex.OBJ_MGF1_SHA1)
 * @returns 0 if succeeded, -1 if arguments are invalid.
 * @ingroup rsa
 */
int RSA_PSS_params_set_maskGenAlgorithm(rsassa_pss_params_t *params,
					const int mgfalgo);

/**
 * Set salt length.
 *
 * @param[out] params buffer pointer for setting RSA-PSS-params.
 * Update only saltLength.
 * @param[in] sLen octet length of salt.
 * @returns 0 if succeeded, -1 if arguments are invalid.
 * @ingroup rsa
 */
int RSA_PSS_params_set_saltLength(rsassa_pss_params_t *params,
				  const int sLen);

/**
 * Signature generation.
 *
 * @param[in] prv RSA private key.
 * @param[in] digest octet string computed by sign_digest_algo.
 * @param[in] dig_size octet length of digest
 * @param[in] params pointer of RSASSA-params.
 * @returns Signature, NULL is error.
 * @ingroup rsa
 */
unsigned char *RSA_PSS_sign_digest(Prvkey_RSA *prv, unsigned char *digest,
				   int dig_size, void *params);

/**
 * EMSA-PSS encoding.
 *
 * @param[in] mHash message to be encoded, an octet string.
 * @param[in] hLen octet length of mHash.
 * @param[in] params pointer of RSASSA-PSS-params.
 * @param[in/out] emLen return encoded message length.
 * @retval !NULL encoded message.
 * @retval NULL error
 * @par Error state:
 * - If the error is caused by encoding error, this function
 * sets the error state to ERR_ST_RSA_ENCODING.
 * - If the error is caused by maskLen too long, this function
 * sets the error state to ERR_ST_RSA_MSKTOOLONG.
 * @ingroup rsa
 */
unsigned char *RSA_PSS_encode(unsigned char *mHash, int hLen,
			      void *params, uint64_t *emLen);

/**
 * EMSA-PSS verification.
 *
 * @param[in] mHash message to be encoded, an octet string.
 * @param[in] hLen octet length of mHash.
 * @param[in] EM encoded message, an octet string.
 * @param[in] emLen octet length of EM.
 * @param[in] emBits bit length of EM.
 * @param[in] params pointer of RSASSA-PSS-params.
 * @retval 0 consistent
 * @retval 1 inconsistent
 * @retval -1 error
 * @par Error state:
 * - If the error is caused by encoding error, this function
 * sets the error state to ERR_ST_RSA_INCONSISTENT.
 * - If the error is caused by maskLen too long, this function
 * sets the error state to ERR_ST_RSA_MSKTOOLONG.
 * @ingroup rsa
 */
int RSA_PSS_verify(unsigned char *mHash, int hLen, uint8_t *EM,
		   uint64_t emLen, uint64_t emBits, void *params);

/* debug output macro.
 *
 * if you want to output debug purpose message, enable following
 * defines.
 *
 *  #define RSA_DEBUG	1
 *
 * debug purpose messages are outputed to stderr. */
#ifdef RSA_DEBUG
#include <stdio.h>
#include <libgen.h>

#define RSA_DPRINTF(format, ...)				\
	fprintf(stderr, "%-15s (%4u) @%-30s: ",			\
		basename(__FILE__), __LINE__, __func__);	\
	fprintf(stderr, format, ## __VA_ARGS__);		\
	fprintf(stderr, "\n");
#else
#define RSA_DPRINTF(format, ...)
#endif
#ifdef __cplusplus
}
#endif

#endif /* INCLUSION_GUARD_UUID_614A7D71_7E80_426F_A570_64F51F84160E */
