webrtc/talk/base/nssstreamadapter.cc
jiayl@webrtc.org a576faf82a Enable SCTP and use OPENSSL on Anroid and NSS on other platforms.
It includes unit test fixes to properly initialize SSL if DTLS or SSL random number generator is used in the tests.
The private key and certificate constant strings used in some tests are updated to be compatible with NSS.
A few potentially overflow type conversions caused compiling warning on Windows and they are fixed by importing and using Chromium's checked_cast, which aborts the program if overflow occurs.
It also fixes a leak in nssstreamadapter.cc by releasing the PRFileDesc* in StreamClose.

BUG=2253
R=fischman@webrtc.org, juberti@google.com, wu@webrtc.org

Review URL: https://webrtc-codereview.appspot.com/4679005

git-svn-id: http://webrtc.googlecode.com/svn/trunk@5459 4adac7df-926f-26a2-2b94-8c16560cd09d
2014-01-29 17:45:53 +00:00

1033 lines
27 KiB
C++

/*
* libjingle
* Copyright 2004--2008, Google Inc.
* Copyright 2004--2011, RTFM, Inc.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* 3. The name of the author may not be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
* WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
* EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
* OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
* ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <vector>
#if HAVE_CONFIG_H
#include "config.h"
#endif // HAVE_CONFIG_H
#if HAVE_NSS_SSL_H
#include "talk/base/nssstreamadapter.h"
#include "keyhi.h"
#include "nspr.h"
#include "nss.h"
#include "pk11pub.h"
#include "secerr.h"
#ifdef NSS_SSL_RELATIVE_PATH
#include "ssl.h"
#include "sslerr.h"
#include "sslproto.h"
#else
#include "net/third_party/nss/ssl/ssl.h"
#include "net/third_party/nss/ssl/sslerr.h"
#include "net/third_party/nss/ssl/sslproto.h"
#endif
#include "talk/base/nssidentity.h"
#include "talk/base/safe_conversions.h"
#include "talk/base/thread.h"
namespace talk_base {
PRDescIdentity NSSStreamAdapter::nspr_layer_identity = PR_INVALID_IO_LAYER;
#define UNIMPLEMENTED \
PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); \
LOG(LS_ERROR) \
<< "Call to unimplemented function "<< __FUNCTION__; ASSERT(false)
#ifdef SRTP_AES128_CM_HMAC_SHA1_80
#define HAVE_DTLS_SRTP
#endif
#ifdef HAVE_DTLS_SRTP
// SRTP cipher suite table
struct SrtpCipherMapEntry {
const char* external_name;
PRUint16 cipher_id;
};
// This isn't elegant, but it's better than an external reference
static const SrtpCipherMapEntry kSrtpCipherMap[] = {
{"AES_CM_128_HMAC_SHA1_80", SRTP_AES128_CM_HMAC_SHA1_80 },
{"AES_CM_128_HMAC_SHA1_32", SRTP_AES128_CM_HMAC_SHA1_32 },
{NULL, 0}
};
#endif
// Implementation of NSPR methods
static PRStatus StreamClose(PRFileDesc *socket) {
ASSERT(!socket->lower);
socket->dtor(socket);
return PR_SUCCESS;
}
static PRInt32 StreamRead(PRFileDesc *socket, void *buf, PRInt32 length) {
StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
size_t read;
int error;
StreamResult result = stream->Read(buf, length, &read, &error);
if (result == SR_SUCCESS) {
return checked_cast<PRInt32>(read);
}
if (result == SR_EOS) {
return 0;
}
if (result == SR_BLOCK) {
PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
return -1;
}
PR_SetError(PR_UNKNOWN_ERROR, error);
return -1;
}
static PRInt32 StreamWrite(PRFileDesc *socket, const void *buf,
PRInt32 length) {
StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
size_t written;
int error;
StreamResult result = stream->Write(buf, length, &written, &error);
if (result == SR_SUCCESS) {
return checked_cast<PRInt32>(written);
}
if (result == SR_BLOCK) {
LOG(LS_INFO) <<
"NSSStreamAdapter: write to underlying transport would block";
PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
return -1;
}
LOG(LS_ERROR) << "Write error";
PR_SetError(PR_UNKNOWN_ERROR, error);
return -1;
}
static PRInt32 StreamAvailable(PRFileDesc *socket) {
UNIMPLEMENTED;
return -1;
}
PRInt64 StreamAvailable64(PRFileDesc *socket) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamSync(PRFileDesc *socket) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PROffset32 StreamSeek(PRFileDesc *socket, PROffset32 offset,
PRSeekWhence how) {
UNIMPLEMENTED;
return -1;
}
static PROffset64 StreamSeek64(PRFileDesc *socket, PROffset64 offset,
PRSeekWhence how) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamFileInfo(PRFileDesc *socket, PRFileInfo *info) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRStatus StreamFileInfo64(PRFileDesc *socket, PRFileInfo64 *info) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRInt32 StreamWritev(PRFileDesc *socket, const PRIOVec *iov,
PRInt32 iov_size, PRIntervalTime timeout) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamConnect(PRFileDesc *socket, const PRNetAddr *addr,
PRIntervalTime timeout) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRFileDesc *StreamAccept(PRFileDesc *sd, PRNetAddr *addr,
PRIntervalTime timeout) {
UNIMPLEMENTED;
return NULL;
}
static PRStatus StreamBind(PRFileDesc *socket, const PRNetAddr *addr) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRStatus StreamListen(PRFileDesc *socket, PRIntn depth) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRStatus StreamShutdown(PRFileDesc *socket, PRIntn how) {
UNIMPLEMENTED;
return PR_FAILURE;
}
// Note: this is always nonblocking and ignores the timeout.
// TODO(ekr@rtfm.com): In future verify that the socket is
// actually in non-blocking mode.
// This function does not support peek.
static PRInt32 StreamRecv(PRFileDesc *socket, void *buf, PRInt32 amount,
PRIntn flags, PRIntervalTime to) {
ASSERT(flags == 0);
if (flags != 0) {
PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
return -1;
}
return StreamRead(socket, buf, amount);
}
// Note: this is always nonblocking and assumes a zero timeout.
// This function does not support peek.
static PRInt32 StreamSend(PRFileDesc *socket, const void *buf,
PRInt32 amount, PRIntn flags,
PRIntervalTime to) {
ASSERT(flags == 0);
return StreamWrite(socket, buf, amount);
}
static PRInt32 StreamRecvfrom(PRFileDesc *socket, void *buf,
PRInt32 amount, PRIntn flags,
PRNetAddr *addr, PRIntervalTime to) {
UNIMPLEMENTED;
return -1;
}
static PRInt32 StreamSendto(PRFileDesc *socket, const void *buf,
PRInt32 amount, PRIntn flags,
const PRNetAddr *addr, PRIntervalTime to) {
UNIMPLEMENTED;
return -1;
}
static PRInt16 StreamPoll(PRFileDesc *socket, PRInt16 in_flags,
PRInt16 *out_flags) {
UNIMPLEMENTED;
return -1;
}
static PRInt32 StreamAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
PRNetAddr **raddr,
void *buf, PRInt32 amount, PRIntervalTime t) {
UNIMPLEMENTED;
return -1;
}
static PRInt32 StreamTransmitFile(PRFileDesc *sd, PRFileDesc *socket,
const void *headers, PRInt32 hlen,
PRTransmitFileFlags flags, PRIntervalTime t) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamGetPeerName(PRFileDesc *socket, PRNetAddr *addr) {
// TODO(ekr@rtfm.com): Modify to return unique names for each channel
// somehow, as opposed to always the same static address. The current
// implementation messes up the session cache, which is why it's off
// elsewhere
addr->inet.family = PR_AF_INET;
addr->inet.port = 0;
addr->inet.ip = 0;
return PR_SUCCESS;
}
static PRStatus StreamGetSockName(PRFileDesc *socket, PRNetAddr *addr) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRStatus StreamGetSockOption(PRFileDesc *socket, PRSocketOptionData *opt) {
switch (opt->option) {
case PR_SockOpt_Nonblocking:
opt->value.non_blocking = PR_TRUE;
return PR_SUCCESS;
default:
UNIMPLEMENTED;
break;
}
return PR_FAILURE;
}
// Imitate setting socket options. These are mostly noops.
static PRStatus StreamSetSockOption(PRFileDesc *socket,
const PRSocketOptionData *opt) {
switch (opt->option) {
case PR_SockOpt_Nonblocking:
return PR_SUCCESS;
case PR_SockOpt_NoDelay:
return PR_SUCCESS;
default:
UNIMPLEMENTED;
break;
}
return PR_FAILURE;
}
static PRInt32 StreamSendfile(PRFileDesc *out, PRSendFileData *in,
PRTransmitFileFlags flags, PRIntervalTime to) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamConnectContinue(PRFileDesc *socket, PRInt16 flags) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRIntn StreamReserved(PRFileDesc *socket) {
UNIMPLEMENTED;
return -1;
}
static const struct PRIOMethods nss_methods = {
PR_DESC_LAYERED,
StreamClose,
StreamRead,
StreamWrite,
StreamAvailable,
StreamAvailable64,
StreamSync,
StreamSeek,
StreamSeek64,
StreamFileInfo,
StreamFileInfo64,
StreamWritev,
StreamConnect,
StreamAccept,
StreamBind,
StreamListen,
StreamShutdown,
StreamRecv,
StreamSend,
StreamRecvfrom,
StreamSendto,
StreamPoll,
StreamAcceptRead,
StreamTransmitFile,
StreamGetSockName,
StreamGetPeerName,
StreamReserved,
StreamReserved,
StreamGetSockOption,
StreamSetSockOption,
StreamSendfile,
StreamConnectContinue,
StreamReserved,
StreamReserved,
StreamReserved,
StreamReserved
};
NSSStreamAdapter::NSSStreamAdapter(StreamInterface *stream)
: SSLStreamAdapterHelper(stream),
ssl_fd_(NULL),
cert_ok_(false) {
}
bool NSSStreamAdapter::Init() {
if (nspr_layer_identity == PR_INVALID_IO_LAYER) {
nspr_layer_identity = PR_GetUniqueIdentity("nssstreamadapter");
}
PRFileDesc *pr_fd = PR_CreateIOLayerStub(nspr_layer_identity, &nss_methods);
if (!pr_fd)
return false;
pr_fd->secret = reinterpret_cast<PRFilePrivate *>(stream());
PRFileDesc *ssl_fd;
if (ssl_mode_ == SSL_MODE_DTLS) {
ssl_fd = DTLS_ImportFD(NULL, pr_fd);
} else {
ssl_fd = SSL_ImportFD(NULL, pr_fd);
}
ASSERT(ssl_fd != NULL); // This should never happen
if (!ssl_fd) {
PR_Close(pr_fd);
return false;
}
SECStatus rv;
// Turn on security.
rv = SSL_OptionSet(ssl_fd, SSL_SECURITY, PR_TRUE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error enabling security on SSL Socket";
return false;
}
// Disable SSLv2.
rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SSL2, PR_FALSE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error disabling SSL2";
return false;
}
// Disable caching.
// TODO(ekr@rtfm.com): restore this when I have the caching
// identity set.
rv = SSL_OptionSet(ssl_fd, SSL_NO_CACHE, PR_TRUE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error disabling cache";
return false;
}
// Disable session tickets.
rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error enabling tickets";
return false;
}
// Disable renegotiation.
rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_RENEGOTIATION,
SSL_RENEGOTIATE_NEVER);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error disabling renegotiation";
return false;
}
// Disable false start.
rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_FALSE_START, PR_FALSE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error disabling false start";
return false;
}
ssl_fd_ = ssl_fd;
return true;
}
NSSStreamAdapter::~NSSStreamAdapter() {
if (ssl_fd_)
PR_Close(ssl_fd_);
};
int NSSStreamAdapter::BeginSSL() {
SECStatus rv;
if (!Init()) {
Error("Init", -1, false);
return -1;
}
ASSERT(state_ == SSL_CONNECTING);
// The underlying stream has been opened. If we are in peer-to-peer mode
// then a peer certificate must have been specified by now.
ASSERT(!ssl_server_name_.empty() ||
peer_certificate_.get() != NULL ||
!peer_certificate_digest_algorithm_.empty());
LOG(LS_INFO) << "BeginSSL: "
<< (!ssl_server_name_.empty() ? ssl_server_name_ :
"with peer");
if (role_ == SSL_CLIENT) {
LOG(LS_INFO) << "BeginSSL: as client";
rv = SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
this);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
} else {
LOG(LS_INFO) << "BeginSSL: as server";
NSSIdentity *identity;
if (identity_.get()) {
identity = static_cast<NSSIdentity *>(identity_.get());
} else {
LOG(LS_ERROR) << "Can't be an SSL server without an identity";
Error("BeginSSL", -1, false);
return -1;
}
rv = SSL_ConfigSecureServer(ssl_fd_, identity->certificate().certificate(),
identity->keypair()->privkey(),
kt_rsa);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
// Insist on a certificate from the client
rv = SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
rv = SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, PR_TRUE);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
}
// Set the version range.
SSLVersionRange vrange;
vrange.min = (ssl_mode_ == SSL_MODE_DTLS) ?
SSL_LIBRARY_VERSION_TLS_1_1 :
SSL_LIBRARY_VERSION_TLS_1_0;
vrange.max = SSL_LIBRARY_VERSION_TLS_1_1;
rv = SSL_VersionRangeSet(ssl_fd_, &vrange);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
// SRTP
#ifdef HAVE_DTLS_SRTP
if (!srtp_ciphers_.empty()) {
rv = SSL_SetSRTPCiphers(
ssl_fd_, &srtp_ciphers_[0],
checked_cast<unsigned int>(srtp_ciphers_.size()));
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
}
#endif
// Certificate validation
rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
// Now start the handshake
rv = SSL_ResetHandshake(ssl_fd_, role_ == SSL_SERVER ? PR_TRUE : PR_FALSE);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
return ContinueSSL();
}
int NSSStreamAdapter::ContinueSSL() {
LOG(LS_INFO) << "ContinueSSL";
ASSERT(state_ == SSL_CONNECTING);
// Clear the DTLS timer
Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
SECStatus rv = SSL_ForceHandshake(ssl_fd_);
if (rv == SECSuccess) {
LOG(LS_INFO) << "Handshake complete";
ASSERT(cert_ok_);
if (!cert_ok_) {
Error("ContinueSSL", -1, true);
return -1;
}
state_ = SSL_CONNECTED;
StreamAdapterInterface::OnEvent(stream(), SE_OPEN|SE_READ|SE_WRITE, 0);
return 0;
}
PRInt32 err = PR_GetError();
switch (err) {
case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
if (ssl_mode_ != SSL_MODE_DTLS) {
Error("ContinueSSL", -1, true);
return -1;
} else {
LOG(LS_INFO) << "Malformed DTLS message. Ignoring.";
// Fall through
}
case PR_WOULD_BLOCK_ERROR:
LOG(LS_INFO) << "Would have blocked";
if (ssl_mode_ == SSL_MODE_DTLS) {
PRIntervalTime timeout;
SECStatus rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
if (rv == SECSuccess) {
LOG(LS_INFO) << "Timeout is " << timeout << " ms";
Thread::Current()->PostDelayed(PR_IntervalToMilliseconds(timeout),
this, MSG_DTLS_TIMEOUT, 0);
}
}
return 0;
default:
LOG(LS_INFO) << "Error " << err;
break;
}
Error("ContinueSSL", -1, true);
return -1;
}
void NSSStreamAdapter::Cleanup() {
if (state_ != SSL_ERROR) {
state_ = SSL_CLOSED;
}
if (ssl_fd_) {
PR_Close(ssl_fd_);
ssl_fd_ = NULL;
}
identity_.reset();
peer_certificate_.reset();
Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
}
StreamResult NSSStreamAdapter::Read(void* data, size_t data_len,
size_t* read, int* error) {
// SSL_CONNECTED sanity check.
switch (state_) {
case SSL_NONE:
case SSL_WAIT:
case SSL_CONNECTING:
return SR_BLOCK;
case SSL_CONNECTED:
break;
case SSL_CLOSED:
return SR_EOS;
case SSL_ERROR:
default:
if (error)
*error = ssl_error_code_;
return SR_ERROR;
}
PRInt32 rv = PR_Read(ssl_fd_, data, checked_cast<PRInt32>(data_len));
if (rv == 0) {
return SR_EOS;
}
// Error
if (rv < 0) {
PRInt32 err = PR_GetError();
switch (err) {
case PR_WOULD_BLOCK_ERROR:
return SR_BLOCK;
default:
Error("Read", -1, false);
*error = err; // libjingle semantics are that this is impl-specific
return SR_ERROR;
}
}
// Success
*read = rv;
return SR_SUCCESS;
}
StreamResult NSSStreamAdapter::Write(const void* data, size_t data_len,
size_t* written, int* error) {
// SSL_CONNECTED sanity check.
switch (state_) {
case SSL_NONE:
case SSL_WAIT:
case SSL_CONNECTING:
return SR_BLOCK;
case SSL_CONNECTED:
break;
case SSL_ERROR:
case SSL_CLOSED:
default:
if (error)
*error = ssl_error_code_;
return SR_ERROR;
}
PRInt32 rv = PR_Write(ssl_fd_, data, checked_cast<PRInt32>(data_len));
// Error
if (rv < 0) {
PRInt32 err = PR_GetError();
switch (err) {
case PR_WOULD_BLOCK_ERROR:
return SR_BLOCK;
default:
Error("Write", -1, false);
*error = err; // libjingle semantics are that this is impl-specific
return SR_ERROR;
}
}
// Success
*written = rv;
return SR_SUCCESS;
}
void NSSStreamAdapter::OnEvent(StreamInterface* stream, int events,
int err) {
int events_to_signal = 0;
int signal_error = 0;
ASSERT(stream == this->stream());
if ((events & SE_OPEN)) {
LOG(LS_INFO) << "NSSStreamAdapter::OnEvent SE_OPEN";
if (state_ != SSL_WAIT) {
ASSERT(state_ == SSL_NONE);
events_to_signal |= SE_OPEN;
} else {
state_ = SSL_CONNECTING;
if (int err = BeginSSL()) {
Error("BeginSSL", err, true);
return;
}
}
}
if ((events & (SE_READ|SE_WRITE))) {
LOG(LS_INFO) << "NSSStreamAdapter::OnEvent"
<< ((events & SE_READ) ? " SE_READ" : "")
<< ((events & SE_WRITE) ? " SE_WRITE" : "");
if (state_ == SSL_NONE) {
events_to_signal |= events & (SE_READ|SE_WRITE);
} else if (state_ == SSL_CONNECTING) {
if (int err = ContinueSSL()) {
Error("ContinueSSL", err, true);
return;
}
} else if (state_ == SSL_CONNECTED) {
if (events & SE_WRITE) {
LOG(LS_INFO) << " -- onStreamWriteable";
events_to_signal |= SE_WRITE;
}
if (events & SE_READ) {
LOG(LS_INFO) << " -- onStreamReadable";
events_to_signal |= SE_READ;
}
}
}
if ((events & SE_CLOSE)) {
LOG(LS_INFO) << "NSSStreamAdapter::OnEvent(SE_CLOSE, " << err << ")";
Cleanup();
events_to_signal |= SE_CLOSE;
// SE_CLOSE is the only event that uses the final parameter to OnEvent().
ASSERT(signal_error == 0);
signal_error = err;
}
if (events_to_signal)
StreamAdapterInterface::OnEvent(stream, events_to_signal, signal_error);
}
void NSSStreamAdapter::OnMessage(Message* msg) {
// Process our own messages and then pass others to the superclass
if (MSG_DTLS_TIMEOUT == msg->message_id) {
LOG(LS_INFO) << "DTLS timeout expired";
ContinueSSL();
} else {
StreamInterface::OnMessage(msg);
}
}
// Certificate verification callback. Called to check any certificate
SECStatus NSSStreamAdapter::AuthCertificateHook(void *arg,
PRFileDesc *fd,
PRBool checksig,
PRBool isServer) {
LOG(LS_INFO) << "NSSStreamAdapter::AuthCertificateHook";
NSSCertificate peer_cert(SSL_PeerCertificate(fd));
NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
stream->cert_ok_ = false;
// Read the peer's certificate chain.
CERTCertList* cert_list = SSL_PeerCertificateChain(fd);
ASSERT(cert_list != NULL);
// If the peer provided multiple certificates, check that they form a valid
// chain as defined by RFC 5246 Section 7.4.2: "Each following certificate
// MUST directly certify the one preceding it.". This check does NOT
// verify other requirements, such as whether the chain reaches a trusted
// root, self-signed certificates have valid signatures, certificates are not
// expired, etc.
// Even if the chain is valid, the leaf certificate must still match a
// provided certificate or digest.
if (!NSSCertificate::IsValidChain(cert_list)) {
CERT_DestroyCertList(cert_list);
PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
return SECFailure;
}
if (stream->peer_certificate_.get()) {
LOG(LS_INFO) << "Checking against specified certificate";
// The peer certificate was specified
if (reinterpret_cast<NSSCertificate *>(stream->peer_certificate_.get())->
Equals(&peer_cert)) {
LOG(LS_INFO) << "Accepted peer certificate";
stream->cert_ok_ = true;
}
} else if (!stream->peer_certificate_digest_algorithm_.empty()) {
LOG(LS_INFO) << "Checking against specified digest";
// The peer certificate digest was specified
unsigned char digest[64]; // Maximum size
std::size_t digest_length;
if (!peer_cert.ComputeDigest(
stream->peer_certificate_digest_algorithm_,
digest, sizeof(digest), &digest_length)) {
LOG(LS_ERROR) << "Digest computation failed";
} else {
Buffer computed_digest(digest, digest_length);
if (computed_digest == stream->peer_certificate_digest_value_) {
LOG(LS_INFO) << "Accepted peer certificate";
stream->cert_ok_ = true;
}
}
} else {
// Other modes, but we haven't implemented yet
// TODO(ekr@rtfm.com): Implement real certificate validation
UNIMPLEMENTED;
}
if (!stream->cert_ok_ && stream->ignore_bad_cert()) {
LOG(LS_WARNING) << "Ignoring cert error while verifying cert chain";
stream->cert_ok_ = true;
}
if (stream->cert_ok_)
stream->peer_certificate_.reset(new NSSCertificate(cert_list));
CERT_DestroyCertList(cert_list);
if (stream->cert_ok_)
return SECSuccess;
PORT_SetError(SEC_ERROR_UNTRUSTED_CERT);
return SECFailure;
}
SECStatus NSSStreamAdapter::GetClientAuthDataHook(void *arg, PRFileDesc *fd,
CERTDistNames *caNames,
CERTCertificate **pRetCert,
SECKEYPrivateKey **pRetKey) {
LOG(LS_INFO) << "Client cert requested";
NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
if (!stream->identity_.get()) {
LOG(LS_ERROR) << "No identity available";
return SECFailure;
}
NSSIdentity *identity = static_cast<NSSIdentity *>(stream->identity_.get());
// Destroyed internally by NSS
*pRetCert = CERT_DupCertificate(identity->certificate().certificate());
*pRetKey = SECKEY_CopyPrivateKey(identity->keypair()->privkey());
return SECSuccess;
}
// RFC 5705 Key Exporter
bool NSSStreamAdapter::ExportKeyingMaterial(const std::string& label,
const uint8* context,
size_t context_len,
bool use_context,
uint8* result,
size_t result_len) {
SECStatus rv = SSL_ExportKeyingMaterial(
ssl_fd_,
label.c_str(),
checked_cast<unsigned int>(label.size()),
use_context,
context,
checked_cast<unsigned int>(context_len),
result,
checked_cast<unsigned int>(result_len));
return rv == SECSuccess;
}
bool NSSStreamAdapter::SetDtlsSrtpCiphers(
const std::vector<std::string>& ciphers) {
#ifdef HAVE_DTLS_SRTP
std::vector<PRUint16> internal_ciphers;
if (state_ != SSL_NONE)
return false;
for (std::vector<std::string>::const_iterator cipher = ciphers.begin();
cipher != ciphers.end(); ++cipher) {
bool found = false;
for (const SrtpCipherMapEntry *entry = kSrtpCipherMap; entry->cipher_id;
++entry) {
if (*cipher == entry->external_name) {
found = true;
internal_ciphers.push_back(entry->cipher_id);
break;
}
}
if (!found) {
LOG(LS_ERROR) << "Could not find cipher: " << *cipher;
return false;
}
}
if (internal_ciphers.empty())
return false;
srtp_ciphers_ = internal_ciphers;
return true;
#else
return false;
#endif
}
bool NSSStreamAdapter::GetDtlsSrtpCipher(std::string* cipher) {
#ifdef HAVE_DTLS_SRTP
ASSERT(state_ == SSL_CONNECTED);
if (state_ != SSL_CONNECTED)
return false;
PRUint16 selected_cipher;
SECStatus rv = SSL_GetSRTPCipher(ssl_fd_, &selected_cipher);
if (rv == SECFailure)
return false;
for (const SrtpCipherMapEntry *entry = kSrtpCipherMap;
entry->cipher_id; ++entry) {
if (selected_cipher == entry->cipher_id) {
*cipher = entry->external_name;
return true;
}
}
ASSERT(false); // This should never happen
#endif
return false;
}
bool NSSContext::initialized;
NSSContext *NSSContext::global_nss_context;
// Static initialization and shutdown
NSSContext *NSSContext::Instance() {
if (!global_nss_context) {
NSSContext *new_ctx = new NSSContext();
if (!(new_ctx->slot_ = PK11_GetInternalSlot())) {
delete new_ctx;
goto fail;
}
global_nss_context = new_ctx;
}
fail:
return global_nss_context;
}
bool NSSContext::InitializeSSL(VerificationCallback callback) {
ASSERT(!callback);
if (!initialized) {
SECStatus rv;
rv = NSS_NoDB_Init(NULL);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Couldn't initialize NSS error=" <<
PORT_GetError();
return false;
}
NSS_SetDomesticPolicy();
initialized = true;
}
return true;
}
bool NSSContext::InitializeSSLThread() {
// Not needed
return true;
}
bool NSSContext::CleanupSSL() {
// Not needed
return true;
}
bool NSSStreamAdapter::HaveDtls() {
return true;
}
bool NSSStreamAdapter::HaveDtlsSrtp() {
#ifdef HAVE_DTLS_SRTP
return true;
#else
return false;
#endif
}
bool NSSStreamAdapter::HaveExporter() {
return true;
}
} // namespace talk_base
#endif // HAVE_NSS_SSL_H