webrtc/talk/base/nssstreamadapter.cc

1039 lines
27 KiB
C++
Raw Normal View History

/*
* libjingle
* Copyright 2004--2008, Google Inc.
* Copyright 2004--2011, RTFM, Inc.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
* 3. The name of the author may not be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
* WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
* EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
* WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
* OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
* ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <vector>
#if HAVE_CONFIG_H
#include "config.h"
#endif // HAVE_CONFIG_H
#if HAVE_NSS_SSL_H
#include "talk/base/nssstreamadapter.h"
#include "keyhi.h"
#include "nspr.h"
#include "nss.h"
#include "pk11pub.h"
#include "secerr.h"
#ifdef NSS_SSL_RELATIVE_PATH
#include "ssl.h"
#include "sslerr.h"
#include "sslproto.h"
#else
#include "net/third_party/nss/ssl/ssl.h"
#include "net/third_party/nss/ssl/sslerr.h"
#include "net/third_party/nss/ssl/sslproto.h"
#endif
#include "talk/base/nssidentity.h"
#include "talk/base/safe_conversions.h"
#include "talk/base/thread.h"
namespace talk_base {
PRDescIdentity NSSStreamAdapter::nspr_layer_identity = PR_INVALID_IO_LAYER;
#define UNIMPLEMENTED \
PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); \
LOG(LS_ERROR) \
<< "Call to unimplemented function "<< __FUNCTION__; ASSERT(false)
#ifdef SRTP_AES128_CM_HMAC_SHA1_80
#define HAVE_DTLS_SRTP
#endif
#ifdef HAVE_DTLS_SRTP
// SRTP cipher suite table
struct SrtpCipherMapEntry {
const char* external_name;
PRUint16 cipher_id;
};
// This isn't elegant, but it's better than an external reference
static const SrtpCipherMapEntry kSrtpCipherMap[] = {
{"AES_CM_128_HMAC_SHA1_80", SRTP_AES128_CM_HMAC_SHA1_80 },
{"AES_CM_128_HMAC_SHA1_32", SRTP_AES128_CM_HMAC_SHA1_32 },
{NULL, 0}
};
#endif
// Implementation of NSPR methods
static PRStatus StreamClose(PRFileDesc *socket) {
ASSERT(!socket->lower);
socket->dtor(socket);
return PR_SUCCESS;
}
static PRInt32 StreamRead(PRFileDesc *socket, void *buf, PRInt32 length) {
StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
size_t read;
int error;
StreamResult result = stream->Read(buf, length, &read, &error);
if (result == SR_SUCCESS) {
return checked_cast<PRInt32>(read);
}
if (result == SR_EOS) {
return 0;
}
if (result == SR_BLOCK) {
PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
return -1;
}
PR_SetError(PR_UNKNOWN_ERROR, error);
return -1;
}
static PRInt32 StreamWrite(PRFileDesc *socket, const void *buf,
PRInt32 length) {
StreamInterface *stream = reinterpret_cast<StreamInterface *>(socket->secret);
size_t written;
int error;
StreamResult result = stream->Write(buf, length, &written, &error);
if (result == SR_SUCCESS) {
return checked_cast<PRInt32>(written);
}
if (result == SR_BLOCK) {
LOG(LS_INFO) <<
"NSSStreamAdapter: write to underlying transport would block";
PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
return -1;
}
LOG(LS_ERROR) << "Write error";
PR_SetError(PR_UNKNOWN_ERROR, error);
return -1;
}
static PRInt32 StreamAvailable(PRFileDesc *socket) {
UNIMPLEMENTED;
return -1;
}
PRInt64 StreamAvailable64(PRFileDesc *socket) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamSync(PRFileDesc *socket) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PROffset32 StreamSeek(PRFileDesc *socket, PROffset32 offset,
PRSeekWhence how) {
UNIMPLEMENTED;
return -1;
}
static PROffset64 StreamSeek64(PRFileDesc *socket, PROffset64 offset,
PRSeekWhence how) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamFileInfo(PRFileDesc *socket, PRFileInfo *info) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRStatus StreamFileInfo64(PRFileDesc *socket, PRFileInfo64 *info) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRInt32 StreamWritev(PRFileDesc *socket, const PRIOVec *iov,
PRInt32 iov_size, PRIntervalTime timeout) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamConnect(PRFileDesc *socket, const PRNetAddr *addr,
PRIntervalTime timeout) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRFileDesc *StreamAccept(PRFileDesc *sd, PRNetAddr *addr,
PRIntervalTime timeout) {
UNIMPLEMENTED;
return NULL;
}
static PRStatus StreamBind(PRFileDesc *socket, const PRNetAddr *addr) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRStatus StreamListen(PRFileDesc *socket, PRIntn depth) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRStatus StreamShutdown(PRFileDesc *socket, PRIntn how) {
UNIMPLEMENTED;
return PR_FAILURE;
}
// Note: this is always nonblocking and ignores the timeout.
// TODO(ekr@rtfm.com): In future verify that the socket is
// actually in non-blocking mode.
// This function does not support peek.
static PRInt32 StreamRecv(PRFileDesc *socket, void *buf, PRInt32 amount,
PRIntn flags, PRIntervalTime to) {
ASSERT(flags == 0);
if (flags != 0) {
PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
return -1;
}
return StreamRead(socket, buf, amount);
}
// Note: this is always nonblocking and assumes a zero timeout.
// This function does not support peek.
static PRInt32 StreamSend(PRFileDesc *socket, const void *buf,
PRInt32 amount, PRIntn flags,
PRIntervalTime to) {
ASSERT(flags == 0);
return StreamWrite(socket, buf, amount);
}
static PRInt32 StreamRecvfrom(PRFileDesc *socket, void *buf,
PRInt32 amount, PRIntn flags,
PRNetAddr *addr, PRIntervalTime to) {
UNIMPLEMENTED;
return -1;
}
static PRInt32 StreamSendto(PRFileDesc *socket, const void *buf,
PRInt32 amount, PRIntn flags,
const PRNetAddr *addr, PRIntervalTime to) {
UNIMPLEMENTED;
return -1;
}
static PRInt16 StreamPoll(PRFileDesc *socket, PRInt16 in_flags,
PRInt16 *out_flags) {
UNIMPLEMENTED;
return -1;
}
static PRInt32 StreamAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
PRNetAddr **raddr,
void *buf, PRInt32 amount, PRIntervalTime t) {
UNIMPLEMENTED;
return -1;
}
static PRInt32 StreamTransmitFile(PRFileDesc *sd, PRFileDesc *socket,
const void *headers, PRInt32 hlen,
PRTransmitFileFlags flags, PRIntervalTime t) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamGetPeerName(PRFileDesc *socket, PRNetAddr *addr) {
// TODO(ekr@rtfm.com): Modify to return unique names for each channel
// somehow, as opposed to always the same static address. The current
// implementation messes up the session cache, which is why it's off
// elsewhere
addr->inet.family = PR_AF_INET;
addr->inet.port = 0;
addr->inet.ip = 0;
return PR_SUCCESS;
}
static PRStatus StreamGetSockName(PRFileDesc *socket, PRNetAddr *addr) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRStatus StreamGetSockOption(PRFileDesc *socket, PRSocketOptionData *opt) {
switch (opt->option) {
case PR_SockOpt_Nonblocking:
opt->value.non_blocking = PR_TRUE;
return PR_SUCCESS;
default:
UNIMPLEMENTED;
break;
}
return PR_FAILURE;
}
// Imitate setting socket options. These are mostly noops.
static PRStatus StreamSetSockOption(PRFileDesc *socket,
const PRSocketOptionData *opt) {
switch (opt->option) {
case PR_SockOpt_Nonblocking:
return PR_SUCCESS;
case PR_SockOpt_NoDelay:
return PR_SUCCESS;
default:
UNIMPLEMENTED;
break;
}
return PR_FAILURE;
}
static PRInt32 StreamSendfile(PRFileDesc *out, PRSendFileData *in,
PRTransmitFileFlags flags, PRIntervalTime to) {
UNIMPLEMENTED;
return -1;
}
static PRStatus StreamConnectContinue(PRFileDesc *socket, PRInt16 flags) {
UNIMPLEMENTED;
return PR_FAILURE;
}
static PRIntn StreamReserved(PRFileDesc *socket) {
UNIMPLEMENTED;
return -1;
}
static const struct PRIOMethods nss_methods = {
PR_DESC_LAYERED,
StreamClose,
StreamRead,
StreamWrite,
StreamAvailable,
StreamAvailable64,
StreamSync,
StreamSeek,
StreamSeek64,
StreamFileInfo,
StreamFileInfo64,
StreamWritev,
StreamConnect,
StreamAccept,
StreamBind,
StreamListen,
StreamShutdown,
StreamRecv,
StreamSend,
StreamRecvfrom,
StreamSendto,
StreamPoll,
StreamAcceptRead,
StreamTransmitFile,
StreamGetSockName,
StreamGetPeerName,
StreamReserved,
StreamReserved,
StreamGetSockOption,
StreamSetSockOption,
StreamSendfile,
StreamConnectContinue,
StreamReserved,
StreamReserved,
StreamReserved,
StreamReserved
};
NSSStreamAdapter::NSSStreamAdapter(StreamInterface *stream)
: SSLStreamAdapterHelper(stream),
ssl_fd_(NULL),
cert_ok_(false) {
}
bool NSSStreamAdapter::Init() {
if (nspr_layer_identity == PR_INVALID_IO_LAYER) {
nspr_layer_identity = PR_GetUniqueIdentity("nssstreamadapter");
}
PRFileDesc *pr_fd = PR_CreateIOLayerStub(nspr_layer_identity, &nss_methods);
if (!pr_fd)
return false;
pr_fd->secret = reinterpret_cast<PRFilePrivate *>(stream());
PRFileDesc *ssl_fd;
if (ssl_mode_ == SSL_MODE_DTLS) {
ssl_fd = DTLS_ImportFD(NULL, pr_fd);
} else {
ssl_fd = SSL_ImportFD(NULL, pr_fd);
}
ASSERT(ssl_fd != NULL); // This should never happen
if (!ssl_fd) {
PR_Close(pr_fd);
return false;
}
SECStatus rv;
// Turn on security.
rv = SSL_OptionSet(ssl_fd, SSL_SECURITY, PR_TRUE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error enabling security on SSL Socket";
return false;
}
// Disable SSLv2.
rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SSL2, PR_FALSE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error disabling SSL2";
return false;
}
// Disable caching.
// TODO(ekr@rtfm.com): restore this when I have the caching
// identity set.
rv = SSL_OptionSet(ssl_fd, SSL_NO_CACHE, PR_TRUE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error disabling cache";
return false;
}
// Disable session tickets.
rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_SESSION_TICKETS, PR_FALSE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error enabling tickets";
return false;
}
// Disable renegotiation.
rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_RENEGOTIATION,
SSL_RENEGOTIATE_NEVER);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error disabling renegotiation";
return false;
}
// Disable false start.
rv = SSL_OptionSet(ssl_fd, SSL_ENABLE_FALSE_START, PR_FALSE);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Error disabling false start";
return false;
}
ssl_fd_ = ssl_fd;
return true;
}
NSSStreamAdapter::~NSSStreamAdapter() {
if (ssl_fd_)
PR_Close(ssl_fd_);
};
int NSSStreamAdapter::BeginSSL() {
SECStatus rv;
if (!Init()) {
Error("Init", -1, false);
return -1;
}
ASSERT(state_ == SSL_CONNECTING);
// The underlying stream has been opened. If we are in peer-to-peer mode
// then a peer certificate must have been specified by now.
ASSERT(!ssl_server_name_.empty() ||
peer_certificate_.get() != NULL ||
!peer_certificate_digest_algorithm_.empty());
LOG(LS_INFO) << "BeginSSL: "
<< (!ssl_server_name_.empty() ? ssl_server_name_ :
"with peer");
if (role_ == SSL_CLIENT) {
LOG(LS_INFO) << "BeginSSL: as client";
rv = SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
this);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
} else {
LOG(LS_INFO) << "BeginSSL: as server";
NSSIdentity *identity;
if (identity_.get()) {
identity = static_cast<NSSIdentity *>(identity_.get());
} else {
LOG(LS_ERROR) << "Can't be an SSL server without an identity";
Error("BeginSSL", -1, false);
return -1;
}
rv = SSL_ConfigSecureServer(ssl_fd_, identity->certificate().certificate(),
identity->keypair()->privkey(),
kt_rsa);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
// Insist on a certificate from the client
rv = SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
rv = SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, PR_TRUE);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
}
// Set the version range.
SSLVersionRange vrange;
vrange.min = (ssl_mode_ == SSL_MODE_DTLS) ?
SSL_LIBRARY_VERSION_TLS_1_1 :
SSL_LIBRARY_VERSION_TLS_1_0;
vrange.max = SSL_LIBRARY_VERSION_TLS_1_1;
rv = SSL_VersionRangeSet(ssl_fd_, &vrange);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
// SRTP
#ifdef HAVE_DTLS_SRTP
if (!srtp_ciphers_.empty()) {
rv = SSL_SetSRTPCiphers(
ssl_fd_, &srtp_ciphers_[0],
checked_cast<unsigned int>(srtp_ciphers_.size()));
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
}
#endif
// Certificate validation
rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
// Now start the handshake
rv = SSL_ResetHandshake(ssl_fd_, role_ == SSL_SERVER ? PR_TRUE : PR_FALSE);
if (rv != SECSuccess) {
Error("BeginSSL", -1, false);
return -1;
}
return ContinueSSL();
}
int NSSStreamAdapter::ContinueSSL() {
LOG(LS_INFO) << "ContinueSSL";
ASSERT(state_ == SSL_CONNECTING);
// Clear the DTLS timer
Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
SECStatus rv = SSL_ForceHandshake(ssl_fd_);
if (rv == SECSuccess) {
LOG(LS_INFO) << "Handshake complete";
ASSERT(cert_ok_);
if (!cert_ok_) {
Error("ContinueSSL", -1, true);
return -1;
}
state_ = SSL_CONNECTED;
StreamAdapterInterface::OnEvent(stream(), SE_OPEN|SE_READ|SE_WRITE, 0);
return 0;
}
PRInt32 err = PR_GetError();
switch (err) {
case SSL_ERROR_RX_MALFORMED_HANDSHAKE:
if (ssl_mode_ != SSL_MODE_DTLS) {
Error("ContinueSSL", -1, true);
return -1;
} else {
LOG(LS_INFO) << "Malformed DTLS message. Ignoring.";
// Fall through
}
case PR_WOULD_BLOCK_ERROR:
LOG(LS_INFO) << "Would have blocked";
if (ssl_mode_ == SSL_MODE_DTLS) {
PRIntervalTime timeout;
SECStatus rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
if (rv == SECSuccess) {
LOG(LS_INFO) << "Timeout is " << timeout << " ms";
Thread::Current()->PostDelayed(PR_IntervalToMilliseconds(timeout),
this, MSG_DTLS_TIMEOUT, 0);
}
}
return 0;
default:
LOG(LS_INFO) << "Error " << err;
break;
}
Error("ContinueSSL", -1, true);
return -1;
}
void NSSStreamAdapter::Cleanup() {
if (state_ != SSL_ERROR) {
state_ = SSL_CLOSED;
}
if (ssl_fd_) {
PR_Close(ssl_fd_);
ssl_fd_ = NULL;
}
identity_.reset();
peer_certificate_.reset();
Thread::Current()->Clear(this, MSG_DTLS_TIMEOUT);
}
StreamResult NSSStreamAdapter::Read(void* data, size_t data_len,
size_t* read, int* error) {
// SSL_CONNECTED sanity check.
switch (state_) {
case SSL_NONE:
case SSL_WAIT:
case SSL_CONNECTING:
return SR_BLOCK;
case SSL_CONNECTED:
break;
case SSL_CLOSED:
return SR_EOS;
case SSL_ERROR:
default:
if (error)
*error = ssl_error_code_;
return SR_ERROR;
}
PRInt32 rv = PR_Read(ssl_fd_, data, checked_cast<PRInt32>(data_len));
if (rv == 0) {
return SR_EOS;
}
// Error
if (rv < 0) {
PRInt32 err = PR_GetError();
switch (err) {
case PR_WOULD_BLOCK_ERROR:
return SR_BLOCK;
default:
Error("Read", -1, false);
*error = err; // libjingle semantics are that this is impl-specific
return SR_ERROR;
}
}
// Success
*read = rv;
return SR_SUCCESS;
}
StreamResult NSSStreamAdapter::Write(const void* data, size_t data_len,
size_t* written, int* error) {
// SSL_CONNECTED sanity check.
switch (state_) {
case SSL_NONE:
case SSL_WAIT:
case SSL_CONNECTING:
return SR_BLOCK;
case SSL_CONNECTED:
break;
case SSL_ERROR:
case SSL_CLOSED:
default:
if (error)
*error = ssl_error_code_;
return SR_ERROR;
}
PRInt32 rv = PR_Write(ssl_fd_, data, checked_cast<PRInt32>(data_len));
// Error
if (rv < 0) {
PRInt32 err = PR_GetError();
switch (err) {
case PR_WOULD_BLOCK_ERROR:
return SR_BLOCK;
default:
Error("Write", -1, false);
*error = err; // libjingle semantics are that this is impl-specific
return SR_ERROR;
}
}
// Success
*written = rv;
return SR_SUCCESS;
}
void NSSStreamAdapter::OnEvent(StreamInterface* stream, int events,
int err) {
int events_to_signal = 0;
int signal_error = 0;
ASSERT(stream == this->stream());
if ((events & SE_OPEN)) {
LOG(LS_INFO) << "NSSStreamAdapter::OnEvent SE_OPEN";
if (state_ != SSL_WAIT) {
ASSERT(state_ == SSL_NONE);
events_to_signal |= SE_OPEN;
} else {
state_ = SSL_CONNECTING;
if (int err = BeginSSL()) {
Error("BeginSSL", err, true);
return;
}
}
}
if ((events & (SE_READ|SE_WRITE))) {
LOG(LS_INFO) << "NSSStreamAdapter::OnEvent"
<< ((events & SE_READ) ? " SE_READ" : "")
<< ((events & SE_WRITE) ? " SE_WRITE" : "");
if (state_ == SSL_NONE) {
events_to_signal |= events & (SE_READ|SE_WRITE);
} else if (state_ == SSL_CONNECTING) {
if (int err = ContinueSSL()) {
Error("ContinueSSL", err, true);
return;
}
} else if (state_ == SSL_CONNECTED) {
if (events & SE_WRITE) {
LOG(LS_INFO) << " -- onStreamWriteable";
events_to_signal |= SE_WRITE;
}
if (events & SE_READ) {
LOG(LS_INFO) << " -- onStreamReadable";
events_to_signal |= SE_READ;
}
}
}
if ((events & SE_CLOSE)) {
LOG(LS_INFO) << "NSSStreamAdapter::OnEvent(SE_CLOSE, " << err << ")";
Cleanup();
events_to_signal |= SE_CLOSE;
// SE_CLOSE is the only event that uses the final parameter to OnEvent().
ASSERT(signal_error == 0);
signal_error = err;
}
if (events_to_signal)
StreamAdapterInterface::OnEvent(stream, events_to_signal, signal_error);
}
void NSSStreamAdapter::OnMessage(Message* msg) {
// Process our own messages and then pass others to the superclass
if (MSG_DTLS_TIMEOUT == msg->message_id) {
LOG(LS_INFO) << "DTLS timeout expired";
ContinueSSL();
} else {
StreamInterface::OnMessage(msg);
}
}
// Certificate verification callback. Called to check any certificate
SECStatus NSSStreamAdapter::AuthCertificateHook(void *arg,
PRFileDesc *fd,
PRBool checksig,
PRBool isServer) {
LOG(LS_INFO) << "NSSStreamAdapter::AuthCertificateHook";
// SSL_PeerCertificate returns a pointer that is owned by the caller, and
// the NSSCertificate constructor copies its argument, so |raw_peer_cert|
// must be destroyed in this function.
CERTCertificate* raw_peer_cert = SSL_PeerCertificate(fd);
NSSCertificate peer_cert(raw_peer_cert);
CERT_DestroyCertificate(raw_peer_cert);
NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
stream->cert_ok_ = false;
// Read the peer's certificate chain.
CERTCertList* cert_list = SSL_PeerCertificateChain(fd);
ASSERT(cert_list != NULL);
// If the peer provided multiple certificates, check that they form a valid
// chain as defined by RFC 5246 Section 7.4.2: "Each following certificate
// MUST directly certify the one preceding it.". This check does NOT
// verify other requirements, such as whether the chain reaches a trusted
// root, self-signed certificates have valid signatures, certificates are not
// expired, etc.
// Even if the chain is valid, the leaf certificate must still match a
// provided certificate or digest.
if (!NSSCertificate::IsValidChain(cert_list)) {
CERT_DestroyCertList(cert_list);
PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
return SECFailure;
}
if (stream->peer_certificate_.get()) {
LOG(LS_INFO) << "Checking against specified certificate";
// The peer certificate was specified
if (reinterpret_cast<NSSCertificate *>(stream->peer_certificate_.get())->
Equals(&peer_cert)) {
LOG(LS_INFO) << "Accepted peer certificate";
stream->cert_ok_ = true;
}
} else if (!stream->peer_certificate_digest_algorithm_.empty()) {
LOG(LS_INFO) << "Checking against specified digest";
// The peer certificate digest was specified
unsigned char digest[64]; // Maximum size
size_t digest_length;
if (!peer_cert.ComputeDigest(
stream->peer_certificate_digest_algorithm_,
digest, sizeof(digest), &digest_length)) {
LOG(LS_ERROR) << "Digest computation failed";
} else {
Buffer computed_digest(digest, digest_length);
if (computed_digest == stream->peer_certificate_digest_value_) {
LOG(LS_INFO) << "Accepted peer certificate";
stream->cert_ok_ = true;
}
}
} else {
// Other modes, but we haven't implemented yet
// TODO(ekr@rtfm.com): Implement real certificate validation
UNIMPLEMENTED;
}
if (!stream->cert_ok_ && stream->ignore_bad_cert()) {
LOG(LS_WARNING) << "Ignoring cert error while verifying cert chain";
stream->cert_ok_ = true;
}
if (stream->cert_ok_)
stream->peer_certificate_.reset(new NSSCertificate(cert_list));
CERT_DestroyCertList(cert_list);
if (stream->cert_ok_)
return SECSuccess;
PORT_SetError(SEC_ERROR_UNTRUSTED_CERT);
return SECFailure;
}
SECStatus NSSStreamAdapter::GetClientAuthDataHook(void *arg, PRFileDesc *fd,
CERTDistNames *caNames,
CERTCertificate **pRetCert,
SECKEYPrivateKey **pRetKey) {
LOG(LS_INFO) << "Client cert requested";
NSSStreamAdapter *stream = reinterpret_cast<NSSStreamAdapter *>(arg);
if (!stream->identity_.get()) {
LOG(LS_ERROR) << "No identity available";
return SECFailure;
}
NSSIdentity *identity = static_cast<NSSIdentity *>(stream->identity_.get());
// Destroyed internally by NSS
*pRetCert = CERT_DupCertificate(identity->certificate().certificate());
*pRetKey = SECKEY_CopyPrivateKey(identity->keypair()->privkey());
return SECSuccess;
}
// RFC 5705 Key Exporter
bool NSSStreamAdapter::ExportKeyingMaterial(const std::string& label,
const uint8* context,
size_t context_len,
bool use_context,
uint8* result,
size_t result_len) {
SECStatus rv = SSL_ExportKeyingMaterial(
ssl_fd_,
label.c_str(),
checked_cast<unsigned int>(label.size()),
use_context,
context,
checked_cast<unsigned int>(context_len),
result,
checked_cast<unsigned int>(result_len));
return rv == SECSuccess;
}
bool NSSStreamAdapter::SetDtlsSrtpCiphers(
const std::vector<std::string>& ciphers) {
#ifdef HAVE_DTLS_SRTP
std::vector<PRUint16> internal_ciphers;
if (state_ != SSL_NONE)
return false;
for (std::vector<std::string>::const_iterator cipher = ciphers.begin();
cipher != ciphers.end(); ++cipher) {
bool found = false;
for (const SrtpCipherMapEntry *entry = kSrtpCipherMap; entry->cipher_id;
++entry) {
if (*cipher == entry->external_name) {
found = true;
internal_ciphers.push_back(entry->cipher_id);
break;
}
}
if (!found) {
LOG(LS_ERROR) << "Could not find cipher: " << *cipher;
return false;
}
}
if (internal_ciphers.empty())
return false;
srtp_ciphers_ = internal_ciphers;
return true;
#else
return false;
#endif
}
bool NSSStreamAdapter::GetDtlsSrtpCipher(std::string* cipher) {
#ifdef HAVE_DTLS_SRTP
ASSERT(state_ == SSL_CONNECTED);
if (state_ != SSL_CONNECTED)
return false;
PRUint16 selected_cipher;
SECStatus rv = SSL_GetSRTPCipher(ssl_fd_, &selected_cipher);
if (rv == SECFailure)
return false;
for (const SrtpCipherMapEntry *entry = kSrtpCipherMap;
entry->cipher_id; ++entry) {
if (selected_cipher == entry->cipher_id) {
*cipher = entry->external_name;
return true;
}
}
ASSERT(false); // This should never happen
#endif
return false;
}
bool NSSContext::initialized;
NSSContext *NSSContext::global_nss_context;
// Static initialization and shutdown
NSSContext *NSSContext::Instance() {
if (!global_nss_context) {
NSSContext *new_ctx = new NSSContext();
if (!(new_ctx->slot_ = PK11_GetInternalSlot())) {
delete new_ctx;
goto fail;
}
global_nss_context = new_ctx;
}
fail:
return global_nss_context;
}
bool NSSContext::InitializeSSL(VerificationCallback callback) {
ASSERT(!callback);
if (!initialized) {
SECStatus rv;
rv = NSS_NoDB_Init(NULL);
if (rv != SECSuccess) {
LOG(LS_ERROR) << "Couldn't initialize NSS error=" <<
PORT_GetError();
return false;
}
NSS_SetDomesticPolicy();
initialized = true;
}
return true;
}
bool NSSContext::InitializeSSLThread() {
// Not needed
return true;
}
bool NSSContext::CleanupSSL() {
// Not needed
return true;
}
bool NSSStreamAdapter::HaveDtls() {
return true;
}
bool NSSStreamAdapter::HaveDtlsSrtp() {
#ifdef HAVE_DTLS_SRTP
return true;
#else
return false;
#endif
}
bool NSSStreamAdapter::HaveExporter() {
return true;
}
} // namespace talk_base
#endif // HAVE_NSS_SSL_H