ml: fix NormalBayesClassifier bulk prediction(#5911)

This commit is contained in:
berak
2016-01-04 11:47:08 +01:00
parent fade402899
commit 5afd0e211e
2 changed files with 45 additions and 1 deletions

View File

@@ -236,7 +236,7 @@ public:
if (results_prob) if (results_prob)
{ {
rptype = results_prob->type(); rptype = results_prob->type();
rpstep = results_prob->isContinuous() ? 1 : results_prob->step/results_prob->elemSize(); rpstep = results_prob->isContinuous() ? results_prob->cols : results_prob->step/results_prob->elemSize();
} }
// allocate memory and initializing headers for calculating // allocate memory and initializing headers for calculating
cv::AutoBuffer<double> _buffer(nvars*2); cv::AutoBuffer<double> _buffer(nvars*2);

View File

@@ -128,4 +128,48 @@ TEST(ML_Boost, regression) { CV_AMLTest test( CV_BOOST ); test.safe_run(); }
TEST(ML_RTrees, regression) { CV_AMLTest test( CV_RTREES ); test.safe_run(); } TEST(ML_RTrees, regression) { CV_AMLTest test( CV_RTREES ); test.safe_run(); }
TEST(DISABLED_ML_ERTrees, regression) { CV_AMLTest test( CV_ERTREES ); test.safe_run(); } TEST(DISABLED_ML_ERTrees, regression) { CV_AMLTest test( CV_ERTREES ); test.safe_run(); }
TEST(ML_NBAYES, regression_5911)
{
int N=12;
Ptr<ml::NormalBayesClassifier> nb = cv::ml::NormalBayesClassifier::create();
// data:
Mat_<float> X(N,4);
X << 1,2,3,4, 1,2,3,4, 1,2,3,4, 1,2,3,4,
5,5,5,5, 5,5,5,5, 5,5,5,5, 5,5,5,5,
4,3,2,1, 4,3,2,1, 4,3,2,1, 4,3,2,1;
// labels:
Mat_<int> Y(N,1);
Y << 0,0,0,0, 1,1,1,1, 2,2,2,2;
nb->train(X, ml::ROW_SAMPLE, Y);
// single prediction:
Mat R1,P1;
for (int i=0; i<N; i++)
{
Mat r,p;
nb->predictProb(X.row(i), r, p);
R1.push_back(r);
P1.push_back(p);
}
// bulk prediction (continuous memory):
Mat R2,P2;
nb->predictProb(X, R2, P2);
EXPECT_EQ(sum(R1 == R2)[0], 255 * R2.total());
EXPECT_EQ(sum(P1 == P2)[0], 255 * P2.total());
// bulk prediction, with non-continuous memory storage
Mat R3_(N, 1+1, CV_32S),
P3_(N, 3+1, CV_32F);
nb->predictProb(X, R3_.col(0), P3_.colRange(0,3));
Mat R3 = R3_.col(0).clone(),
P3 = P3_.colRange(0,3).clone();
EXPECT_EQ(sum(R1 == R3)[0], 255 * R3.total());
EXPECT_EQ(sum(P1 == P3)[0], 255 * P3.total());
}
/* End of file. */ /* End of file. */