From 3d5afc8b83632c712d9ee62f6c5421a7fe5ace58 Mon Sep 17 00:00:00 2001
From: Andy Polyakov <appro@openssl.org>
Date: Thu, 2 Jun 2005 18:29:21 +0000
Subject: [PATCH] PSS update [from 0.9.7].

---
 crypto/rsa/rsa.h     |  4 +-
 crypto/rsa/rsa_err.c |  4 +-
 crypto/rsa/rsa_pss.c | 89 +++++++++++++++++++++++++++++---------------
 3 files changed, 63 insertions(+), 34 deletions(-)

diff --git a/crypto/rsa/rsa.h b/crypto/rsa/rsa.h
index 1f6131b42..a1ca34760 100644
--- a/crypto/rsa/rsa.h
+++ b/crypto/rsa/rsa.h
@@ -411,7 +411,7 @@ void ERR_load_RSA_strings(void);
 #define RSA_R_NULL_BEFORE_BLOCK_MISSING			 113
 #define RSA_R_N_DOES_NOT_EQUAL_P_Q			 127
 #define RSA_R_OAEP_DECODING_ERROR			 121
-#define RSA_R_ONE_CHECK_FAILED				 135
+#define RSA_R_SLEN_RECOVERY_FAILED			 135
 #define RSA_R_PADDING_CHECK_FAILED			 114
 #define RSA_R_P_NOT_PRIME				 128
 #define RSA_R_Q_NOT_PRIME				 129
@@ -421,7 +421,7 @@ void ERR_load_RSA_strings(void);
 #define RSA_R_UNKNOWN_ALGORITHM_TYPE			 117
 #define RSA_R_UNKNOWN_PADDING_TYPE			 118
 #define RSA_R_WRONG_SIGNATURE_LENGTH			 119
-#define RSA_R_ZERO_CHECK_FAILED				 136
+#define RSA_R_SLEN_CHECK_FAILED				 136
 
 #ifdef  __cplusplus
 }
diff --git a/crypto/rsa/rsa_err.c b/crypto/rsa/rsa_err.c
index 48e8f3931..cfb1e908a 100644
--- a/crypto/rsa/rsa_err.c
+++ b/crypto/rsa/rsa_err.c
@@ -141,7 +141,7 @@ static ERR_STRING_DATA RSA_str_reasons[]=
 {ERR_REASON(RSA_R_NULL_BEFORE_BLOCK_MISSING),"null before block missing"},
 {ERR_REASON(RSA_R_N_DOES_NOT_EQUAL_P_Q)  ,"n does not equal p q"},
 {ERR_REASON(RSA_R_OAEP_DECODING_ERROR)   ,"oaep decoding error"},
-{ERR_REASON(RSA_R_ONE_CHECK_FAILED)      ,"one check failed"},
+{ERR_REASON(RSA_R_SLEN_RECOVERY_FAILED)  ,"salt length recovery failed"},
 {ERR_REASON(RSA_R_PADDING_CHECK_FAILED)  ,"padding check failed"},
 {ERR_REASON(RSA_R_P_NOT_PRIME)           ,"p not prime"},
 {ERR_REASON(RSA_R_Q_NOT_PRIME)           ,"q not prime"},
@@ -151,7 +151,7 @@ static ERR_STRING_DATA RSA_str_reasons[]=
 {ERR_REASON(RSA_R_UNKNOWN_ALGORITHM_TYPE),"unknown algorithm type"},
 {ERR_REASON(RSA_R_UNKNOWN_PADDING_TYPE)  ,"unknown padding type"},
 {ERR_REASON(RSA_R_WRONG_SIGNATURE_LENGTH),"wrong signature length"},
-{ERR_REASON(RSA_R_ZERO_CHECK_FAILED)     ,"zero check failed"},
+{ERR_REASON(RSA_R_SLEN_CHECK_FAILED)     ,"salt length check failed"},
 {0,NULL}
 	};
 
diff --git a/crypto/rsa/rsa_pss.c b/crypto/rsa/rsa_pss.c
index 5dcdb5460..2815628f5 100644
--- a/crypto/rsa/rsa_pss.c
+++ b/crypto/rsa/rsa_pss.c
@@ -76,10 +76,35 @@ int RSA_verify_PKCS1_PSS(RSA *rsa, const unsigned char *mHash,
 	unsigned char *DB = NULL;
 	EVP_MD_CTX ctx;
 	unsigned char H_[EVP_MAX_MD_SIZE];
+
+	hLen = EVP_MD_size(Hash);
+	/*
+	 * Negative sLen has special meanings:
+	 *	-1	sLen == hLen
+	 *	-2	salt length is autorecovered from signature
+	 *	-N	reserved
+	 */
+	if      (sLen == -1)	sLen = hLen;
+	else if (sLen == -2)	sLen = -2;
+	else if (sLen < -2)
+		{
+		RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_SLEN_CHECK_FAILED);
+		goto err;
+		}
+
 	MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
 	emLen = RSA_size(rsa);
-	hLen = EVP_MD_size(Hash);
-	if (emLen < (hLen + sLen + 2))
+	if (EM[0] & (0xFF << MSBits))
+		{
+		RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_FIRST_OCTET_INVALID);
+		goto err;
+		}
+	if (MSBits == 0)
+		{
+		EM++;
+		emLen--;
+		}
+	if (emLen < (hLen + sLen + 2)) /* sLen can be small negative */
 		{
 		RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_DATA_TOO_LARGE);
 		goto err;
@@ -89,16 +114,6 @@ int RSA_verify_PKCS1_PSS(RSA *rsa, const unsigned char *mHash,
 		RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_LAST_OCTET_INVALID);
 		goto err;
 		}
-	if (EM[0] & (0xFF << MSBits))
-		{
-		RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_FIRST_OCTET_INVALID);
-		goto err;
-		}
-	if (!MSBits)
-		{
-		EM++;
-		emLen--;
-		}
 	maskedDBLen = emLen - hLen - 1;
 	H = EM + maskedDBLen;
 	DB = OPENSSL_malloc(maskedDBLen);
@@ -112,26 +127,23 @@ int RSA_verify_PKCS1_PSS(RSA *rsa, const unsigned char *mHash,
 		DB[i] ^= EM[i];
 	if (MSBits)
 		DB[0] &= 0xFF >> (8 - MSBits);
-	for (i = 0; i < (emLen - hLen - sLen - 2); i++)
+	for (i = 0; DB[i] == 0 && i < (maskedDBLen-1); i++) ;
+	if (DB[i++] != 0x1)
 		{
-		if (DB[i] != 0)	
-			{
-			RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS,
-						RSA_R_ZERO_CHECK_FAILED);
-			goto err;
-			}
+		RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_SLEN_RECOVERY_FAILED);
+		goto err;
 		}
-	if (DB[i] != 0x1)
+	if (sLen >= 0 && (maskedDBLen - i) != sLen)
 		{
-		RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_ONE_CHECK_FAILED);
+		RSAerr(RSA_F_RSA_VERIFY_PKCS1_PSS, RSA_R_SLEN_CHECK_FAILED);
 		goto err;
 		}
 	EVP_MD_CTX_init(&ctx);
 	EVP_DigestInit_ex(&ctx, Hash, NULL);
 	EVP_DigestUpdate(&ctx, zeroes, sizeof zeroes);
 	EVP_DigestUpdate(&ctx, mHash, hLen);
-	if (sLen)
-		EVP_DigestUpdate(&ctx, DB + maskedDBLen - sLen, sLen);
+	if (maskedDBLen - i)
+		EVP_DigestUpdate(&ctx, DB + i, maskedDBLen - i);
 	EVP_DigestFinal(&ctx, H_, NULL);
 	EVP_MD_CTX_cleanup(&ctx);
 	if (memcmp(H_, H, hLen))
@@ -159,22 +171,39 @@ int RSA_padding_add_PKCS1_PSS(RSA *rsa, unsigned char *EM,
 	int hLen, maskedDBLen, MSBits, emLen;
 	unsigned char *H, *salt = NULL, *p;
 	EVP_MD_CTX ctx;
-	MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
-	emLen = RSA_size(rsa);
+
 	hLen = EVP_MD_size(Hash);
-	if (sLen < 0)
-		sLen = 0;
-	if (emLen < (hLen + sLen + 2))
+	/*
+	 * Negative sLen has special meanings:
+	 *	-1	sLen == hLen
+	 *	-2	salt length is maximized
+	 *	-N	reserved
+	 */
+	if      (sLen == -1)	sLen = hLen;
+	else if (sLen == -2)	sLen = -2;
+	else if (sLen < -2)
 		{
-		RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_PSS,
-		   RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
+		RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_PSS, RSA_R_SLEN_CHECK_FAILED);
 		goto err;
 		}
+
+	MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
+	emLen = RSA_size(rsa);
 	if (MSBits == 0)
 		{
 		*EM++ = 0;
 		emLen--;
 		}
+	if (sLen == -2)
+		{
+		sLen = emLen - hLen - 2;
+		}
+	else if (emLen < (hLen + sLen + 2))
+		{
+		RSAerr(RSA_F_RSA_PADDING_ADD_PKCS1_PSS,
+		   RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
+		goto err;
+		}
 	if (sLen > 0)
 		{
 		salt = OPENSSL_malloc(sLen);