modified FernClassifier::train(); remove old RTreeClassifier and added new implementation CalonderClassifier; removed old find_obj_calonder and added new one
This commit is contained in:
@@ -692,9 +692,9 @@ Size FernClassifier::getPatchSize() const
|
||||
}
|
||||
|
||||
|
||||
FernClassifier::FernClassifier(const vector<Point2f>& points,
|
||||
const vector<Ptr<Mat> >& refimgs,
|
||||
const vector<int>& labels,
|
||||
FernClassifier::FernClassifier(const vector<vector<Point2f> >& points,
|
||||
const vector<Mat>& refimgs,
|
||||
const vector<vector<int> >& labels,
|
||||
int _nclasses, int _patchSize,
|
||||
int _signatureSize, int _nstructs,
|
||||
int _structSize, int _nviews, int _compressionMethod,
|
||||
@@ -829,41 +829,56 @@ void FernClassifier::prepare(int _nclasses, int _patchSize, int _signatureSize,
|
||||
}
|
||||
}
|
||||
|
||||
static int calcNumPoints( const vector<vector<Point2f> >& points )
|
||||
{
|
||||
int count = 0;
|
||||
for( size_t i = 0; i < points.size(); i++ )
|
||||
count += points[i].size();
|
||||
return count;
|
||||
}
|
||||
|
||||
void FernClassifier::train(const vector<Point2f>& points,
|
||||
const vector<Ptr<Mat> >& refimgs,
|
||||
const vector<int>& labels,
|
||||
void FernClassifier::train(const vector<vector<Point2f> >& points,
|
||||
const vector<Mat>& refimgs,
|
||||
const vector<vector<int> >& labels,
|
||||
int _nclasses, int _patchSize,
|
||||
int _signatureSize, int _nstructs,
|
||||
int _structSize, int _nviews, int _compressionMethod,
|
||||
const PatchGenerator& patchGenerator)
|
||||
{
|
||||
_nclasses = _nclasses > 0 ? _nclasses : (int)points.size();
|
||||
CV_Assert( points.size() == refimgs.size() );
|
||||
int numPoints = calcNumPoints( points );
|
||||
_nclasses = (!labels.empty() && _nclasses>0) ? _nclasses : numPoints;
|
||||
CV_Assert( labels.empty() || labels.size() == points.size() );
|
||||
|
||||
|
||||
prepare(_nclasses, _patchSize, _signatureSize, _nstructs,
|
||||
_structSize, _nviews, _compressionMethod);
|
||||
|
||||
// pass all the views of all the samples through the generated trees and accumulate
|
||||
// the statistics (posterior probabilities) in leaves.
|
||||
Mat patch;
|
||||
int i, j, nsamples = (int)points.size();
|
||||
RNG& rng = theRNG();
|
||||
|
||||
for( i = 0; i < nsamples; i++ )
|
||||
int globalPointIdx = 0;
|
||||
for( size_t imgIdx = 0; imgIdx < points.size(); imgIdx++ )
|
||||
{
|
||||
Point2f pt = points[i];
|
||||
const Mat& src = *refimgs[i];
|
||||
int classId = labels.empty() ? i : labels[i];
|
||||
if( verbose && (i+1)*progressBarSize/nsamples != i*progressBarSize/nsamples )
|
||||
putchar('.');
|
||||
CV_Assert( 0 <= classId && classId < nclasses );
|
||||
classCounters[classId] += _nviews;
|
||||
for( j = 0; j < _nviews; j++ )
|
||||
const Point2f* imgPoints = &points[imgIdx][0];
|
||||
const int* imgLabels = labels.empty() ? 0 : &labels[imgIdx][0];
|
||||
for( size_t pointIdx = 0; pointIdx < points[imgIdx].size(); pointIdx++, globalPointIdx++ )
|
||||
{
|
||||
patchGenerator(src, pt, patch, patchSize, rng);
|
||||
for( int f = 0; f < nstructs; f++ )
|
||||
posteriors[getLeaf(f, patch)*nclasses + classId]++;
|
||||
Point2f pt = imgPoints[pointIdx];
|
||||
const Mat& src = refimgs[imgIdx];
|
||||
int classId = imgLabels==0 ? globalPointIdx : imgLabels[pointIdx];
|
||||
if( verbose && (globalPointIdx+1)*progressBarSize/numPoints != globalPointIdx*progressBarSize/numPoints )
|
||||
putchar('.');
|
||||
CV_Assert( 0 <= classId && classId < nclasses );
|
||||
classCounters[classId] += _nviews;
|
||||
for( int v = 0; v < _nviews; v++ )
|
||||
{
|
||||
patchGenerator(src, pt, patch, patchSize, rng);
|
||||
for( int f = 0; f < nstructs; f++ )
|
||||
posteriors[getLeaf(f, patch)*nclasses + classId]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
if( verbose )
|
||||
|
Reference in New Issue
Block a user