first implementation KNearest wrapper on KDTree
This commit is contained in:
parent
37b1a7560c
commit
9ddb23e025
@ -230,10 +230,11 @@ public:
|
|||||||
class CV_EXPORTS_W_MAP Params
|
class CV_EXPORTS_W_MAP Params
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Params(int defaultK=10, bool isclassifier=true);
|
Params(int defaultK=10, bool isclassifier_=true, int Emax_=INT_MAX);
|
||||||
|
|
||||||
CV_PROP_RW int defaultK;
|
CV_PROP_RW int defaultK;
|
||||||
CV_PROP_RW bool isclassifier;
|
CV_PROP_RW bool isclassifier;
|
||||||
|
CV_PROP_RW int Emax; // for implementation with KDTree
|
||||||
};
|
};
|
||||||
virtual void setParams(const Params& p) = 0;
|
virtual void setParams(const Params& p) = 0;
|
||||||
virtual Params getParams() const = 0;
|
virtual Params getParams() const = 0;
|
||||||
@ -241,7 +242,10 @@ public:
|
|||||||
OutputArray results,
|
OutputArray results,
|
||||||
OutputArray neighborResponses=noArray(),
|
OutputArray neighborResponses=noArray(),
|
||||||
OutputArray dist=noArray() ) const = 0;
|
OutputArray dist=noArray() ) const = 0;
|
||||||
static Ptr<KNearest> create(const Params& params=Params());
|
|
||||||
|
enum { DEFAULT=1, KDTREE=2 };
|
||||||
|
|
||||||
|
static Ptr<KNearest> create(const Params& params=Params(), int type=DEFAULT);
|
||||||
};
|
};
|
||||||
|
|
||||||
/****************************************************************************************\
|
/****************************************************************************************\
|
||||||
|
@ -49,10 +49,11 @@
|
|||||||
namespace cv {
|
namespace cv {
|
||||||
namespace ml {
|
namespace ml {
|
||||||
|
|
||||||
KNearest::Params::Params(int k, bool isclassifier_)
|
KNearest::Params::Params(int k, bool isclassifier_, int Emax_)
|
||||||
{
|
{
|
||||||
defaultK = k;
|
defaultK = k;
|
||||||
isclassifier = isclassifier_;
|
isclassifier = isclassifier_;
|
||||||
|
Emax = Emax_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -352,8 +353,156 @@ public:
|
|||||||
Params params;
|
Params params;
|
||||||
};
|
};
|
||||||
|
|
||||||
Ptr<KNearest> KNearest::create(const Params& p)
|
|
||||||
|
class KNearestKDTreeImpl : public KNearest
|
||||||
{
|
{
|
||||||
|
public:
|
||||||
|
KNearestKDTreeImpl(const Params& p)
|
||||||
|
{
|
||||||
|
params = p;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual ~KNearestKDTreeImpl() {}
|
||||||
|
|
||||||
|
Params getParams() const { return params; }
|
||||||
|
void setParams(const Params& p) { params = p; }
|
||||||
|
|
||||||
|
bool isClassifier() const { return params.isclassifier; }
|
||||||
|
bool isTrained() const { return !samples.empty(); }
|
||||||
|
|
||||||
|
String getDefaultModelName() const { return "opencv_ml_knn_kd"; }
|
||||||
|
|
||||||
|
void clear()
|
||||||
|
{
|
||||||
|
samples.release();
|
||||||
|
responses.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
int getVarCount() const { return samples.cols; }
|
||||||
|
|
||||||
|
bool train( const Ptr<TrainData>& data, int flags )
|
||||||
|
{
|
||||||
|
Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
|
||||||
|
Mat new_responses;
|
||||||
|
data->getTrainResponses().convertTo(new_responses, CV_32F);
|
||||||
|
bool update = (flags & UPDATE_MODEL) != 0 && !samples.empty();
|
||||||
|
|
||||||
|
CV_Assert( new_samples.type() == CV_32F );
|
||||||
|
|
||||||
|
if( !update )
|
||||||
|
{
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
CV_Assert( new_samples.cols == samples.cols &&
|
||||||
|
new_responses.cols == responses.cols );
|
||||||
|
}
|
||||||
|
|
||||||
|
samples.push_back(new_samples);
|
||||||
|
responses.push_back(new_responses);
|
||||||
|
|
||||||
|
tr.build(samples);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
float findNearest( InputArray _samples, int k,
|
||||||
|
OutputArray _results,
|
||||||
|
OutputArray _neighborResponses,
|
||||||
|
OutputArray _dists ) const
|
||||||
|
{
|
||||||
|
float result = 0.f;
|
||||||
|
CV_Assert( 0 < k );
|
||||||
|
|
||||||
|
Mat test_samples = _samples.getMat();
|
||||||
|
CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
|
||||||
|
int testcount = test_samples.rows;
|
||||||
|
|
||||||
|
if( testcount == 0 )
|
||||||
|
{
|
||||||
|
_results.release();
|
||||||
|
_neighborResponses.release();
|
||||||
|
_dists.release();
|
||||||
|
return 0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat res, nr, d;
|
||||||
|
if( _results.needed() )
|
||||||
|
{
|
||||||
|
_results.create(testcount, 1, CV_32F);
|
||||||
|
res = _results.getMat();
|
||||||
|
}
|
||||||
|
if( _neighborResponses.needed() )
|
||||||
|
{
|
||||||
|
_neighborResponses.create(testcount, k, CV_32F);
|
||||||
|
nr = _neighborResponses.getMat();
|
||||||
|
}
|
||||||
|
if( _dists.needed() )
|
||||||
|
{
|
||||||
|
_dists.create(testcount, k, CV_32F);
|
||||||
|
d = _dists.getMat();
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i=0; i<test_samples.rows; ++i)
|
||||||
|
{
|
||||||
|
Mat _res, _nr, _d;
|
||||||
|
if (res.rows>i)
|
||||||
|
{
|
||||||
|
_res = res.row(i);
|
||||||
|
}
|
||||||
|
if (nr.rows>i)
|
||||||
|
{
|
||||||
|
_nr = nr.row(i);
|
||||||
|
}
|
||||||
|
if (d.rows>i)
|
||||||
|
{
|
||||||
|
_d = d.row(i);
|
||||||
|
}
|
||||||
|
tr.findNearest(test_samples.row(i), k, params.Emax, _res, _nr, _d, noArray());
|
||||||
|
}
|
||||||
|
|
||||||
|
return result; // currently always 0
|
||||||
|
}
|
||||||
|
|
||||||
|
float predict(InputArray inputs, OutputArray outputs, int) const
|
||||||
|
{
|
||||||
|
return findNearest( inputs, params.defaultK, outputs, noArray(), noArray() );
|
||||||
|
}
|
||||||
|
|
||||||
|
void write( FileStorage& fs ) const
|
||||||
|
{
|
||||||
|
fs << "is_classifier" << (int)params.isclassifier;
|
||||||
|
fs << "default_k" << params.defaultK;
|
||||||
|
|
||||||
|
fs << "samples" << samples;
|
||||||
|
fs << "responses" << responses;
|
||||||
|
}
|
||||||
|
|
||||||
|
void read( const FileNode& fn )
|
||||||
|
{
|
||||||
|
clear();
|
||||||
|
params.isclassifier = (int)fn["is_classifier"] != 0;
|
||||||
|
params.defaultK = (int)fn["default_k"];
|
||||||
|
|
||||||
|
fn["samples"] >> samples;
|
||||||
|
fn["responses"] >> responses;
|
||||||
|
}
|
||||||
|
|
||||||
|
KDTree tr;
|
||||||
|
|
||||||
|
Mat samples;
|
||||||
|
Mat responses;
|
||||||
|
Params params;
|
||||||
|
};
|
||||||
|
|
||||||
|
Ptr<KNearest> KNearest::create(const Params& p, int type)
|
||||||
|
{
|
||||||
|
if (KDTREE==type)
|
||||||
|
{
|
||||||
|
return makePtr<KNearestKDTreeImpl>(p);
|
||||||
|
}
|
||||||
|
|
||||||
return makePtr<KNearestImpl>(p);
|
return makePtr<KNearestImpl>(p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -312,8 +312,10 @@ void CV_KNearestTest::run( int /*start_from*/ )
|
|||||||
generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
|
generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
|
||||||
|
|
||||||
int code = cvtest::TS::OK;
|
int code = cvtest::TS::OK;
|
||||||
Ptr<KNearest> knearest = KNearest::create(true);
|
|
||||||
knearest->train(trainData, cv::ml::ROW_SAMPLE, trainLabels);
|
// KNearest default implementation
|
||||||
|
Ptr<KNearest> knearest = KNearest::create();
|
||||||
|
knearest->train(trainData, ml::ROW_SAMPLE, trainLabels);
|
||||||
knearest->findNearest(testData, 4, bestLabels);
|
knearest->findNearest(testData, 4, bestLabels);
|
||||||
float err;
|
float err;
|
||||||
if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
|
if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
|
||||||
@ -326,6 +328,17 @@ void CV_KNearestTest::run( int /*start_from*/ )
|
|||||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
|
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
|
||||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KNearest KDTree implementation
|
||||||
|
Ptr<KNearest> knearestKdt = KNearest::create(ml::KNearest::Params(), ml::KNearest::KDTREE);
|
||||||
|
knearestKdt->train(trainData, ml::ROW_SAMPLE, trainLabels);
|
||||||
|
knearestKdt->findNearest(testData, 4, bestLabels);
|
||||||
|
if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Bad output labels.\n" );
|
||||||
|
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||||
|
}
|
||||||
|
|
||||||
ts->set_failed_test_info( code );
|
ts->set_failed_test_info( code );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user