modified FernClassifier::train(); remove old RTreeClassifier and added new implementation CalonderClassifier; removed old find_obj_calonder and added new one

This commit is contained in:
Maria Dimashova
2010-07-26 08:58:46 +00:00
parent 1135bc2495
commit b5a71db742
5 changed files with 959 additions and 1612 deletions

View File

@@ -692,9 +692,9 @@ Size FernClassifier::getPatchSize() const
}
FernClassifier::FernClassifier(const vector<Point2f>& points,
const vector<Ptr<Mat> >& refimgs,
const vector<int>& labels,
FernClassifier::FernClassifier(const vector<vector<Point2f> >& points,
const vector<Mat>& refimgs,
const vector<vector<int> >& labels,
int _nclasses, int _patchSize,
int _signatureSize, int _nstructs,
int _structSize, int _nviews, int _compressionMethod,
@@ -829,41 +829,56 @@ void FernClassifier::prepare(int _nclasses, int _patchSize, int _signatureSize,
}
}
static int calcNumPoints( const vector<vector<Point2f> >& points )
{
int count = 0;
for( size_t i = 0; i < points.size(); i++ )
count += points[i].size();
return count;
}
void FernClassifier::train(const vector<Point2f>& points,
const vector<Ptr<Mat> >& refimgs,
const vector<int>& labels,
void FernClassifier::train(const vector<vector<Point2f> >& points,
const vector<Mat>& refimgs,
const vector<vector<int> >& labels,
int _nclasses, int _patchSize,
int _signatureSize, int _nstructs,
int _structSize, int _nviews, int _compressionMethod,
const PatchGenerator& patchGenerator)
{
_nclasses = _nclasses > 0 ? _nclasses : (int)points.size();
CV_Assert( points.size() == refimgs.size() );
int numPoints = calcNumPoints( points );
_nclasses = (!labels.empty() && _nclasses>0) ? _nclasses : numPoints;
CV_Assert( labels.empty() || labels.size() == points.size() );
prepare(_nclasses, _patchSize, _signatureSize, _nstructs,
_structSize, _nviews, _compressionMethod);
// pass all the views of all the samples through the generated trees and accumulate
// the statistics (posterior probabilities) in leaves.
Mat patch;
int i, j, nsamples = (int)points.size();
RNG& rng = theRNG();
for( i = 0; i < nsamples; i++ )
int globalPointIdx = 0;
for( size_t imgIdx = 0; imgIdx < points.size(); imgIdx++ )
{
Point2f pt = points[i];
const Mat& src = *refimgs[i];
int classId = labels.empty() ? i : labels[i];
if( verbose && (i+1)*progressBarSize/nsamples != i*progressBarSize/nsamples )
putchar('.');
CV_Assert( 0 <= classId && classId < nclasses );
classCounters[classId] += _nviews;
for( j = 0; j < _nviews; j++ )
const Point2f* imgPoints = &points[imgIdx][0];
const int* imgLabels = labels.empty() ? 0 : &labels[imgIdx][0];
for( size_t pointIdx = 0; pointIdx < points[imgIdx].size(); pointIdx++, globalPointIdx++ )
{
patchGenerator(src, pt, patch, patchSize, rng);
for( int f = 0; f < nstructs; f++ )
posteriors[getLeaf(f, patch)*nclasses + classId]++;
Point2f pt = imgPoints[pointIdx];
const Mat& src = refimgs[imgIdx];
int classId = imgLabels==0 ? globalPointIdx : imgLabels[pointIdx];
if( verbose && (globalPointIdx+1)*progressBarSize/numPoints != globalPointIdx*progressBarSize/numPoints )
putchar('.');
CV_Assert( 0 <= classId && classId < nclasses );
classCounters[classId] += _nviews;
for( int v = 0; v < _nviews; v++ )
{
patchGenerator(src, pt, patch, patchSize, rng);
for( int f = 0; f < nstructs; f++ )
posteriors[getLeaf(f, patch)*nclasses + classId]++;
}
}
}
if( verbose )