diff --git a/ssl/s3_clnt.c b/ssl/s3_clnt.c index 04af8514d..d5bcf5428 100644 --- a/ssl/s3_clnt.c +++ b/ssl/s3_clnt.c @@ -331,10 +331,8 @@ int ssl3_connect(SSL *s) /* Check if it is anon DH/ECDH, SRP auth */ /* or PSK */ - if (! - (s->s3->tmp. - new_cipher->algorithm_auth & (SSL_aNULL | SSL_aSRP)) - && !(s->s3->tmp.new_cipher->algorithm_mkey & SSL_kPSK)) { + if (!(s->s3->tmp.new_cipher->algorithm_auth & + (SSL_aNULL | SSL_aSRP | SSL_aPSK))) { ret = ssl3_get_server_certificate(s); if (ret <= 0) goto end; @@ -1414,7 +1412,7 @@ int ssl3_get_key_exchange(SSL *s) * Can't skip server key exchange if this is an ephemeral * ciphersuite. */ - if (alg_k & (SSL_kDHE | SSL_kECDHE)) { + if (alg_k & (SSL_kDHE | SSL_kECDHE | SSL_kDHEPSK | SSL_kECDHEPSK)) { SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, SSL_R_UNEXPECTED_MESSAGE); al = SSL_AD_UNEXPECTED_MESSAGE; goto f_err; @@ -1447,8 +1445,8 @@ int ssl3_get_key_exchange(SSL *s) al = SSL_AD_DECODE_ERROR; #ifndef OPENSSL_NO_PSK - if (alg_k & SSL_kPSK) { - char tmp_id_hint[PSK_MAX_IDENTITY_LEN + 1]; + /* PSK ciphersuites are preceded by an identity hint */ + if (alg_k & SSL_PSK) { param_len = 2; if (param_len > n) { @@ -1475,23 +1473,24 @@ int ssl3_get_key_exchange(SSL *s) } param_len += i; - /* - * If received PSK identity hint contains NULL characters, the hint - * is truncated from the first NULL. p may not be ending with NULL, - * so create a NULL-terminated string. - */ - memcpy(tmp_id_hint, p, i); - memset(tmp_id_hint + i, 0, PSK_MAX_IDENTITY_LEN + 1 - i); OPENSSL_free(s->session->psk_identity_hint); - s->session->psk_identity_hint = BUF_strdup(tmp_id_hint); - if (s->session->psk_identity_hint == NULL) { - al = SSL_AD_HANDSHAKE_FAILURE; - SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE); - goto f_err; + if (i != 0) { + s->session->psk_identity_hint = BUF_strndup((char *)p, i); + if (s->session->psk_identity_hint == NULL) { + al = SSL_AD_HANDSHAKE_FAILURE; + SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE); + goto f_err; + } + } else { + s->session->psk_identity_hint = NULL; } p += i; n -= param_len; + } + + /* Nothing else to do for plain PSK or RSAPSK */ + if (alg_k & (SSL_kPSK | SSL_kRSAPSK)) { } else #endif /* !OPENSSL_NO_PSK */ #ifndef OPENSSL_NO_SRP @@ -1661,7 +1660,7 @@ int ssl3_get_key_exchange(SSL *s) if (0) ; #endif #ifndef OPENSSL_NO_DH - else if (alg_k & SSL_kDHE) { + else if (alg_k & (SSL_kDHE | SSL_kDHEPSK)) { if ((dh = DH_new()) == NULL) { SSLerr(SSL_F_SSL3_GET_KEY_EXCHANGE, ERR_R_DH_LIB); goto err; @@ -1742,7 +1741,7 @@ int ssl3_get_key_exchange(SSL *s) #endif /* !OPENSSL_NO_DH */ #ifndef OPENSSL_NO_EC - else if (alg_k & SSL_kECDHE) { + else if (alg_k & (SSL_kECDHE | SSL_kECDHEPSK)) { EC_GROUP *ngroup; const EC_GROUP *group; @@ -1945,8 +1944,8 @@ int ssl3_get_key_exchange(SSL *s) } } } else { - /* aNULL, aSRP or kPSK do not need public keys */ - if (!(alg_a & (SSL_aNULL | SSL_aSRP)) && !(alg_k & SSL_kPSK)) { + /* aNULL, aSRP or PSK do not need public keys */ + if (!(alg_a & (SSL_aNULL | SSL_aSRP)) && !(alg_k & SSL_PSK)) { /* Might be wrong key type, check it */ if (ssl3_check_cert_and_algorithm(s)) /* Otherwise this shouldn't happen */ @@ -2329,6 +2328,9 @@ int ssl3_send_client_key_exchange(SSL *s) { unsigned char *p; int n; +#ifndef OPENSSL_NO_PSK + size_t pskhdrlen = 0; +#endif unsigned long alg_k; #ifndef OPENSSL_NO_RSA unsigned char *q; @@ -2344,17 +2346,89 @@ int ssl3_send_client_key_exchange(SSL *s) #endif unsigned char *pms = NULL; size_t pmslen = 0; + alg_k = s->s3->tmp.new_cipher->algorithm_mkey; if (s->state == SSL3_ST_CW_KEY_EXCH_A) { p = ssl_handshake_start(s); - alg_k = s->s3->tmp.new_cipher->algorithm_mkey; + +#ifndef OPENSSL_NO_PSK + if (alg_k & SSL_PSK) { + int psk_err = 1; + /* + * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a + * \0-terminated identity. The last byte is for us for simulating + * strnlen. + */ + char identity[PSK_MAX_IDENTITY_LEN + 1]; + size_t identitylen; + unsigned char psk[PSK_MAX_PSK_LEN]; + size_t psklen; + + if (s->psk_client_callback == NULL) { + SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, + SSL_R_PSK_NO_CLIENT_CB); + goto err; + } + + memset(identity, 0, sizeof(identity)); + + psklen = s->psk_client_callback(s, s->session->psk_identity_hint, + identity, sizeof(identity) - 1, + psk, sizeof(psk)); + + if (psklen > PSK_MAX_PSK_LEN) { + SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, + ERR_R_INTERNAL_ERROR); + goto psk_err; + } else if (psklen == 0) { + SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, + SSL_R_PSK_IDENTITY_NOT_FOUND); + goto psk_err; + } + + OPENSSL_free(s->s3->tmp.psk); + s->s3->tmp.psk = BUF_memdup(psk, psklen); + OPENSSL_cleanse(psk, psklen); + + if (s->s3->tmp.psk == NULL) + goto memerr; + + s->s3->tmp.psklen = psklen; + + identitylen = strlen(identity); + if (identitylen > PSK_MAX_IDENTITY_LEN) { + SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, + ERR_R_INTERNAL_ERROR); + goto psk_err; + } + OPENSSL_free(s->session->psk_identity); + s->session->psk_identity = BUF_strdup(identity); + if (s->session->psk_identity == NULL) + goto memerr; + + s2n(identitylen, p); + memcpy(p, identity, identitylen); + pskhdrlen = 2 + identitylen; + p += identitylen; + psk_err = 0; + psk_err: + OPENSSL_cleanse(identity, sizeof(identity)); + if (psk_err != 0) { + ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE); + goto err; + } + } + if (alg_k & SSL_kPSK) { + n = 0; + } else +#endif /* Fool emacs indentation */ if (0) { } #ifndef OPENSSL_NO_RSA - else if (alg_k & SSL_kRSA) { + else if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) { RSA *rsa; pmslen = SSL_MAX_MASTER_KEY_LENGTH; pms = OPENSSL_malloc(pmslen); @@ -2414,7 +2488,7 @@ int ssl3_send_client_key_exchange(SSL *s) } #endif #ifndef OPENSSL_NO_DH - else if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd)) { + else if (alg_k & (SSL_kDHE | SSL_kDHr | SSL_kDHd | SSL_kDHEPSK)) { DH *dh_srvr, *dh_clnt; if (s->s3->peer_dh_tmp != NULL) dh_srvr = s->s3->peer_dh_tmp; @@ -2493,7 +2567,7 @@ int ssl3_send_client_key_exchange(SSL *s) #endif #ifndef OPENSSL_NO_EC - else if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe)) { + else if (alg_k & (SSL_kECDHE | SSL_kECDHr | SSL_kECDHe | SSL_kECDHEPSK)) { const EC_GROUP *srvr_group = NULL; EC_KEY *tkey; int ecdh_clnt_cert = 0; @@ -2780,82 +2854,6 @@ int ssl3_send_client_key_exchange(SSL *s) goto err; } } -#endif -#ifndef OPENSSL_NO_PSK - else if (alg_k & SSL_kPSK) { - /* - * The callback needs PSK_MAX_IDENTITY_LEN + 1 bytes to return a - * \0-terminated identity. The last byte is for us for simulating - * strnlen. - */ - char identity[PSK_MAX_IDENTITY_LEN + 2]; - size_t identity_len; - unsigned char *t = NULL; - unsigned int psk_len = 0; - int psk_err = 1; - - n = 0; - if (s->psk_client_callback == NULL) { - SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, - SSL_R_PSK_NO_CLIENT_CB); - goto err; - } - - memset(identity, 0, sizeof(identity)); - /* Allocate maximum size buffer */ - pmslen = PSK_MAX_PSK_LEN * 2 + 4; - pms = OPENSSL_malloc(pmslen); - if (!pms) - goto memerr; - - psk_len = s->psk_client_callback(s, s->session->psk_identity_hint, - identity, sizeof(identity) - 1, - pms, pmslen); - if (psk_len > PSK_MAX_PSK_LEN) { - SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, - ERR_R_INTERNAL_ERROR); - goto psk_err; - } else if (psk_len == 0) { - SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, - SSL_R_PSK_IDENTITY_NOT_FOUND); - goto psk_err; - } - /* Change pmslen to real length */ - pmslen = 2 + psk_len + 2 + psk_len; - identity[PSK_MAX_IDENTITY_LEN + 1] = '\0'; - identity_len = strlen(identity); - if (identity_len > PSK_MAX_IDENTITY_LEN) { - SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, - ERR_R_INTERNAL_ERROR); - goto psk_err; - } - /* create PSK pre_master_secret */ - t = pms; - memmove(pms + psk_len + 4, pms, psk_len); - s2n(psk_len, t); - memset(t, 0, psk_len); - t += psk_len; - s2n(psk_len, t); - - OPENSSL_free(s->session->psk_identity); - s->session->psk_identity = BUF_strdup(identity); - if (s->session->psk_identity == NULL) { - SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, - ERR_R_MALLOC_FAILURE); - goto psk_err; - } - - s2n(identity_len, p); - memcpy(p, identity, identity_len); - n = 2 + identity_len; - psk_err = 0; - psk_err: - OPENSSL_cleanse(identity, sizeof(identity)); - if (psk_err != 0) { - ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE); - goto err; - } - } #endif else { ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE); @@ -2863,6 +2861,10 @@ int ssl3_send_client_key_exchange(SSL *s) goto err; } +#ifndef OPENSSL_NO_PSK + n += pskhdrlen; +#endif + if (!ssl_set_handshake_header(s, SSL3_MT_CLIENT_KEY_EXCHANGE, n)) { ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE); SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, ERR_R_INTERNAL_ERROR); @@ -2876,7 +2878,7 @@ int ssl3_send_client_key_exchange(SSL *s) n = ssl_do_write(s); #ifndef OPENSSL_NO_SRP /* Check for SRP */ - if (s->s3->tmp.new_cipher->algorithm_mkey & SSL_kSRP) { + if (alg_k & SSL_kSRP) { /* * If everything written generate master key: no need to save PMS as * srp_generate_client_master_secret generates it internally. @@ -2900,7 +2902,7 @@ int ssl3_send_client_key_exchange(SSL *s) pms = s->s3->tmp.pms; pmslen = s->s3->tmp.pmslen; } - if (pms == NULL) { + if (pms == NULL && !(alg_k & SSL_kPSK)) { ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR); SSLerr(SSL_F_SSL3_SEND_CLIENT_KEY_EXCHANGE, ERR_R_MALLOC_FAILURE); goto err; @@ -2923,6 +2925,10 @@ int ssl3_send_client_key_exchange(SSL *s) OPENSSL_free(encodedPoint); EC_KEY_free(clnt_ecdh); EVP_PKEY_free(srvr_pub_pkey); +#endif +#ifndef OPENSSL_NO_PSK + OPENSSL_clear_free(s->s3->tmp.psk, s->s3->tmp.psklen); + s->s3->tmp.psk = NULL; #endif s->state = SSL_ST_ERR; return (-1); @@ -3261,7 +3267,7 @@ int ssl3_check_cert_and_algorithm(SSL *s) } #endif #ifndef OPENSSL_NO_RSA - if (alg_k & SSL_kRSA) { + if (alg_k & (SSL_kRSA | SSL_kRSAPSK)) { if (!SSL_C_IS_EXPORT(s->s3->tmp.new_cipher) && !has_bits(i, EVP_PK_RSA | EVP_PKT_ENC)) { SSLerr(SSL_F_SSL3_CHECK_CERT_AND_ALGORITHM,