modified FernClassifier::train(); remove old RTreeClassifier and added new implementation CalonderClassifier; removed old find_obj_calonder and added new one
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -144,7 +144,7 @@ void DescriptorExtractor::removeBorderKeypoints( vector<KeyPoint>& keypoints,
|
||||
}
|
||||
|
||||
/****************************************************************************************\
|
||||
* SiftDescriptorExtractor *
|
||||
* SiftDescriptorExtractor *
|
||||
\****************************************************************************************/
|
||||
SiftDescriptorExtractor::SiftDescriptorExtractor( double magnification, bool isNormalize, bool recalculateAngles,
|
||||
int nOctaves, int nOctaveLayers, int firstOctave, int angleMode )
|
||||
@@ -188,7 +188,7 @@ void SiftDescriptorExtractor::write (FileStorage &fs) const
|
||||
}
|
||||
|
||||
/****************************************************************************************\
|
||||
* SurfDescriptorExtractor *
|
||||
* SurfDescriptorExtractor *
|
||||
\****************************************************************************************/
|
||||
SurfDescriptorExtractor::SurfDescriptorExtractor( int nOctaves,
|
||||
int nOctaveLayers, bool extended )
|
||||
@@ -228,6 +228,10 @@ void SurfDescriptorExtractor::write( FileStorage &fs ) const
|
||||
fs << "extended" << surf.extended;
|
||||
}
|
||||
|
||||
/****************************************************************************************\
|
||||
* Factory functions for descriptor extractor and matcher creating *
|
||||
\****************************************************************************************/
|
||||
|
||||
Ptr<DescriptorExtractor> createDescriptorExtractor( const string& descriptorExtractorType )
|
||||
{
|
||||
DescriptorExtractor* de = 0;
|
||||
@@ -270,7 +274,9 @@ Ptr<DescriptorMatcher> createDescriptorMatcher( const string& descriptorMatcherT
|
||||
return dm;
|
||||
}
|
||||
|
||||
|
||||
/****************************************************************************************\
|
||||
* BruteForceMatcher L2 specialization *
|
||||
\****************************************************************************************/
|
||||
template<>
|
||||
void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const Mat& descriptors_2,
|
||||
const Mat& /*mask*/, vector<int>& matches ) const
|
||||
@@ -317,7 +323,6 @@ void BruteForceMatcher<L2<float> >::matchImpl( const Mat& descriptors_1, const M
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
/****************************************************************************************\
|
||||
* GenericDescriptorMatch *
|
||||
\****************************************************************************************/
|
||||
@@ -394,6 +399,9 @@ void GenericDescriptorMatch::clear()
|
||||
collection.clear();
|
||||
}
|
||||
|
||||
/*
|
||||
* Factory function for GenericDescriptorMatch creating
|
||||
*/
|
||||
Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericDescritptorMatchType, const string ¶msFilename )
|
||||
{
|
||||
GenericDescriptorMatch *descriptorMatch = 0;
|
||||
@@ -409,7 +417,7 @@ Ptr<GenericDescriptorMatch> createGenericDescriptorMatch( const string& genericD
|
||||
}
|
||||
else if( ! genericDescritptorMatchType.compare ("CALONDER") )
|
||||
{
|
||||
descriptorMatch = new CalonderDescriptorMatch ();
|
||||
//descriptorMatch = new CalonderDescriptorMatch ();
|
||||
}
|
||||
|
||||
if( !paramsFilename.empty() && descriptorMatch != 0 )
|
||||
@@ -626,6 +634,7 @@ void OneWayDescriptorMatch::clear ()
|
||||
/****************************************************************************************\
|
||||
* CalonderDescriptorMatch *
|
||||
\****************************************************************************************/
|
||||
#if 0
|
||||
CalonderDescriptorMatch::Params::Params( const RNG& _rng, const PatchGenerator& _patchGen,
|
||||
int _numTrees, int _depth, int _views,
|
||||
size_t _reducedNumDim,
|
||||
@@ -774,6 +783,7 @@ void CalonderDescriptorMatch::write( FileStorage& fs ) const
|
||||
fs << "numQuantBits" << params.numQuantBits;
|
||||
fs << "printStatus" << params.printStatus;
|
||||
}
|
||||
#endif
|
||||
|
||||
/****************************************************************************************\
|
||||
* FernDescriptorMatch *
|
||||
@@ -827,22 +837,13 @@ void FernDescriptorMatch::trainFernClassifier()
|
||||
{
|
||||
assert( params.filename.empty() );
|
||||
|
||||
vector<Point2f> points;
|
||||
vector<Ptr<Mat> > refimgs;
|
||||
vector<int> labels;
|
||||
for( size_t imageIdx = 0; imageIdx < collection.images.size(); imageIdx++ )
|
||||
{
|
||||
for( size_t pointIdx = 0; pointIdx < collection.points[imageIdx].size(); pointIdx++ )
|
||||
{
|
||||
refimgs.push_back(new Mat (collection.images[imageIdx]));
|
||||
points.push_back(collection.points[imageIdx][pointIdx].pt);
|
||||
labels.push_back((int)pointIdx);
|
||||
}
|
||||
}
|
||||
vector<vector<Point2f> > points;
|
||||
for( size_t imgIdx = 0; imgIdx < collection.images.size(); imgIdx++ )
|
||||
KeyPoint::convert( collection.points[imgIdx], points[imgIdx] );
|
||||
|
||||
classifier = new FernClassifier( points, refimgs, labels, params.nclasses, params.patchSize,
|
||||
params.signatureSize, params.nstructs, params.structSize, params.nviews,
|
||||
params.compressionMethod, params.patchGenerator );
|
||||
classifier = new FernClassifier( points, collection.images, vector<vector<int> >(), 0, // each points is a class
|
||||
params.patchSize, params.signatureSize, params.nstructs, params.structSize,
|
||||
params.nviews, params.compressionMethod, params.patchGenerator );
|
||||
}
|
||||
}
|
||||
|
||||
@@ -966,4 +967,59 @@ void FernDescriptorMatch::clear ()
|
||||
classifier.release();
|
||||
}
|
||||
|
||||
/****************************************************************************************\
|
||||
* VectorDescriptorMatch *
|
||||
\****************************************************************************************/
|
||||
void VectorDescriptorMatch::add( const Mat& image, vector<KeyPoint>& keypoints )
|
||||
{
|
||||
Mat descriptors;
|
||||
extractor->compute( image, keypoints, descriptors );
|
||||
matcher->add( descriptors );
|
||||
|
||||
collection.add( Mat(), keypoints );
|
||||
};
|
||||
|
||||
void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<int>& keypointIndices )
|
||||
{
|
||||
Mat descriptors;
|
||||
extractor->compute( image, points, descriptors );
|
||||
|
||||
matcher->match( descriptors, keypointIndices );
|
||||
};
|
||||
|
||||
void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points, vector<DMatch>& matches )
|
||||
{
|
||||
Mat descriptors;
|
||||
extractor->compute( image, points, descriptors );
|
||||
|
||||
matcher->match( descriptors, matches );
|
||||
}
|
||||
|
||||
void VectorDescriptorMatch::match( const Mat& image, vector<KeyPoint>& points,
|
||||
vector<vector<DMatch> >& matches, float threshold )
|
||||
{
|
||||
Mat descriptors;
|
||||
extractor->compute( image, points, descriptors );
|
||||
|
||||
matcher->match( descriptors, matches, threshold );
|
||||
}
|
||||
|
||||
void VectorDescriptorMatch::clear()
|
||||
{
|
||||
GenericDescriptorMatch::clear();
|
||||
matcher->clear();
|
||||
}
|
||||
|
||||
void VectorDescriptorMatch::read( const FileNode& fn )
|
||||
{
|
||||
GenericDescriptorMatch::read(fn);
|
||||
extractor->read (fn);
|
||||
}
|
||||
|
||||
void VectorDescriptorMatch::write (FileStorage& fs) const
|
||||
{
|
||||
GenericDescriptorMatch::write(fs);
|
||||
extractor->write (fs);
|
||||
}
|
||||
|
||||
}
|
||||
|
@@ -692,9 +692,9 @@ Size FernClassifier::getPatchSize() const
|
||||
}
|
||||
|
||||
|
||||
FernClassifier::FernClassifier(const vector<Point2f>& points,
|
||||
const vector<Ptr<Mat> >& refimgs,
|
||||
const vector<int>& labels,
|
||||
FernClassifier::FernClassifier(const vector<vector<Point2f> >& points,
|
||||
const vector<Mat>& refimgs,
|
||||
const vector<vector<int> >& labels,
|
||||
int _nclasses, int _patchSize,
|
||||
int _signatureSize, int _nstructs,
|
||||
int _structSize, int _nviews, int _compressionMethod,
|
||||
@@ -829,41 +829,56 @@ void FernClassifier::prepare(int _nclasses, int _patchSize, int _signatureSize,
|
||||
}
|
||||
}
|
||||
|
||||
static int calcNumPoints( const vector<vector<Point2f> >& points )
|
||||
{
|
||||
int count = 0;
|
||||
for( size_t i = 0; i < points.size(); i++ )
|
||||
count += points[i].size();
|
||||
return count;
|
||||
}
|
||||
|
||||
void FernClassifier::train(const vector<Point2f>& points,
|
||||
const vector<Ptr<Mat> >& refimgs,
|
||||
const vector<int>& labels,
|
||||
void FernClassifier::train(const vector<vector<Point2f> >& points,
|
||||
const vector<Mat>& refimgs,
|
||||
const vector<vector<int> >& labels,
|
||||
int _nclasses, int _patchSize,
|
||||
int _signatureSize, int _nstructs,
|
||||
int _structSize, int _nviews, int _compressionMethod,
|
||||
const PatchGenerator& patchGenerator)
|
||||
{
|
||||
_nclasses = _nclasses > 0 ? _nclasses : (int)points.size();
|
||||
CV_Assert( points.size() == refimgs.size() );
|
||||
int numPoints = calcNumPoints( points );
|
||||
_nclasses = (!labels.empty() && _nclasses>0) ? _nclasses : numPoints;
|
||||
CV_Assert( labels.empty() || labels.size() == points.size() );
|
||||
|
||||
|
||||
prepare(_nclasses, _patchSize, _signatureSize, _nstructs,
|
||||
_structSize, _nviews, _compressionMethod);
|
||||
|
||||
// pass all the views of all the samples through the generated trees and accumulate
|
||||
// the statistics (posterior probabilities) in leaves.
|
||||
Mat patch;
|
||||
int i, j, nsamples = (int)points.size();
|
||||
RNG& rng = theRNG();
|
||||
|
||||
for( i = 0; i < nsamples; i++ )
|
||||
int globalPointIdx = 0;
|
||||
for( size_t imgIdx = 0; imgIdx < points.size(); imgIdx++ )
|
||||
{
|
||||
Point2f pt = points[i];
|
||||
const Mat& src = *refimgs[i];
|
||||
int classId = labels.empty() ? i : labels[i];
|
||||
if( verbose && (i+1)*progressBarSize/nsamples != i*progressBarSize/nsamples )
|
||||
putchar('.');
|
||||
CV_Assert( 0 <= classId && classId < nclasses );
|
||||
classCounters[classId] += _nviews;
|
||||
for( j = 0; j < _nviews; j++ )
|
||||
const Point2f* imgPoints = &points[imgIdx][0];
|
||||
const int* imgLabels = labels.empty() ? 0 : &labels[imgIdx][0];
|
||||
for( size_t pointIdx = 0; pointIdx < points[imgIdx].size(); pointIdx++, globalPointIdx++ )
|
||||
{
|
||||
patchGenerator(src, pt, patch, patchSize, rng);
|
||||
for( int f = 0; f < nstructs; f++ )
|
||||
posteriors[getLeaf(f, patch)*nclasses + classId]++;
|
||||
Point2f pt = imgPoints[pointIdx];
|
||||
const Mat& src = refimgs[imgIdx];
|
||||
int classId = imgLabels==0 ? globalPointIdx : imgLabels[pointIdx];
|
||||
if( verbose && (globalPointIdx+1)*progressBarSize/numPoints != globalPointIdx*progressBarSize/numPoints )
|
||||
putchar('.');
|
||||
CV_Assert( 0 <= classId && classId < nclasses );
|
||||
classCounters[classId] += _nviews;
|
||||
for( int v = 0; v < _nviews; v++ )
|
||||
{
|
||||
patchGenerator(src, pt, patch, patchSize, rng);
|
||||
for( int f = 0; f < nstructs; f++ )
|
||||
posteriors[getLeaf(f, patch)*nclasses + classId]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if( verbose )
|
||||
|
Reference in New Issue
Block a user