/* rsa_pss.c */
/*
 * Copyright (c) 2017-2019 National Institute of Informatics in Japan,
 * All rights reserved.
 *
 * This file or a portion of this file is licensed under the terms of
 * the NAREGI Public License, found at http://www.naregi.org/download/index.html.
 * If you redistribute this file, with or without modifications, you must
 * include this notice in the file.
 */

#include "aiconfig.h" /* for ok_rand.h */

#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <inttypes.h>

#include <aicrypto/ok_asn1.h>
#include <aicrypto/ok_err.h>
#include <aicrypto/ok_rand.h>
#include <aicrypto/ok_rsa.h>
#include <aicrypto/ok_tool.h>

/**
 * Pointer of salt generate function.
 * To change this use function RSA_PSS_set_salt_gen.
 */
static SGF salt_gen = RSA_PSS_salt_gen;

static inline int ceiling(int a, int b)
{
	return ((a / b) + (a % b == 0 ? 0 : 1));
}

#ifdef RSA_DEBUG
static inline void dump(const char *label, const uint8_t *buf,
			const size_t bytes)
{
	int i;

	fprintf(stderr, "%s:\n", label);
	for (i = 0; i < bytes; i++) {
		if (i % 8 == 0) {
			fprintf(stderr, "    ");
		}
		fprintf(stderr, "%02x ", buf[i]);
		if ((i+1) % 8 == 0) {
			fprintf(stderr, "\n");
		}
	}
	fprintf(stderr, "\n");
}
#endif /* RSA_DEBUG */

/**
 * MGF1 is a mask generation function based on a hash function.
 *
 * @param[in] mgfSeed seed from which mask is generated, an octet string
 * @param[in] maskLen  intended length in octets of the mask, at most 2^32 hLen
 * @param[in] ha hash function information
 * @param[out] mask buffer pointer for setting mask
 * @returns 0 when successfull, -1 when an internal error occured.
 * @par Error state:
 * If the error is caused by maskLen too long, this function
 * sets the error state to ERR_ST_RSA_MSKTOOLONG.
 * @ingroup rsa
 */
static int MGF1(uint8_t *mgfSeed, uint64_t maskLen, HASHAlgorithm *ha,
		uint8_t *mask);

int is_in_OAEP_PSSDigestAlgorithms(const int hash_algo)
{
	int result = -1;

	switch (hash_algo) {
	case OBJ_HASH_SHA1:
	case OBJ_HASH_SHA224:
	case OBJ_HASH_SHA256:
	case OBJ_HASH_SHA384:
	case OBJ_HASH_SHA512:
	case OBJ_HASH_SHA512224:
	case OBJ_HASH_SHA512256:
		result = 0;
		break;
	default:
		break;
	}
	return result;
}

int is_in_PKCS1MGFAlgorithms(const int algo)
{
	int result = -1;

	switch (algo) {
	case OBJ_MGF1_SHA1:
	case OBJ_MGF1_SHA224:
	case OBJ_MGF1_SHA256:
	case OBJ_MGF1_SHA384:
	case OBJ_MGF1_SHA512:
	case OBJ_MGF1_SHA512224:
	case OBJ_MGF1_SHA512256:
		result = 0;
		break;
	default:
		break;
	}
	return result;
}

int get_PKCS1MGFAlgorithms(const int algo)
{
	int ret = -1;

	switch (algo) {
	case OBJ_HASH_SHA1:
		ret = OBJ_MGF1_SHA1;
		break;
	case OBJ_HASH_SHA224:
		ret = OBJ_MGF1_SHA224;
		break;
	case OBJ_HASH_SHA256:
		ret = OBJ_MGF1_SHA256;
		break;
	case OBJ_HASH_SHA384:
		ret = OBJ_MGF1_SHA384;
		break;
	case OBJ_HASH_SHA512:
		ret = OBJ_MGF1_SHA512;
		break;
	case OBJ_HASH_SHA512224:
		ret = OBJ_MGF1_SHA512224;
		break;
	case OBJ_HASH_SHA512256:
		ret = OBJ_MGF1_SHA512256;
		break;
	default:
		break;
	}
	return ret;
}

int get_HashAlgorithm(const int mgf1algo)
{
	int ret = -1;

	switch (mgf1algo) {
	case OBJ_MGF1_SHA1:
		ret = OBJ_HASH_SHA1;
		break;
	case OBJ_MGF1_SHA224:
		ret = OBJ_HASH_SHA224;
		break;
	case OBJ_MGF1_SHA256:
		ret = OBJ_HASH_SHA256;
		break;
	case OBJ_MGF1_SHA384:
		ret = OBJ_HASH_SHA384;
		break;
	case OBJ_MGF1_SHA512:
		ret = OBJ_HASH_SHA512;
		break;
	case OBJ_MGF1_SHA512224:
		ret = OBJ_HASH_SHA512224;
		break;
	case OBJ_MGF1_SHA512256:
		ret = OBJ_HASH_SHA512256;
		break;
	default:
		break;
	}
	return ret;
}

void RSA_PSS_salt_gen(uint8_t *salt, const int sLen)
{
	assert(salt != NULL);

	if (sLen == 0) {
		salt[0] = '\0';
		return;
	}

	RAND_init();
	RAND_bytes(salt, sLen);
	RAND_cleanup();
}

int RSA_PSS_params_set(rsassa_pss_params_t *params, const int hash_algo,
		       const int mgfalgo, const int sLen, const int trailerfld)
{
	assert(params != NULL);

	if (is_in_OAEP_PSSDigestAlgorithms(hash_algo) != 0) {
#ifdef RSA_DEBUG
		RSA_DPRINTF("invalid hash algorithm: %d\n", hash_algo);
#endif /* RSA_DEBUG */
		return -1;
	}

	if (is_in_PKCS1MGFAlgorithms(mgfalgo) != 0) {
#ifdef RSA_DEBUG
		RSA_DPRINTF("invalid MGF algorithm: %d\n", mgfalgo);
#endif /* RSA_DEBUG */
		return -1;
	}

	if (sLen < 0) {
#ifdef RSA_DEBUG
		RSA_DPRINTF("invalid salt length: %d\n", sLen);
#endif /* RSA_DEBUG */
		return -1;
	}

	if (trailerfld != 1) {
#ifdef RSA_DEBUG
		RSA_DPRINTF("invalid trailer field: %d\n", trailerfld);
#endif /* RSA_DEBUG */
		return -1;
	}

	params->hashAlgorithm = hash_algo;
	params->maskGenAlgorithm = mgfalgo;
	params->saltLength = sLen;
	params->trailerField = trailerfld;

	return 0;
}

int RSA_PSS_params_set_recommend(rsassa_pss_params_t *params,
				 const int hash_algo)
{
	assert(params != NULL);

	/*
	 * RFC8017 A.2.3.  RSASSA-PSS
	 *
	 * hashAlgorithm identifies the hash function.  It SHALL be an
	 * algorithm ID with an OID in the set OAEP-PSSDigestAlgorithms
	 * (snip)
	 */
	if (is_in_OAEP_PSSDigestAlgorithms(hash_algo) != 0) {
#ifdef RSA_DEBUG
		RSA_DPRINTF("invalid hash algorithm: %d\n", hash_algo);
#endif /* RSA_DEBUG */
		return -1;
	}

	/*
	 * maskGenAlgorithm identifies the mask generation function.  It
	 * SHALL be an algorithm ID with an OID in the set PKCS1MGFAlgorithms
	 * (snip).
	 *
	 * For MGF1 (and more generally, for other mask generation functions
	 * based on a hash function), it is RECOMMENDED that the underlying
	 * hash function be the same as the one identified by hashAlgorithm
	 * (snip).
	 */
	int mgfalgo = get_PKCS1MGFAlgorithms(hash_algo);

	/*
	 * saltLength is the octet length of the salt.  It SHALL be an
	 * integer.  For a given hashAlgorithm, the default value of
	 * saltLength is the octet length of the hash value. (snip)
	 */
	int sLen = hash_size(hash_algo);

	/*
	 * trailerField is the trailer field number, for compatibility with
	 * IEEE 1363a [IEEE1363A].  It SHALL be 1 for this version of the
	 * document, which represents the trailer field with hexadecimal
	 * value 0xbc.  Other trailer fields (including the trailer field
	 * HashID || 0xcc in IEEE 1363a) are not supported in this document.
	 *   TrailerField ::= INTEGER { trailerFieldBC(1) }
	 */
	int trailerFld = 1;

	return RSA_PSS_params_set(params, hash_algo, mgfalgo, sLen, trailerFld);
}

int RSA_PSS_params_set_default(rsassa_pss_params_t *params)
{
	return RSA_PSS_params_set_recommend(params, OBJ_HASH_SHA1);
}

int RSA_PSS_params_set_maskGenAlgorithm(rsassa_pss_params_t *params,
					 const int mgfalgo)
{
	assert(params != NULL);

	if (is_in_PKCS1MGFAlgorithms(mgfalgo) != 0) {
#ifdef RSA_DEBUG
		RSA_DPRINTF("invalid MGF algorithm: %d\n", mgfalgo);
#endif /* RSA_DEBUG */
		return -1;
	}

	params->maskGenAlgorithm = mgfalgo;
	return 0;
}

int RSA_PSS_params_set_saltLength(rsassa_pss_params_t *params,
				  const int sLen)
{
	assert(params != NULL);

	if (sLen < 0) {
#ifdef RSA_DEBUG
		RSA_DPRINTF("invalid salt length: %d\n", sLen);
#endif /* RSA_DEBUG */
		return -1;
	}

	params->saltLength = sLen;
	return 0;
}

void set_RSA_PSS_salt_gen_function(SGF func)
{
	salt_gen = (func != NULL) ? func : RSA_PSS_salt_gen;
}

/* RSASSA-PSS-SIGN (K, M) */
unsigned char *RSA_PSS_sign_digest(Prvkey_RSA *prv, unsigned char *digest,
				   int dig_size, void *params)
{
	uint64_t emBits; /* less than 2^64 bits */
	uint64_t emLen; /* less than 2^61 octets */
	unsigned char *EM;
	unsigned char *S;

	assert(prv != NULL);
	assert(digest != NULL);
	assert(dig_size > 0);

	/*
	 * 1.  EMSA-PSS encoding: Apply the EMSA-PSS encoding operation
	 *     (Section 9.1.1) to the message M to produce an encoded message
	 *     EM of length \ceil ((modBits - 1)/8) octets such that the bit
	 *     length of the integer OS2IP (EM) (see Section 4.2) is at most
	 *     modBits - 1, where modBits is the length in bits of the RSA
	 *     modulus n:
	 *
	 *        EM = EMSA-PSS-ENCODE (M, modBits - 1).
	 *
	 *     Note that the octet length of EM will be one less than k if
	 *     modBits - 1 is divisible by 8 and equal to k otherwise.  If
	 *     the encoding operation outputs "message too long", output
	 *     "message too long" and stop.  If the encoding operation
	 *     outputs "encoding error", output "encoding error" and stop.
	 */

	emBits = prv->size * 8;
#ifdef RSA_DEBUG
	RSA_DPRINTF("emBits=%"PRIu64" \n", emBits);
#endif /* RSA_DEBUG */
	emLen = ceiling(emBits -1, 8);

#ifdef RSA_DEBUG
	RSA_DPRINTF("emLen=%"PRIu64" (before) \n", emLen);
#endif /* RSA_DEBUG */
	EM = RSA_PSS_encode(digest, dig_size, params, &emLen);
	if (EM == NULL) {
		return NULL;
	}
#ifdef RSA_DEBUG
	RSA_DPRINTF("emLen=%"PRIu64" (after) \n", emLen);
	dump("EM", EM, emLen);
#endif /* RSA_DEBUG */

	/*
	 * 2.  RSA signature:
	 */
	/* if ret is NULL, something's wrong with key X( */
	S = OK_do_sign((Key *) prv, EM, emLen, NULL);
	free(EM);

	/*
	 * 3.  Output the signature S.
	 */
	return S;
}

/* EMSA-PSS-ENCODE (M, emBits) */
unsigned char *RSA_PSS_encode(unsigned char *mHash, int hLen,
			      void *params, uint64_t *emLen)
{
	int hash_algo;
	int mgf1_algo;
	int sLen;
#if 0
	/*
	 * This variable is unused since the value that can be set
	 * at present is '1' only.
	 */
	int trailerField;
#endif

#ifdef RSA_DEBUG
	dump("mHash", mHash, hLen);
#endif /* RSA_DEBUG */
	if (params) {
		rsassa_pss_params_t *p = (rsassa_pss_params_t *) params;
		//hash_algo = p->hashAlgorithm;
		hash_algo = sign_digest_algo;
		mgf1_algo = p->maskGenAlgorithm;
		sLen = p->saltLength;
#if 0
		trailerField = p->trailerField;
#endif
	} else {
		/* set default */
		hash_algo = OBJ_HASH_SHA1;
		mgf1_algo = OBJ_MGF1_SHA1;
		sLen = 20;
#if 0
		trailerField = 1;
#endif
	}

	/*
	 * 3.   If emLen < hLen + sLen + 2, output "encoding error" and stop.
	 */
#ifdef RSA_DEBUG
	RSA_DPRINTF("emLen=%"PRIu64"\n", *emLen);
	RSA_DPRINTF(" hLen=%d\n", hLen);
	RSA_DPRINTF(" sLen=%d\n", sLen);
#endif /* RSA_DEBUG */
	if (*emLen < hLen + sLen + 2) {
		OK_set_error(ERR_ST_RSA_ENCODING, ERR_LC_RSA, ERR_PT_RSAPSS+0,
			     NULL);
		return NULL;
	}

	/*
	 * 4.   Generate a random octet string salt of length sLen; if sLen =
	 *      0, then salt is the empty string.
	 */
	uint8_t salt[sLen + 1];
	if (sLen != 0) {
		(*salt_gen)(salt, sLen);
	} else {
		salt[0] = '\0';
	}
#ifdef RSA_DEBUG
	RSA_DPRINTF("sLen = %d\n", sLen);
	dump("salt", salt, sLen);
#endif /* RSA_DEBUG */

	/*
	 * 5.   Let
	 *
	 *         M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
	 *
	 *      M' is an octet string of length 8 + hLen + sLen with eight
	 *      initial zero octets.
	 */
	uint8_t M_[8 + hLen + sLen];
	memset(M_, 0, sizeof(M_));
	memcpy(&M_[8], mHash, hLen);
	memcpy(&M_[8 + hLen], salt, sLen);
#ifdef RSA_DEBUG
	dump("M'", M_, 8 + hLen + sLen);
#endif /* RSA_DEBUG */

	/*
	 * 6.   Let H = Hash(M'), an octet string of length hLen.
	 */
	HASHAlgorithm *ha;
	uint8_t H[hLen];

	ha = gethashalgobyaioid(hash_algo);
#ifdef RSA_DEBUG
	RSA_DPRINTF("H = %s(M')\n", ha->name);
#endif /* RSA_DEBUG */
	(*ha->hash_compute)(sizeof(M_), M_, H);
#ifdef RSA_DEBUG
	dump("H", H, hLen);
#endif /* RSA_DEBUG */

	/*
	 * 7.   Generate an octet string PS consisting of emLen - sLen - hLen
	 *      - 2 zero octets.  The length of PS may be 0.
	 */
	int psLen = *emLen - sLen - hLen - 2;
#ifdef RSA_DEBUG
	RSA_DPRINTF("psLen=%d\n", psLen);
#endif /* RSA_DEBUG */

	/*
	 * 8.   Let DB = PS || 0x01 || salt; DB is an octet string of length
	 *      emLen - hLen - 1.
	 */
	uint64_t dbLen = *emLen - hLen - 1;
#ifdef RSA_DEBUG
	RSA_DPRINTF("dbLen=%"PRIu64" \n", dbLen);
#endif /* RSA_DEBUG */
	uint8_t DB[dbLen + 1];
	memset(DB, 0, dbLen + 1);
	DB[psLen] = 1;
	memcpy(&DB[psLen+1], salt, sLen);

	/*
	 * 9.   Let dbMask = MGF(H, emLen - hLen - 1).
	 */
	int mgf1_hash_algo = get_HashAlgorithm(mgf1_algo);
#ifdef RSA_DEBUG
	RSA_DPRINTF("     mgf1_algo = %d\n", mgf1_algo);
	RSA_DPRINTF("mgf1_hash_algo = %d\n", mgf1_hash_algo);
	RSA_DPRINTF("     hash_algo = %d\n",      hash_algo);
#endif /* RSA_DEBUG */
	if (mgf1_hash_algo != hash_algo) {
		ha = gethashalgobyaioid(mgf1_hash_algo);
	}

	uint8_t dbMask[dbLen + 1];
	int ret;

	memset(dbMask, 0, dbLen + 1);
	ret = MGF1(H, dbLen, ha, dbMask);
	if (ret != 0) {
		return NULL;
	}
#ifdef RSA_DEBUG
	dump("dbMask", dbMask, dbLen);
#endif /* RSA_DEBUG */

	/*
	 * 10.  Let maskedDB = DB \xor dbMask.
	 */
	uint64_t i;
	for (i = 0; i < dbLen; i++) {
		DB[i] ^= dbMask[i];
	}
#ifdef RSA_DEBUG
	dump("DB^dbMask", DB, dbLen);
#endif /* RSA_DEBUG */

	/*
	 * 11.  Set the leftmost 8emLen - emBits bits of the leftmost octet
	 *      in maskedDB to zero.
	 */
	DB[0] &= 0x7f;

	/*
	 * 12.  Let EM = maskedDB || H || 0xbc.
	 */
	uint8_t *EM = (uint8_t *) malloc(*emLen);
	memcpy(&EM[0], DB, dbLen);
	memcpy(&EM[dbLen], H, hLen);
#if 0
	/*
	 *  Should implement the following steps after a value
	 *  other than 1 will be able to set in trailerField.
	 */
	if (trailerField == 1) ...
#endif
	EM[dbLen + hLen] = 0xbc;
#ifdef RSA_DEBUG
	dump("EM", EM, *emLen);
#endif /* RSA_DEBUG */

	/*
	 * 13.  Output EM.
	 */
	*emLen = dbLen + hLen + 1;
#ifdef RSA_DEBUG
	RSA_DPRINTF("emLen=%"PRIu64"\n", *emLen);
#endif /* RSA_DEBUG */
	return EM;
}

/* EMSA-PSS-VERIFY (M, EM, emBits) */
int RSA_PSS_verify(unsigned char *mHash, int hLen,
		   uint8_t *EM, uint64_t emLen, uint64_t emBits,
		   void *params)
{
	int hash_algo;
	int mgf1_algo;
	int sLen;
#if 0
	/*
	 * This variable is unused since the value that can be set
	 * at present is '1' only.
	 */
	int trailerField;
#endif

#ifdef RSA_DEBUG
	RSA_DPRINTF("RSA_PSS_verify start\n");
	dump("mHash", mHash, hLen);
	dump("EM", EM, emLen);
#endif /* RSA_DEBUG */
	if (params) {
		rsassa_pss_params_t *p = (rsassa_pss_params_t *) params;
		//hash_algo = p->hashAlgorithm;
		hash_algo = sign_digest_algo;
		mgf1_algo = p->maskGenAlgorithm;
		sLen = p->saltLength;
#if 0
		trailerField = p->trailerField;
#endif
	} else {
		/* set default */
		hash_algo = OBJ_HASH_SHA1;
		mgf1_algo = OBJ_MGF1_SHA1;
		sLen = 20;
#if 0
		trailerField = 1;
#endif
	}

	/*
	 * 3.   If emLen < hLen + sLen + 2, output "inconsistent" and stop.
	 */
#ifdef RSA_DEBUG
	RSA_DPRINTF("emLen=%"PRIu64"\n", emLen);
	RSA_DPRINTF(" hLen=%d\n", hLen);
	RSA_DPRINTF(" sLen=%d\n", sLen);
#endif /* RSA_DEBUG */
	if (emLen < hLen + sLen + 2) {
		OK_set_error(ERR_ST_RSA_INCONSISTENT, ERR_LC_RSA,
			     ERR_PT_RSAPSS+2, NULL);
		return 1;
	}

	/*
	 * 4.   If the rightmost octet of EM does not have hexadecimal value
	 *      0xbc, output "inconsistent" and stop.
	 */
#ifdef RSA_DEBUG
	RSA_DPRINTF("EM[%lu] = %02x\n", emLen - 1, EM[emLen - 1]);
#endif /* RSA_DEBUG */
	if (EM[emLen - 1] != 0xbc) {
		OK_set_error(ERR_ST_RSA_INCONSISTENT, ERR_LC_RSA,
			     ERR_PT_RSAPSS+3, NULL);
		return 1;
	}

	/*
	 * 5.   Let maskedDB be the leftmost emLen - hLen - 1 octets of EM,
	 *      and let H be the next hLen octets.
	 */
	uint64_t dbLen = emLen - hLen - 1;
	uint8_t DB[dbLen];
	uint8_t *H;

	memcpy(DB, EM, dbLen);
	H = &EM[dbLen];
#ifdef RSA_DEBUG
	RSA_DPRINTF("dbLen=%"PRIu64" \n", dbLen);
	dump("maskedDB", DB, dbLen);
	dump("H", H, hLen);
#endif /* RSA_DEBUG */

	/*
	 * 6.   If the leftmost 8emLen - emBits bits of the leftmost octet in
	 *      maskedDB are not all equal to zero, output "inconsistent" and
	 *      stop.
	 */
	uint8_t bits = 8 * emLen - emBits;
#ifdef RSA_DEBUG
	RSA_DPRINTF("8 * emLen - emBits = %d\n", bits);
#endif /* RSA_DEBUG */
	if (bits != 0) {
		uint8_t maskbits = 0xff << (8 - bits);
#ifdef RSA_DEBUG
		RSA_DPRINTF("bits = %d, maskbits = %02x, maskedDB=%02x\n",
			    bits, maskbits, DB[0]);
#endif /* RSA_DEBUG */
		if ((DB[0] & maskbits) != 0) {
			OK_set_error(ERR_ST_RSA_INCONSISTENT, ERR_LC_RSA,
				     ERR_PT_RSAPSS+4, NULL);
			return 1;
		}
	}

	/*
	 * 7.   Let dbMask = MGF(H, emLen - hLen - 1).
	 */
	HASHAlgorithm *ha;
	int mgf1_hash_algo = get_HashAlgorithm(mgf1_algo);
#ifdef RSA_DEBUG
	RSA_DPRINTF("     mgf1_algo = %d\n", mgf1_algo);
	RSA_DPRINTF("mgf1_hash_algo = %d\n", mgf1_hash_algo);
#endif /* RSA_DEBUG */
	ha = gethashalgobyaioid(mgf1_hash_algo);

	uint8_t dbMask[dbLen];
	int ret;

	memset(dbMask, 0, dbLen);
	ret = MGF1(H, dbLen, ha, dbMask);
	if (ret != 0) {
		return -1;
	}
#ifdef RSA_DEBUG
	dump("dbMask", dbMask, dbLen);
#endif /* RSA_DEBUG */

	/*
	 * 8.   Let DB = maskedDB \xor dbMask.
	 */
	uint64_t i;
	for (i = 0; i < dbLen; i++) {
		DB[i] ^= dbMask[i];
	}
#ifdef RSA_DEBUG
	dump("maskedDB^dbMask", DB, dbLen);
#endif /* RSA_DEBUG */

	/*
	 * 9.   Set the leftmost 8emLen - emBits bits of the leftmost octet
	 *      in DB to zero.
	 */
	if (bits != 0) {
		uint8_t maskbits = 0xff >> bits;
#ifdef RSA_DEBUG
		RSA_DPRINTF("maskbits = %02x\n", maskbits);
#endif /* RSA_DEBUG */
		DB[0] &= maskbits;
	}
#ifdef RSA_DEBUG
	dump("DB", DB, dbLen);
#endif /* RSA_DEBUG */

	/*
	 * 10.  If the emLen - hLen - sLen - 2 leftmost octets of DB are not
	 *      zero or if the octet at position emLen - hLen - sLen - 1 (the
	 *      leftmost position is "position 1") does not have hexadecimal
	 *      value 0x01, output "inconsistent" and stop.
	 */
	/* dbLen: emLen - hLen - 1 */
	uint64_t psLen = dbLen - sLen - 1;
	int nonzero = 0;

	for (i = 0; i < psLen; i++) {
		if (DB[i] != 0) {
			nonzero = 1;
			break;
		}
	}
#ifdef RSA_DEBUG
	RSA_DPRINTF("nonzero = %d\n", nonzero);
	if (nonzero == 1) {
		RSA_DPRINTF("DB[%lu] = %02x\n", i, DB[i]);
	}
	RSA_DPRINTF("DB[%lu] = %02x\n", dbLen - sLen, DB[dbLen - sLen -1]);
#endif /* RSA_DEBUG */
	if (nonzero || DB[dbLen - sLen -1] != 0x01) {
		OK_set_error(ERR_ST_RSA_INCONSISTENT, ERR_LC_RSA,
			     ERR_PT_RSAPSS+5, NULL);
		return 1;
	}

	/*
	 * 11.  Let salt be the last sLen octets of DB.
	 */
	uint8_t *salt;
	salt = &DB[dbLen - sLen];
#ifdef RSA_DEBUG
	RSA_DPRINTF("sLen = %d\n", sLen);
	dump("salt", salt, sLen);
#endif /* RSA_DEBUG */

	/*
	 * 12.  Let
	 *
	 *         M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
	 *
	 *      M' is an octet string of length 8 + hLen + sLen with eight
	 *      initial zero octets.
	 */
	uint8_t M_[8 + hLen + sLen];
	memset(M_, 0, sizeof(M_));
	memcpy(&M_[8], mHash, hLen);
	memcpy(&M_[8 + hLen], salt, sLen);
#ifdef RSA_DEBUG
	dump("M'", M_, 8 + hLen + sLen);
#endif /* RSA_DEBUG */

	/*
	 * 13.  Let H' = Hash(M'), an octet string of length hLen.
	 */
	uint8_t H_[hLen];

	if (hash_algo != mgf1_hash_algo) {
		ha = gethashalgobyaioid(hash_algo);
	}

#ifdef RSA_DEBUG
	RSA_DPRINTF("H' = %s(M')\n", ha->name);
#endif /* RSA_DEBUG */
	(*ha->hash_compute)(sizeof(M_), M_, H_);
#ifdef RSA_DEBUG
	dump("H'", H_, hLen);
#endif /* RSA_DEBUG */

	/*
	 * 14.  If H = H', output "consistent".  Otherwise, output
	 *      "inconsistent".
	 */
	if (memcmp(H, H_, hLen) != 0) {
		OK_set_error(ERR_ST_RSA_INCONSISTENT, ERR_LC_RSA,
			     ERR_PT_RSAPSS+6, NULL);
		return 1;
	}
	/* consistent */
	return 0;
}

/*  MGF1 (mgfSeed, maskLen) */
static int MGF1(uint8_t *mgfSeed, uint64_t maskLen, HASHAlgorithm *ha,
		uint8_t *mask)
{
	uint8_t C[4];
	uint32_t counter;
	/* max_mask : 2^32 hLen */
	uint64_t max_mask = (1ULL << 32);
	max_mask *= ha->hash_size;

#ifdef RSA_DEBUG
	RSA_DPRINTF("max_mask = %"PRIu64"\n", max_mask);
	RSA_DPRINTF("maskLen = %"PRIu64"\n", maskLen);
	RSA_DPRINTF("hLen = %d\n", ha->hash_size);
	RSA_DPRINTF("ceil = %d\n", ceiling(maskLen, ha->hash_size));
#endif /* RSA_DEBUG */
	/*
	 * 1.  If maskLen > 2^32 hLen, output "mask too long" and stop.
	 */
	if (maskLen > max_mask) {
		OK_set_error(ERR_ST_RSA_MSKTOOLONG, ERR_LC_RSA, ERR_PT_RSAPSS+1,
			     NULL);
		return -1;
	}

	/*
	 * 2.  Let T be the empty octet string.
	 */
	uint32_t counter_max = ceiling(maskLen, ha->hash_size);
#ifdef RSA_DEBUG
	RSA_DPRINTF("counter_max = %d\n", counter_max);
#endif /* RSA_DEBUG */
	uint8_t T[ha->hash_size * counter_max];
	memset(T, 0, ha->hash_size * counter_max);

	int tLen = 0;
	/*
	 * 3.  For counter from 0 to \ceil (maskLen / hLen) - 1, do the
	 *     following:
	 */
	for (counter = 0; counter < counter_max; counter++) {
		/*
		 *  A.  Convert counter to an octet string C of length 4
		 *      octets (see Section 4.1):
		 *
		 *         C = I2OSP (counter, 4) .
		 */
		C[0] = (counter >> 8*3) & 0xff;
		C[1] = (counter >> 8*2) & 0xff;
		C[2] = (counter >> 8*1) & 0xff;
		C[3] = (counter       ) & 0xff;
		/*
		 *  B.  Concatenate the hash of the seed mgfSeed and C to
		 *      the octet string T:
		 *
		 *         T = T || Hash(mgfSeed || C) .
		 */
		uint8_t buf[ha->hash_size + 4];
		memset(buf, 0, sizeof(buf));
		memcpy(&buf[0], mgfSeed, ha->hash_size);
		memcpy(&buf[ha->hash_size], C, 4);
#ifdef RSA_DEBUG
		RSA_DPRINTF("counter = %d\n", counter);
		dump("mgfSeed || C", buf, ha->hash_size + 4);
#endif /* RSA_DEBUG */

		uint8_t digest[ha->hash_size];
		(*ha->hash_compute)(sizeof(buf), buf, digest);
#ifdef RSA_DEBUG
		dump("Hash(mgfSeed || C)", digest, ha->hash_size);
#endif /* RSA_DEBUG */

		memcpy(&T[tLen], digest, ha->hash_size);
		tLen += ha->hash_size;
#ifdef RSA_DEBUG
		RSA_DPRINTF("tLen = %d\n", tLen);
		dump("T", T, tLen);
#endif /* RSA_DEBUG */
	}

	/*
	 * 4.  Output the leading maskLen octets of T as the octet string mask.
	 */
	memcpy(mask, T, maskLen);

	return 0;
}
