first implementation KNearest wrapper on KDTree
This commit is contained in:
		@@ -230,10 +230,11 @@ public:
 | 
			
		||||
    class CV_EXPORTS_W_MAP Params
 | 
			
		||||
    {
 | 
			
		||||
    public:
 | 
			
		||||
        Params(int defaultK=10, bool isclassifier=true);
 | 
			
		||||
        Params(int defaultK=10, bool isclassifier_=true, int Emax_=INT_MAX);
 | 
			
		||||
 | 
			
		||||
        CV_PROP_RW int defaultK;
 | 
			
		||||
        CV_PROP_RW bool isclassifier;
 | 
			
		||||
        CV_PROP_RW int Emax; // for implementation with KDTree
 | 
			
		||||
    };
 | 
			
		||||
    virtual void setParams(const Params& p) = 0;
 | 
			
		||||
    virtual Params getParams() const = 0;
 | 
			
		||||
@@ -241,7 +242,10 @@ public:
 | 
			
		||||
                               OutputArray results,
 | 
			
		||||
                               OutputArray neighborResponses=noArray(),
 | 
			
		||||
                               OutputArray dist=noArray() ) const = 0;
 | 
			
		||||
    static Ptr<KNearest> create(const Params& params=Params());
 | 
			
		||||
 | 
			
		||||
    enum { DEFAULT=1, KDTREE=2 };
 | 
			
		||||
 | 
			
		||||
    static Ptr<KNearest> create(const Params& params=Params(), int type=DEFAULT);
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/****************************************************************************************\
 | 
			
		||||
 
 | 
			
		||||
@@ -49,10 +49,11 @@
 | 
			
		||||
namespace cv {
 | 
			
		||||
namespace ml {
 | 
			
		||||
 | 
			
		||||
KNearest::Params::Params(int k, bool isclassifier_)
 | 
			
		||||
KNearest::Params::Params(int k, bool isclassifier_, int Emax_)
 | 
			
		||||
{
 | 
			
		||||
    defaultK = k;
 | 
			
		||||
    isclassifier = isclassifier_;
 | 
			
		||||
    Emax = Emax_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -352,8 +353,156 @@ public:
 | 
			
		||||
    Params params;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
Ptr<KNearest> KNearest::create(const Params& p)
 | 
			
		||||
 | 
			
		||||
class KNearestKDTreeImpl : public KNearest
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    KNearestKDTreeImpl(const Params& p)
 | 
			
		||||
    {
 | 
			
		||||
        params = p;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    virtual ~KNearestKDTreeImpl() {}
 | 
			
		||||
 | 
			
		||||
    Params getParams() const { return params; }
 | 
			
		||||
    void setParams(const Params& p) { params = p; }
 | 
			
		||||
 | 
			
		||||
    bool isClassifier() const { return params.isclassifier; }
 | 
			
		||||
    bool isTrained() const { return !samples.empty(); }
 | 
			
		||||
 | 
			
		||||
    String getDefaultModelName() const { return "opencv_ml_knn_kd"; }
 | 
			
		||||
 | 
			
		||||
    void clear()
 | 
			
		||||
    {
 | 
			
		||||
        samples.release();
 | 
			
		||||
        responses.release();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    int getVarCount() const { return samples.cols; }
 | 
			
		||||
 | 
			
		||||
    bool train( const Ptr<TrainData>& data, int flags )
 | 
			
		||||
    {
 | 
			
		||||
        Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
 | 
			
		||||
        Mat new_responses;
 | 
			
		||||
        data->getTrainResponses().convertTo(new_responses, CV_32F);
 | 
			
		||||
        bool update = (flags & UPDATE_MODEL) != 0 && !samples.empty();
 | 
			
		||||
 | 
			
		||||
        CV_Assert( new_samples.type() == CV_32F );
 | 
			
		||||
 | 
			
		||||
        if( !update )
 | 
			
		||||
        {
 | 
			
		||||
            clear();
 | 
			
		||||
        }
 | 
			
		||||
        else
 | 
			
		||||
        {
 | 
			
		||||
            CV_Assert( new_samples.cols == samples.cols &&
 | 
			
		||||
                       new_responses.cols == responses.cols );
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        samples.push_back(new_samples);
 | 
			
		||||
        responses.push_back(new_responses);
 | 
			
		||||
 | 
			
		||||
        tr.build(samples);
 | 
			
		||||
 | 
			
		||||
        return true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    float findNearest( InputArray _samples, int k,
 | 
			
		||||
                       OutputArray _results,
 | 
			
		||||
                       OutputArray _neighborResponses,
 | 
			
		||||
                       OutputArray _dists ) const
 | 
			
		||||
    {
 | 
			
		||||
        float result = 0.f;
 | 
			
		||||
        CV_Assert( 0 < k );
 | 
			
		||||
 | 
			
		||||
        Mat test_samples = _samples.getMat();
 | 
			
		||||
        CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
 | 
			
		||||
        int testcount = test_samples.rows;
 | 
			
		||||
 | 
			
		||||
        if( testcount == 0 )
 | 
			
		||||
        {
 | 
			
		||||
            _results.release();
 | 
			
		||||
            _neighborResponses.release();
 | 
			
		||||
            _dists.release();
 | 
			
		||||
            return 0.f;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        Mat res, nr, d;
 | 
			
		||||
        if( _results.needed() )
 | 
			
		||||
        {
 | 
			
		||||
            _results.create(testcount, 1, CV_32F);
 | 
			
		||||
            res = _results.getMat();
 | 
			
		||||
        }
 | 
			
		||||
        if( _neighborResponses.needed() )
 | 
			
		||||
        {
 | 
			
		||||
            _neighborResponses.create(testcount, k, CV_32F);
 | 
			
		||||
            nr = _neighborResponses.getMat();
 | 
			
		||||
        }
 | 
			
		||||
        if( _dists.needed() )
 | 
			
		||||
        {
 | 
			
		||||
            _dists.create(testcount, k, CV_32F);
 | 
			
		||||
            d = _dists.getMat();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for (int i=0; i<test_samples.rows; ++i)
 | 
			
		||||
        {
 | 
			
		||||
            Mat _res, _nr, _d;
 | 
			
		||||
            if (res.rows>i)
 | 
			
		||||
            {
 | 
			
		||||
                _res = res.row(i);
 | 
			
		||||
            }
 | 
			
		||||
            if (nr.rows>i)
 | 
			
		||||
            {
 | 
			
		||||
                _nr = nr.row(i);
 | 
			
		||||
            }
 | 
			
		||||
            if (d.rows>i)
 | 
			
		||||
            {
 | 
			
		||||
                _d = d.row(i);
 | 
			
		||||
            }
 | 
			
		||||
            tr.findNearest(test_samples.row(i), k, params.Emax, _res, _nr, _d, noArray());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return result; // currently always 0
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    float predict(InputArray inputs, OutputArray outputs, int) const
 | 
			
		||||
    {
 | 
			
		||||
        return findNearest( inputs, params.defaultK, outputs, noArray(), noArray() );
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void write( FileStorage& fs ) const
 | 
			
		||||
    {
 | 
			
		||||
        fs << "is_classifier" << (int)params.isclassifier;
 | 
			
		||||
        fs << "default_k" << params.defaultK;
 | 
			
		||||
 | 
			
		||||
        fs << "samples" << samples;
 | 
			
		||||
        fs << "responses" << responses;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    void read( const FileNode& fn )
 | 
			
		||||
    {
 | 
			
		||||
        clear();
 | 
			
		||||
        params.isclassifier = (int)fn["is_classifier"] != 0;
 | 
			
		||||
        params.defaultK = (int)fn["default_k"];
 | 
			
		||||
 | 
			
		||||
        fn["samples"] >> samples;
 | 
			
		||||
        fn["responses"] >> responses;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    KDTree tr;
 | 
			
		||||
 | 
			
		||||
    Mat samples;
 | 
			
		||||
    Mat responses;
 | 
			
		||||
    Params params;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
Ptr<KNearest> KNearest::create(const Params& p, int type)
 | 
			
		||||
{
 | 
			
		||||
    if (KDTREE==type)
 | 
			
		||||
    {
 | 
			
		||||
        return makePtr<KNearestKDTreeImpl>(p);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return makePtr<KNearestImpl>(p);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -312,9 +312,11 @@ void CV_KNearestTest::run( int /*start_from*/ )
 | 
			
		||||
    generateData( testData, testLabels, sizes, means, covs, CV_32FC1, CV_32FC1 );
 | 
			
		||||
 | 
			
		||||
    int code = cvtest::TS::OK;
 | 
			
		||||
    Ptr<KNearest> knearest = KNearest::create(true);
 | 
			
		||||
    knearest->train(trainData, cv::ml::ROW_SAMPLE, trainLabels);
 | 
			
		||||
    knearest->findNearest( testData, 4, bestLabels);
 | 
			
		||||
 | 
			
		||||
    // KNearest default implementation
 | 
			
		||||
    Ptr<KNearest> knearest = KNearest::create();
 | 
			
		||||
    knearest->train(trainData, ml::ROW_SAMPLE, trainLabels);
 | 
			
		||||
    knearest->findNearest(testData, 4, bestLabels);
 | 
			
		||||
    float err;
 | 
			
		||||
    if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
 | 
			
		||||
    {
 | 
			
		||||
@@ -326,6 +328,17 @@ void CV_KNearestTest::run( int /*start_from*/ )
 | 
			
		||||
        ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
 | 
			
		||||
        code = cvtest::TS::FAIL_BAD_ACCURACY;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // KNearest KDTree implementation
 | 
			
		||||
    Ptr<KNearest> knearestKdt = KNearest::create(ml::KNearest::Params(), ml::KNearest::KDTREE);
 | 
			
		||||
    knearestKdt->train(trainData, ml::ROW_SAMPLE, trainLabels);
 | 
			
		||||
    knearestKdt->findNearest(testData, 4, bestLabels);
 | 
			
		||||
    if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
 | 
			
		||||
    {
 | 
			
		||||
        ts->printf( cvtest::TS::LOG, "Bad output labels.\n" );
 | 
			
		||||
        code = cvtest::TS::FAIL_INVALID_OUTPUT;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ts->set_failed_test_info( code );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user