Added plot data generation for visual descriptors comparison in the evaluation framework
This commit is contained in:
parent
f6f634bace
commit
88bd1f1d1a
@ -1503,6 +1503,25 @@ struct CV_EXPORTS L2
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/****************************************************************************************\
|
||||
* DescriptorMatching *
|
||||
\****************************************************************************************/
|
||||
/*
|
||||
* Struct for matching: match index and distance between descriptors
|
||||
*/
|
||||
struct DescriptorMatching
|
||||
{
|
||||
int index;
|
||||
float distance;
|
||||
|
||||
//less is better
|
||||
bool operator<( const DescriptorMatching &m) const
|
||||
{
|
||||
return distance < m.distance;
|
||||
}
|
||||
};
|
||||
|
||||
/****************************************************************************************\
|
||||
* DescriptorMatcher *
|
||||
\****************************************************************************************/
|
||||
@ -1545,6 +1564,28 @@ public:
|
||||
void match( const Mat& query, const Mat& mask,
|
||||
vector<int>& matches ) const;
|
||||
|
||||
/*
|
||||
* Find the best match for each descriptor from a query set
|
||||
*
|
||||
* query The query set of descriptors
|
||||
* matchings Matchings of the closest matches from the training set
|
||||
*/
|
||||
void match( const Mat& query, vector<DescriptorMatching>& matchings ) const;
|
||||
|
||||
/*
|
||||
* Find the best matches between two descriptor sets, with constraints
|
||||
* on which pairs of descriptors can be matched.
|
||||
*
|
||||
* The mask describes which descriptors can be matched. descriptors_1[i]
|
||||
* can be matched with descriptors_2[j] only if mask.at<char>(i,j) is non-zero.
|
||||
*
|
||||
* query The query set of descriptors
|
||||
* mask Mask specifying permissible matches.
|
||||
* matchings Matchings of the closest matches from the training set
|
||||
*/
|
||||
void match( const Mat& query, const Mat& mask,
|
||||
vector<DescriptorMatching>& matchings ) const;
|
||||
|
||||
/*
|
||||
* Find the best keypoint matches for small view changes.
|
||||
*
|
||||
@ -1574,6 +1615,13 @@ protected:
|
||||
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
|
||||
const Mat& mask, vector<int>& matches ) const = 0;
|
||||
|
||||
/*
|
||||
* Find matches; match() calls this. Must be implemented by the subclass.
|
||||
* The mask may be empty.
|
||||
*/
|
||||
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
|
||||
const Mat& mask, vector<DescriptorMatching>& matches ) const = 0;
|
||||
|
||||
static bool possibleMatch( const Mat& mask, int index_1, int index_2 )
|
||||
{
|
||||
return mask.empty() || mask.at<char>(index_1, index_2);
|
||||
@ -1609,6 +1657,18 @@ inline void DescriptorMatcher::match( const Mat& query, const Mat& mask,
|
||||
matchImpl( query, train, mask, matches );
|
||||
}
|
||||
|
||||
inline void DescriptorMatcher::match( const Mat& query, vector<DescriptorMatching>& matches ) const
|
||||
{
|
||||
matchImpl( query, train, Mat(), matches );
|
||||
}
|
||||
|
||||
|
||||
inline void DescriptorMatcher::match( const Mat& query, const Mat& mask,
|
||||
vector<DescriptorMatching>& matches ) const
|
||||
{
|
||||
matchImpl( query, train, mask, matches );
|
||||
}
|
||||
|
||||
inline void DescriptorMatcher::clear()
|
||||
{
|
||||
train.release();
|
||||
@ -1633,12 +1693,28 @@ protected:
|
||||
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
|
||||
const Mat& mask, vector<int>& matches ) const;
|
||||
|
||||
virtual void matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
|
||||
const Mat& mask, vector<DescriptorMatching>& matches ) const;
|
||||
|
||||
Distance distance;
|
||||
};
|
||||
|
||||
template<class Distance>
|
||||
void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
|
||||
const Mat& mask, vector<int>& matches ) const
|
||||
{
|
||||
vector<DescriptorMatching> matchings;
|
||||
matchImpl( descriptors_1, descriptors_2, mask, matchings);
|
||||
matches.resize( matchings.size() );
|
||||
for( size_t i=0;i<matchings.size();i++)
|
||||
{
|
||||
matches[i] = matchings[i].index;
|
||||
}
|
||||
}
|
||||
|
||||
template<class Distance>
|
||||
void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
|
||||
const Mat& mask, vector<DescriptorMatching>& matches ) const
|
||||
{
|
||||
typedef typename Distance::ValueType ValueType;
|
||||
typedef typename Distance::ResultType DistanceType;
|
||||
@ -1650,8 +1726,7 @@ void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat
|
||||
assert( DataType<ValueType>::type == descriptors_2.type() || descriptors_2.empty() );
|
||||
|
||||
int dimension = descriptors_1.cols;
|
||||
matches.clear();
|
||||
matches.reserve(descriptors_1.rows);
|
||||
matches.resize(descriptors_1.rows);
|
||||
|
||||
for( int i = 0; i < descriptors_1.rows; i++ )
|
||||
{
|
||||
@ -1674,7 +1749,12 @@ void BruteForceMatcher<Distance>::matchImpl( const Mat& descriptors_1, const Mat
|
||||
}
|
||||
|
||||
if( matchIndex != -1 )
|
||||
matches.push_back( matchIndex );
|
||||
{
|
||||
DescriptorMatching matching;
|
||||
matching.index = matchIndex;
|
||||
matching.distance = matchDistance;
|
||||
matches[i] = matching;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1742,6 +1822,12 @@ public:
|
||||
// indices A vector to be filled with keypoint class indices
|
||||
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<int>& indices ) = 0;
|
||||
|
||||
// Matches test keypoints to the training set
|
||||
// image The source image
|
||||
// points Test keypoints from the source image
|
||||
// matchings A vector to be filled with keypoint matchings
|
||||
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings ) {};
|
||||
|
||||
// Clears keypoints storing in collection
|
||||
virtual void clear();
|
||||
|
||||
@ -1816,6 +1902,8 @@ public:
|
||||
// loaded with DescriptorOneWay::Initialize, kd tree is used for finding minimum distances.
|
||||
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<int>& indices );
|
||||
|
||||
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings );
|
||||
|
||||
// Classify a set of keypoints. The same as match, but returns point classes rather than indices
|
||||
virtual void classify( const Mat& image, vector<KeyPoint>& points );
|
||||
|
||||
@ -1944,6 +2032,8 @@ public:
|
||||
|
||||
virtual void match( const Mat& image, vector<KeyPoint>& keypoints, vector<int>& indices );
|
||||
|
||||
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings );
|
||||
|
||||
virtual void classify( const Mat& image, vector<KeyPoint>& keypoints );
|
||||
|
||||
virtual void clear ();
|
||||
@ -2000,6 +2090,14 @@ public:
|
||||
matcher.match( descriptors, keypointIndices );
|
||||
};
|
||||
|
||||
virtual void match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings )
|
||||
{
|
||||
Mat descriptors;
|
||||
extractor.compute( image, points, descriptors );
|
||||
|
||||
matcher.match( descriptors, matchings );
|
||||
}
|
||||
|
||||
virtual void clear()
|
||||
{
|
||||
GenericDescriptorMatch::clear();
|
||||
|
@ -44,6 +44,8 @@
|
||||
using namespace std;
|
||||
using namespace cv;
|
||||
|
||||
//#define _KDTREE
|
||||
|
||||
/****************************************************************************************\
|
||||
* DescriptorExtractor *
|
||||
\****************************************************************************************/
|
||||
@ -332,15 +334,27 @@ void OneWayDescriptorMatch::add( KeyPointCollection& keypoints )
|
||||
|
||||
void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<int>& indices)
|
||||
{
|
||||
vector<DescriptorMatching> matchings( points.size() );
|
||||
indices.resize(points.size());
|
||||
|
||||
match( image, points, matchings );
|
||||
|
||||
for( size_t i = 0; i < points.size(); i++ )
|
||||
indices[i] = matchings[i].index;
|
||||
}
|
||||
|
||||
void OneWayDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<DescriptorMatching>& matchings )
|
||||
{
|
||||
matchings.resize( points.size() );
|
||||
IplImage _image = image;
|
||||
for( size_t i = 0; i < points.size(); i++ )
|
||||
{
|
||||
int descIdx = -1;
|
||||
int poseIdx = -1;
|
||||
float distance;
|
||||
base->FindDescriptor( &_image, points[i].pt, descIdx, poseIdx, distance );
|
||||
indices[i] = descIdx;
|
||||
|
||||
DescriptorMatching matching;
|
||||
matching.index = -1;
|
||||
base->FindDescriptor( &_image, points[i].pt, matching.index, poseIdx, matching.distance );
|
||||
matchings[i] = matching;
|
||||
}
|
||||
}
|
||||
|
||||
@ -631,6 +645,21 @@ void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints,
|
||||
}
|
||||
}
|
||||
|
||||
void FernDescriptorMatch::match( const Mat& image, vector<KeyPoint>& keypoints, vector<DescriptorMatching>& matchings )
|
||||
{
|
||||
trainFernClassifier();
|
||||
|
||||
matchings.resize( keypoints.size() );
|
||||
vector<float> signature( (size_t)classifier->getClassCount() );
|
||||
|
||||
for( size_t pi = 0; pi < keypoints.size(); pi++ )
|
||||
{
|
||||
calcBestProbAndMatchIdx( image, keypoints[pi].pt, matchings[pi].distance, matchings[pi].index, signature );
|
||||
//matching[pi].distance is log of probability so we need to transform it
|
||||
matchings[pi].distance = -matchings[pi].distance;
|
||||
}
|
||||
}
|
||||
|
||||
void FernDescriptorMatch::classify( const Mat& image, vector<KeyPoint>& keypoints )
|
||||
{
|
||||
trainFernClassifier();
|
||||
|
@ -549,9 +549,9 @@ inline float precision( int correctMatchCount, int falseMatchCount )
|
||||
}
|
||||
|
||||
void evaluateDescriptors( const vector<EllipticKeyPoint>& keypoints1, const vector<EllipticKeyPoint>& keypoints2,
|
||||
const vector<int>& matches1to2,
|
||||
vector< pair<DescriptorMatching, int> >& matches1to2,
|
||||
const Mat& img1, const Mat& img2, const Mat& H1to2,
|
||||
int& correctMatchCount, int& falseMatchCount, int& correspondenceCount )
|
||||
int &correctMatchCount, int &falseMatchCount, vector<int> &matchStatuses, int& correspondenceCount )
|
||||
{
|
||||
assert( !keypoints1.empty() && !keypoints2.empty() && !matches1to2.empty() );
|
||||
assert( keypoints1.size() == matches1to2.size() );
|
||||
@ -564,17 +564,27 @@ void evaluateDescriptors( const vector<EllipticKeyPoint>& keypoints1, const vect
|
||||
repeatability, correspCount,
|
||||
&thresholdedOverlapMask );
|
||||
correspondenceCount = thresholdedOverlapMask.nzcount();
|
||||
correctMatchCount = falseMatchCount = 0;
|
||||
|
||||
matchStatuses.resize( matches1to2.size() );
|
||||
correctMatchCount = 0;
|
||||
falseMatchCount = 0;
|
||||
|
||||
//the nearest descriptors should be examined first
|
||||
std::sort( matches1to2.begin(), matches1to2.end() );
|
||||
|
||||
for( size_t i1 = 0; i1 < matches1to2.size(); i1++ )
|
||||
{
|
||||
int i2 = matches1to2[i1];
|
||||
int i2 = matches1to2[i1].first.index;
|
||||
if( i2 > 0 )
|
||||
{
|
||||
if( thresholdedOverlapMask(i1, i2) )
|
||||
matchStatuses[i2] = thresholdedOverlapMask(matches1to2[i1].second, i2);
|
||||
if( matchStatuses[i2] )
|
||||
correctMatchCount++;
|
||||
else
|
||||
falseMatchCount++;
|
||||
}
|
||||
else
|
||||
matchStatuses[i2] = -1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -615,11 +625,16 @@ class BaseQualityTest : public CvTest
|
||||
{
|
||||
public:
|
||||
BaseQualityTest( const char* _algName, const char* _testName, const char* _testFuncs ) :
|
||||
CvTest( _testName, _testFuncs ), algName(_algName) {}
|
||||
CvTest( _testName, _testFuncs ), algName(_algName)
|
||||
{
|
||||
//TODO: change this
|
||||
isWriteGraphicsData = true;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual string getRunParamsFilename() const = 0;
|
||||
virtual string getResultsFilename() const = 0;
|
||||
virtual string getPlotPath() const = 0;
|
||||
|
||||
virtual void validQualityClear( int datasetIdx ) = 0;
|
||||
virtual void calcQualityClear( int datasetIdx ) = 0;
|
||||
@ -650,9 +665,11 @@ protected:
|
||||
|
||||
virtual void processResults();
|
||||
virtual int processResults( int datasetIdx, int caseIdx ) = 0;
|
||||
void writeAllPlotData() const;
|
||||
virtual void writePlotData( const string &filename, int datasetIdx ) const {};
|
||||
|
||||
string algName;
|
||||
bool isWriteParams, isWriteResults;
|
||||
bool isWriteParams, isWriteResults, isWriteGraphicsData;
|
||||
};
|
||||
|
||||
void BaseQualityTest::readAllDatasetsRunParams()
|
||||
@ -811,6 +828,8 @@ void BaseQualityTest::processResults()
|
||||
{
|
||||
if( isWriteParams )
|
||||
writeAllDatasetsRunParams();
|
||||
if( isWriteGraphicsData )
|
||||
writeAllPlotData();
|
||||
|
||||
int res = CvTS::OK;
|
||||
if( isWriteResults )
|
||||
@ -838,6 +857,18 @@ void BaseQualityTest::processResults()
|
||||
ts->set_failed_test_info( res );
|
||||
}
|
||||
|
||||
void BaseQualityTest::writeAllPlotData() const
|
||||
{
|
||||
for( int di = 0; di < DATASETS_COUNT; di++ )
|
||||
{
|
||||
stringstream stream;
|
||||
stream << getPlotPath() << algName << "_" << DATASET_NAMES[di] << ".csv";
|
||||
string filename;
|
||||
stream >> filename;
|
||||
writePlotData( filename, di );
|
||||
}
|
||||
}
|
||||
|
||||
void BaseQualityTest::run ( int )
|
||||
{
|
||||
readAlgorithm ();
|
||||
@ -904,6 +935,7 @@ protected:
|
||||
|
||||
virtual string getRunParamsFilename() const;
|
||||
virtual string getResultsFilename() const;
|
||||
virtual string getPlotPath() const;
|
||||
|
||||
virtual void validQualityClear( int datasetIdx );
|
||||
virtual void calcQualityClear( int datasetIdx );
|
||||
@ -961,6 +993,11 @@ string DetectorQualityTest::getResultsFilename() const
|
||||
return string(ts->get_data_path()) + DETECTORS_DIR + algName + RES_POSTFIX;
|
||||
}
|
||||
|
||||
string DetectorQualityTest::getPlotPath() const
|
||||
{
|
||||
return string(ts->get_data_path()) + DETECTORS_DIR + "plots/";
|
||||
}
|
||||
|
||||
void DetectorQualityTest::validQualityClear( int datasetIdx )
|
||||
{
|
||||
validQuality[datasetIdx].clear();
|
||||
@ -1253,6 +1290,7 @@ public:
|
||||
{
|
||||
validQuality.resize(DATASETS_COUNT);
|
||||
calcQuality.resize(DATASETS_COUNT);
|
||||
calcDatasetQuality.resize(DATASETS_COUNT);
|
||||
commRunParams.resize(DATASETS_COUNT);
|
||||
|
||||
commRunParamsDefault.projectKeypointsFrom1Image = true;
|
||||
@ -1267,6 +1305,7 @@ protected:
|
||||
|
||||
virtual string getRunParamsFilename() const;
|
||||
virtual string getResultsFilename() const;
|
||||
virtual string getPlotPath() const;
|
||||
|
||||
virtual void validQualityClear( int datasetIdx );
|
||||
virtual void calcQualityClear( int datasetIdx );
|
||||
@ -1289,6 +1328,8 @@ protected:
|
||||
|
||||
virtual int processResults( int datasetIdx, int caseIdx );
|
||||
|
||||
virtual void writePlotData( const string &filename, int di ) const;
|
||||
|
||||
struct Quality
|
||||
{
|
||||
float recall;
|
||||
@ -1296,6 +1337,7 @@ protected:
|
||||
};
|
||||
vector<vector<Quality> > validQuality;
|
||||
vector<vector<Quality> > calcQuality;
|
||||
vector<vector<Quality> > calcDatasetQuality;
|
||||
|
||||
struct CommonRunParams
|
||||
{
|
||||
@ -1322,6 +1364,11 @@ string DescriptorQualityTest::getResultsFilename() const
|
||||
return string(ts->get_data_path()) + DESCRIPTORS_DIR + algName + RES_POSTFIX;
|
||||
}
|
||||
|
||||
string DescriptorQualityTest::getPlotPath() const
|
||||
{
|
||||
return string(ts->get_data_path()) + DESCRIPTORS_DIR + "plots/";
|
||||
}
|
||||
|
||||
void DescriptorQualityTest::validQualityClear( int datasetIdx )
|
||||
{
|
||||
validQuality[datasetIdx].clear();
|
||||
@ -1408,6 +1455,16 @@ void DescriptorQualityTest::setDefaultDatasetRunParams( int datasetIdx )
|
||||
commRunParams[datasetIdx].keypontsFilename = "surf_" + DATASET_NAMES[datasetIdx] + ".xml.gz";
|
||||
}
|
||||
|
||||
void DescriptorQualityTest::writePlotData( const string &filename, int di ) const
|
||||
{
|
||||
FILE *file = fopen (filename.c_str(),"w");
|
||||
size_t size = calcDatasetQuality[di].size();
|
||||
for (size_t i=0;i<size;i++)
|
||||
{
|
||||
fprintf( file, "%f, %f\n", 1 - calcDatasetQuality[di][i].precision, calcDatasetQuality[di][i].recall);
|
||||
}
|
||||
fclose( file );
|
||||
}
|
||||
|
||||
void DescriptorQualityTest::readAlgorithm( )
|
||||
{
|
||||
@ -1478,6 +1535,10 @@ void DescriptorQualityTest::runDatasetTest (const vector<Mat> &imgs, const vecto
|
||||
transformToEllipticKeyPoints( keypoints1, ekeypoints1 );
|
||||
|
||||
int progressCount = DATASETS_COUNT*TEST_CASE_COUNT;
|
||||
vector< pair<DescriptorMatching, int> > allMatchings;
|
||||
vector<int> allMatchStatuses;
|
||||
size_t matchingIndex = 0;
|
||||
int allCorrespCount = 0;
|
||||
for( int ci = 0; ci < TEST_CASE_COUNT; ci++ )
|
||||
{
|
||||
progress = update_progress( progress, di*TEST_CASE_COUNT + ci, progressCount, 0 );
|
||||
@ -1494,16 +1555,50 @@ void DescriptorQualityTest::runDatasetTest (const vector<Mat> &imgs, const vecto
|
||||
readKeypoints( keypontsFS, keypoints2, ci+1 );
|
||||
transformToEllipticKeyPoints( keypoints2, ekeypoints2 );
|
||||
descMatch->add( imgs[ci+1], keypoints2 );
|
||||
vector<int> matches1to2;
|
||||
descMatch->match( imgs[0], keypoints1, matches1to2 );
|
||||
vector<DescriptorMatching> matchings1to2;
|
||||
descMatch->match( imgs[0], keypoints1, matchings1to2 );
|
||||
vector< pair<DescriptorMatching, int> > matchings (matchings1to2.size());
|
||||
for( size_t i=0;i<matchings1to2.size();i++ )
|
||||
matchings[i] = pair<DescriptorMatching, int>( matchings1to2[i], i);
|
||||
|
||||
// TODO if( commRunParams[di].matchFilter )
|
||||
int correctMatchCount, falseMatchCount, correspCount;
|
||||
evaluateDescriptors( ekeypoints1, ekeypoints2, matches1to2, imgs[0], imgs[ci+1], Hs[ci],
|
||||
correctMatchCount, falseMatchCount, correspCount );
|
||||
int correspCount;
|
||||
int correctMatchCount = 0, falseMatchCount = 0;
|
||||
vector<int> matchStatuses;
|
||||
evaluateDescriptors( ekeypoints1, ekeypoints2, matchings, imgs[0], imgs[ci+1], Hs[ci],
|
||||
correctMatchCount, falseMatchCount, matchStatuses, correspCount );
|
||||
for( size_t i=0;i<matchings.size();i++ )
|
||||
matchings[i].second += matchingIndex;
|
||||
matchingIndex += matchings.size();
|
||||
|
||||
|
||||
allCorrespCount += correspCount;
|
||||
|
||||
//TODO: use merge
|
||||
std::copy( matchings.begin(), matchings.end(), std::back_inserter( allMatchings ) );
|
||||
std::copy( matchStatuses.begin(), matchStatuses.end(), std::back_inserter( allMatchStatuses ) );
|
||||
|
||||
printf ("%d %d %d \n", correctMatchCount, falseMatchCount, correspCount );
|
||||
|
||||
calcQuality[di][ci].recall = recall( correctMatchCount, correspCount );
|
||||
calcQuality[di][ci].precision = precision( correctMatchCount, falseMatchCount );
|
||||
descMatch->clear ();
|
||||
}
|
||||
|
||||
std::sort( allMatchings.begin(), allMatchings.end() );
|
||||
|
||||
calcDatasetQuality[di].resize( allMatchings.size() );
|
||||
int correctMatchCount = 0, falseMatchCount = 0;
|
||||
for( size_t i=0;i<allMatchings.size();i++)
|
||||
{
|
||||
if( allMatchStatuses[ allMatchings[i].second ] )
|
||||
correctMatchCount++;
|
||||
else
|
||||
falseMatchCount++;
|
||||
|
||||
calcDatasetQuality[di][i].recall = recall( correctMatchCount, allCorrespCount );
|
||||
calcDatasetQuality[di][i].precision = precision( correctMatchCount, falseMatchCount );
|
||||
}
|
||||
}
|
||||
|
||||
int DescriptorQualityTest::processResults( int datasetIdx, int caseIdx )
|
||||
|
Loading…
x
Reference in New Issue
Block a user