From 5afd0e211e15109ee1e743c74ab13e9f4c1ab555 Mon Sep 17 00:00:00 2001 From: berak Date: Mon, 4 Jan 2016 11:47:08 +0100 Subject: [PATCH] ml: fix NormalBayesClassifier bulk prediction(#5911) --- modules/ml/src/nbayes.cpp | 2 +- modules/ml/test/test_mltests.cpp | 44 ++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/modules/ml/src/nbayes.cpp b/modules/ml/src/nbayes.cpp index 221db93e2..733dcf39f 100644 --- a/modules/ml/src/nbayes.cpp +++ b/modules/ml/src/nbayes.cpp @@ -236,7 +236,7 @@ public: if (results_prob) { rptype = results_prob->type(); - rpstep = results_prob->isContinuous() ? 1 : results_prob->step/results_prob->elemSize(); + rpstep = results_prob->isContinuous() ? results_prob->cols : results_prob->step/results_prob->elemSize(); } // allocate memory and initializing headers for calculating cv::AutoBuffer _buffer(nvars*2); diff --git a/modules/ml/test/test_mltests.cpp b/modules/ml/test/test_mltests.cpp index 2ffa531ec..70cc0f7ec 100644 --- a/modules/ml/test/test_mltests.cpp +++ b/modules/ml/test/test_mltests.cpp @@ -128,4 +128,48 @@ TEST(ML_Boost, regression) { CV_AMLTest test( CV_BOOST ); test.safe_run(); } TEST(ML_RTrees, regression) { CV_AMLTest test( CV_RTREES ); test.safe_run(); } TEST(DISABLED_ML_ERTrees, regression) { CV_AMLTest test( CV_ERTREES ); test.safe_run(); } +TEST(ML_NBAYES, regression_5911) +{ + int N=12; + Ptr nb = cv::ml::NormalBayesClassifier::create(); + + // data: + Mat_ X(N,4); + X << 1,2,3,4, 1,2,3,4, 1,2,3,4, 1,2,3,4, + 5,5,5,5, 5,5,5,5, 5,5,5,5, 5,5,5,5, + 4,3,2,1, 4,3,2,1, 4,3,2,1, 4,3,2,1; + + // labels: + Mat_ Y(N,1); + Y << 0,0,0,0, 1,1,1,1, 2,2,2,2; + nb->train(X, ml::ROW_SAMPLE, Y); + + // single prediction: + Mat R1,P1; + for (int i=0; ipredictProb(X.row(i), r, p); + R1.push_back(r); + P1.push_back(p); + } + + // bulk prediction (continuous memory): + Mat R2,P2; + nb->predictProb(X, R2, P2); + + EXPECT_EQ(sum(R1 == R2)[0], 255 * R2.total()); + EXPECT_EQ(sum(P1 == P2)[0], 255 * P2.total()); + + // bulk prediction, with non-continuous memory storage + Mat R3_(N, 1+1, CV_32S), + P3_(N, 3+1, CV_32F); + nb->predictProb(X, R3_.col(0), P3_.colRange(0,3)); + Mat R3 = R3_.col(0).clone(), + P3 = P3_.colRange(0,3).clone(); + + EXPECT_EQ(sum(R1 == R3)[0], 255 * R3.total()); + EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total()); +} + /* End of file. */