diff --git a/MongoDB/include/Poco/MongoDB/PoolableConnectionFactory.h b/MongoDB/include/Poco/MongoDB/PoolableConnectionFactory.h index 573310ad6..60115b188 100644 --- a/MongoDB/include/Poco/MongoDB/PoolableConnectionFactory.h +++ b/MongoDB/include/Poco/MongoDB/PoolableConnectionFactory.h @@ -28,24 +28,38 @@ namespace Poco { template<> class PoolableObjectFactory /// PoolableObjectFactory specialisation for Connection. New connections - /// are created with the given address. + /// are created with the given address or URI. + /// + /// If a Connection::SocketFactory is given, it must live for the entire + /// lifetime of the PoolableObjectFactory. { public: PoolableObjectFactory(Net::SocketAddress& address): - _address(address) + _address(address), + _pSocketFactory(0) { } PoolableObjectFactory(const std::string& address): - _address(address) + _address(address), + _pSocketFactory(0) + { + } + + PoolableObjectFactory(const std::string& uri, MongoDB::Connection::SocketFactory& socketFactory): + _uri(uri), + _pSocketFactory(&socketFactory) { } MongoDB::Connection::Ptr createObject() { - return new MongoDB::Connection(_address); + if (_pSocketFactory) + return new MongoDB::Connection(_uri, *_pSocketFactory); + else + return new MongoDB::Connection(_address); } - + bool validateObject(MongoDB::Connection::Ptr pObject) { return true; @@ -65,6 +79,8 @@ public: private: Net::SocketAddress _address; + std::string _uri; + MongoDB::Connection::SocketFactory* _pSocketFactory; }; diff --git a/MongoDB/src/Connection.cpp b/MongoDB/src/Connection.cpp index a5944f062..56bb192ce 100644 --- a/MongoDB/src/Connection.cpp +++ b/MongoDB/src/Connection.cpp @@ -152,35 +152,38 @@ void Connection::connect(const std::string& uri, SocketFactory& socketFactory) std::string host = theURI.getHost(); Poco::UInt16 port = theURI.getPort(); if (port == 0) port = 27017; + std::string databaseName = theURI.getPath(); + if (!databaseName.empty() && databaseName[0] == '/') databaseName.erase(0, 1); if (databaseName.empty()) databaseName = "admin"; - bool secure = false; + + bool ssl = false; Poco::Timespan connectTimeout; Poco::Timespan socketTimeout; - std::string authMethod = Database::AUTH_SCRAM_SHA1; + std::string authMechanism = Database::AUTH_SCRAM_SHA1; Poco::URI::QueryParameters params = theURI.getQueryParameters(); for (Poco::URI::QueryParameters::const_iterator it = params.begin(); it != params.end(); ++it) { if (it->first == "ssl") { - secure = (it->second == "true"); + ssl = (it->second == "true"); } else if (it->first == "connectTimeoutMS") { - connectTimeout = 1000*Poco::NumberParser::parse(it->second); + connectTimeout = static_cast(1000)*Poco::NumberParser::parse(it->second); } else if (it->first == "socketTimeoutMS") { - socketTimeout = 1000*Poco::NumberParser::parse(it->second); + socketTimeout = static_cast(1000)*Poco::NumberParser::parse(it->second); } else if (it->first == "authMechanism") { - authMethod = it->second; + authMechanism = it->second; } } - connect(socketFactory.createSocket(host, port, connectTimeout, secure)); + connect(socketFactory.createSocket(host, port, connectTimeout, ssl)); if (socketTimeout > 0) { @@ -201,7 +204,7 @@ void Connection::connect(const std::string& uri, SocketFactory& socketFactory) else username = userInfo; Database database(databaseName); - if (!database.authenticate(*this, username, password, authMethod)) + if (!database.authenticate(*this, username, password, authMechanism)) throw Poco::NoPermissionException(Poco::format("Access to MongoDB database %s denied for user %s", databaseName, username)); } } diff --git a/MongoDB/testsuite/src/MongoDBTest.cpp b/MongoDB/testsuite/src/MongoDBTest.cpp index 0f8060a7d..6703f2764 100644 --- a/MongoDB/testsuite/src/MongoDBTest.cpp +++ b/MongoDB/testsuite/src/MongoDBTest.cpp @@ -221,10 +221,10 @@ void MongoDBTest::testDeleteRequest() void MongoDBTest::testCursorRequest() { Poco::MongoDB::Database db("team"); - + Poco::SharedPtr deleteRequest = db.createDeleteRequest("numbers"); _mongo->sendRequest(*deleteRequest); - + Poco::SharedPtr insertRequest = db.createInsertRequest("numbers"); for(int i = 0; i < 10000; ++i) { @@ -395,6 +395,51 @@ void MongoDBTest::testUUID() } +void MongoDBTest::testConnectURI() +{ + Poco::MongoDB::Connection conn; + Poco::MongoDB::Connection::SocketFactory sf; + + conn.connect("mongodb://127.0.0.1", sf); + conn.disconnect(); + + try + { + conn.connect("http://127.0.0.1", sf); + fail("invalid URI scheme - must throw"); + } + catch (Poco::UnknownURISchemeException&) + { + } + + try + { + conn.connect("mongodb://127.0.0.1?ssl=true", sf); + fail("SSL not supported, must throw"); + } + catch (Poco::NotImplementedException&) + { + } + + conn.connect("mongodb://127.0.0.1/admin?ssl=false&connectTimeoutMS=10000&socketTimeoutMS=10000", sf); + conn.disconnect(); + + try + { + conn.connect("mongodb://127.0.0.1/admin?connectTimeoutMS=foo", sf); + fail("invalid parameter - must throw"); + } + catch (Poco::Exception&) + { + } + +#ifdef MONGODB_TEST_AUTH + conn.connect("mongodb://admin:admin@127.0.0.1/admin", sf); + conn.disconnect(); +#endif +} + + CppUnit::Test* MongoDBTest::suite() { try @@ -423,6 +468,7 @@ CppUnit::Test* MongoDBTest::suite() CppUnit_addTest(pSuite, MongoDBTest, testObjectID); CppUnit_addTest(pSuite, MongoDBTest, testCommand); CppUnit_addTest(pSuite, MongoDBTest, testUUID); + CppUnit_addTest(pSuite, MongoDBTest, testConnectURI); return pSuite; } diff --git a/MongoDB/testsuite/src/MongoDBTest.h b/MongoDB/testsuite/src/MongoDBTest.h index 9818c4de4..4c5954612 100644 --- a/MongoDB/testsuite/src/MongoDBTest.h +++ b/MongoDB/testsuite/src/MongoDBTest.h @@ -39,6 +39,7 @@ public: void testObjectID(); void testCommand(); void testUUID(); + void testConnectURI(); void setUp(); void tearDown();