fixed em test
This commit is contained in:
parent
94c258cf15
commit
008a1c91fd
@ -87,8 +87,10 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const vecto
|
||||
r = r * (*cit) + *mit;
|
||||
if( labelType == CV_32FC1 )
|
||||
labels.at<float>(p, 0) = (float)l;
|
||||
else
|
||||
else if( labelType == CV_32SC1 )
|
||||
labels.at<int>(p, 0) = l;
|
||||
else
|
||||
CV_DbgAssert(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -201,20 +203,23 @@ void CV_KMeansTest::run( int /*start_from*/ )
|
||||
generateData( data, labels, sizes, means, covs, CV_32SC1 );
|
||||
|
||||
int code = cvtest::TS::OK;
|
||||
float err;
|
||||
Mat bestLabels;
|
||||
// 1. flag==KMEANS_PP_CENTERS
|
||||
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_PP_CENTERS, noArray() );
|
||||
if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
|
||||
err = calcErr( bestLabels, labels, sizes, false );
|
||||
if( err > 0.01f )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
|
||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
|
||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||
}
|
||||
|
||||
// 2. flag==KMEANS_RANDOM_CENTERS
|
||||
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_RANDOM_CENTERS, noArray() );
|
||||
if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
|
||||
err = calcErr( bestLabels, labels, sizes, false );
|
||||
if( err > 0.01f )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
|
||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
|
||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||
}
|
||||
|
||||
@ -224,9 +229,10 @@ void CV_KMeansTest::run( int /*start_from*/ )
|
||||
for( int i = 0; i < 0.5f * pointsCount; i++ )
|
||||
bestLabels.at<int>( rng.next() % pointsCount, 0 ) = rng.next() % 3;
|
||||
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_USE_INITIAL_LABELS, noArray() );
|
||||
if( calcErr( bestLabels, labels, sizes, false ) > 0.01f )
|
||||
err = calcErr( bestLabels, labels, sizes, false );
|
||||
if( err > 0.01f )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "bad accuracy if flag==KMEANS_PP_CENTERS" );
|
||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
|
||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||
}
|
||||
|
||||
@ -261,9 +267,10 @@ void CV_KNearestTest::run( int /*start_from*/ )
|
||||
KNearest knearest;
|
||||
knearest.train( trainData, trainLabels );
|
||||
knearest.find_nearest( testData, 4, &bestLabels );
|
||||
if( calcErr( bestLabels, testLabels, sizes, true ) > 0.01f )
|
||||
float err = calcErr( bestLabels, testLabels, sizes, true );
|
||||
if( err > 0.01f )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "bad accuracy on test data" );
|
||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
|
||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||
}
|
||||
ts->set_failed_test_info( code );
|
||||
@ -294,15 +301,17 @@ void CV_EMTest::run( int /*start_from*/ )
|
||||
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
|
||||
|
||||
int code = cvtest::TS::OK;
|
||||
float err;
|
||||
ExpectationMaximization em;
|
||||
CvEMParams params;
|
||||
params.nclusters = 3;
|
||||
em.train( trainData, Mat(), params, &bestLabels );
|
||||
|
||||
// check train error
|
||||
if( calcErr( bestLabels, trainLabels, sizes, true ) > 0.002f )
|
||||
err = calcErr( bestLabels, trainLabels, sizes, false );
|
||||
if( err > 0.002f )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "bad accuracy on train data" );
|
||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on train data.\n", err );
|
||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||
}
|
||||
|
||||
@ -313,9 +322,10 @@ void CV_EMTest::run( int /*start_from*/ )
|
||||
Mat sample( 1, testData.cols, CV_32FC1, testData.ptr<float>(i));
|
||||
bestLabels.at<int>(i,0) = (int)em.predict( sample, 0 );
|
||||
}
|
||||
if( calcErr( bestLabels, testLabels, sizes, true ) > 0.005f )
|
||||
err = calcErr( bestLabels, testLabels, sizes, false );
|
||||
if( err > 0.005f )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "bad accuracy on test data" );
|
||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
|
||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||
}
|
||||
|
||||
@ -324,4 +334,4 @@ void CV_EMTest::run( int /*start_from*/ )
|
||||
|
||||
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
|
||||
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
|
||||
TEST(ML_EMTest, accuracy) { CV_EMTest test; test.safe_run(); }
|
||||
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
|
||||
|
Loading…
x
Reference in New Issue
Block a user