diff --git a/modules/features2d/include/opencv2/features2d/features2d.hpp b/modules/features2d/include/opencv2/features2d/features2d.hpp index 195eb6cc6..ad7ed2134 100644 --- a/modules/features2d/include/opencv2/features2d/features2d.hpp +++ b/modules/features2d/include/opencv2/features2d/features2d.hpp @@ -1659,6 +1659,8 @@ public: */ void match( const Mat& query, const Mat& mask, vector& matches ) const; + void match( const Mat& query, const Mat& train, vector& matches, const Mat& mask ) const; + /* * Find many matches for each descriptor from a query set * @@ -1686,21 +1688,21 @@ public: virtual void clear(); protected: - Mat train; + Mat m_train; /* * Find matches; match() calls this. Must be implemented by the subclass. * The mask may be empty. */ - virtual void matchImpl( const Mat& query, const Mat& mask, vector& matches ) const = 0; + virtual void matchImpl( const Mat& query, const Mat& train, vector& matches, const Mat& mask ) const = 0; /* * Find matches; match() calls this. Must be implemented by the subclass. * The mask may be empty. */ - virtual void matchImpl( const Mat& query, const Mat& mask, vector& matches ) const = 0; + virtual void matchImpl( const Mat& query, const Mat& train, vector& matches, const Mat& mask ) const = 0; - virtual void matchImpl( const Mat& query, const Mat& mask, vector >& matches, float threshold ) const = 0; + virtual void matchImpl( const Mat& query, const Mat& train, vector >& matches, float threshold, const Mat& mask ) const = 0; static bool possibleMatch( const Mat& mask, int index_1, int index_2 ) @@ -1725,20 +1727,20 @@ public: BruteForceMatcher( Distance d = Distance() ) : distance(d) {} virtual void index() {} protected: - virtual void matchImpl( const Mat& query, const Mat& mask, vector& matches ) const; + virtual void matchImpl( const Mat& query, const Mat& train, vector& matches, const Mat& mask ) const; - virtual void matchImpl( const Mat& query, const Mat& mask, vector& matches ) const; + virtual void matchImpl( const Mat& query, const Mat& train, vector& matches, const Mat& mask ) const; - virtual void matchImpl( const Mat& query, const Mat& mask, vector >& matches, float threshold ) const; + virtual void matchImpl( const Mat& query, const Mat& train, vector >& matches, float threshold, const Mat& mask ) const; Distance distance; }; template inline -void BruteForceMatcher::matchImpl( const Mat& query, const Mat& mask, vector& matches ) const +void BruteForceMatcher::matchImpl( const Mat& query, const Mat& train, vector& matches, const Mat& mask ) const { vector fullMatches; - matchImpl( query, mask, fullMatches); + matchImpl( query, train, fullMatches, mask ); matches.clear(); matches.resize( fullMatches.size() ); for( size_t i=0;i::matchImpl( const Mat& query, const Mat& mask, } template inline -void BruteForceMatcher::matchImpl( const Mat& query, const Mat& mask, vector& matches ) const +void BruteForceMatcher::matchImpl( const Mat& query, const Mat& train, vector& matches, const Mat& mask ) const { typedef typename Distance::ValueType ValueType; typedef typename Distance::ResultType DistanceType; @@ -1795,7 +1797,8 @@ void BruteForceMatcher::matchImpl( const Mat& query, const Mat& mask, } template inline -void BruteForceMatcher::matchImpl( const Mat& query, const Mat& mask, vector >& matches, float threshold ) const +void BruteForceMatcher::matchImpl( const Mat& query, const Mat& train, vector >& matches, + float threshold, const Mat& mask ) const { typedef typename Distance::ValueType ValueType; typedef typename Distance::ResultType DistanceType; @@ -1804,7 +1807,7 @@ void BruteForceMatcher::matchImpl( const Mat& query, const Mat& mask, assert( query.cols == train.cols || query.empty() || train.empty() ); assert( DataType::type == query.type() || query.empty() ); - assert( DataType::type == train.type() || train.empty() ); + assert( DataType::type == train.type() || train.empty() ); int dimension = query.cols; matches.clear(); @@ -1834,7 +1837,7 @@ void BruteForceMatcher::matchImpl( const Mat& query, const Mat& mask, } template<> -void BruteForceMatcher >::matchImpl( const Mat& query, const Mat& mask, vector& matches ) const; +void BruteForceMatcher >::matchImpl( const Mat& query, const Mat& train, vector& matches, const Mat& mask) const; CV_EXPORTS Ptr createDescriptorMatcher( const string& descriptorMatcherType ); diff --git a/modules/features2d/src/descriptors.cpp b/modules/features2d/src/descriptors.cpp index 528afd4b4..adad4875a 100644 --- a/modules/features2d/src/descriptors.cpp +++ b/modules/features2d/src/descriptors.cpp @@ -431,65 +431,70 @@ Ptr createDescriptorMatcher( const string& descriptorMatcherT \****************************************************************************************/ void DescriptorMatcher::add( const Mat& descriptors ) { - if( train.empty() ) + if( m_train.empty() ) { - train = descriptors; + m_train = descriptors; } else { // merge train and descriptors - Mat m( train.rows + descriptors.rows, train.cols, CV_32F ); - Mat m1 = m.rowRange( 0, train.rows ); - train.copyTo( m1 ); - Mat m2 = m.rowRange( train.rows + 1, m.rows ); + Mat m( m_train.rows + descriptors.rows, m_train.cols, CV_32F ); + Mat m1 = m.rowRange( 0, m_train.rows ); + m_train.copyTo( m1 ); + Mat m2 = m.rowRange( m_train.rows + 1, m.rows ); descriptors.copyTo( m2 ); - train = m; + m_train = m; } } void DescriptorMatcher::match( const Mat& query, vector& matches ) const { - matchImpl( query, Mat(), matches ); + matchImpl( query, m_train, matches, Mat() ); } void DescriptorMatcher::match( const Mat& query, const Mat& mask, vector& matches ) const { - matchImpl( query, mask, matches ); + matchImpl( query, m_train, matches, mask ); } void DescriptorMatcher::match( const Mat& query, vector& matches ) const { - matchImpl( query, Mat(), matches ); + matchImpl( query, m_train, matches, Mat() ); } void DescriptorMatcher::match( const Mat& query, const Mat& mask, vector& matches ) const { - matchImpl( query, mask, matches ); + matchImpl( query, m_train, matches, mask ); +} + +void DescriptorMatcher::match( const Mat& query, const Mat& train, vector& matches, const Mat& mask ) const +{ + matchImpl( query, train, matches, mask ); } void DescriptorMatcher::match( const Mat& query, vector >& matches, float threshold ) const { - matchImpl( query, Mat(), matches, threshold ); + matchImpl( query, m_train, matches, threshold, Mat() ); } void DescriptorMatcher::match( const Mat& query, const Mat& mask, vector >& matches, float threshold ) const { - matchImpl( query, mask, matches, threshold ); + matchImpl( query, m_train, matches, threshold, mask ); } void DescriptorMatcher::clear() { - train.release(); + m_train.release(); } /* * BruteForceMatcher L2 specialization */ template<> -void BruteForceMatcher >::matchImpl( const Mat& query, const Mat& mask, vector& matches ) const +void BruteForceMatcher >::matchImpl( const Mat& query, const Mat& train, vector& matches, const Mat& mask ) const { assert( mask.empty() || (mask.rows == query.rows && mask.cols == train.rows) ); assert( query.cols == train.cols || query.empty() || train.empty() ); diff --git a/samples/cpp/descriptor_extractor_matcher.cpp b/samples/cpp/descriptor_extractor_matcher.cpp index bd08d5675..154e9a805 100644 --- a/samples/cpp/descriptor_extractor_matcher.cpp +++ b/samples/cpp/descriptor_extractor_matcher.cpp @@ -64,10 +64,8 @@ void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective, cout << ">" << endl; cout << "< Matching descriptors..." << endl; - vector matches; - descriptorMatcher->clear(); - descriptorMatcher->add( descriptors2 ); - descriptorMatcher->match( descriptors1, matches ); + vector matches; + descriptorMatcher->match( descriptors1, descriptors2, matches, Mat() ); cout << ">" << endl; if( !H12.empty() ) @@ -81,11 +79,15 @@ void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective, cout << ">" << endl; } + vector trainIdxs( matches.size() ); + for( size_t i = 0; i < matches.size(); i++ ) + trainIdxs[i] = matches[i].indexTrain; + if( !isWarpPerspective && ransacReprojThreshold >= 0 ) { cout << "< Computing homography (RANSAC)..." << endl; vector points1; KeyPoint::convert(keypoints1, points1); - vector points2; KeyPoint::convert(keypoints2, points2, matches); + vector points2; KeyPoint::convert(keypoints2, points2, trainIdxs); H12 = findHomography( Mat(points1), Mat(points2), CV_RANSAC, ransacReprojThreshold ); cout << ">" << endl; } @@ -95,9 +97,8 @@ void doIteration( const Mat& img1, Mat& img2, bool isWarpPerspective, { vector matchesMask( matches.size(), 0 ); vector points1; KeyPoint::convert(keypoints1, points1); - vector points2; KeyPoint::convert(keypoints2, points2, matches); + vector points2; KeyPoint::convert(keypoints2, points2, trainIdxs); Mat points1t; perspectiveTransform(Mat(points1), points1t, H12); - vector::const_iterator mit = matches.begin(); for( size_t i1 = 0; i1 < points1.size(); i1++ ) { if( norm(points2[i1] - points1t.at(i1,0)) < 4 ) // inlier