modified EM interface; updated tests & samples

This commit is contained in:
Vadim Pisarevsky
2012-04-06 15:59:30 +00:00
parent 1c1c6b98f6
commit b8c310065c
8 changed files with 338 additions and 333 deletions

View File

@@ -320,6 +320,30 @@ void CV_KNearestTest::run( int /*start_from*/ )
ts->set_failed_test_info( code );
}
class EM_Params
{
public:
EM_Params(int nclusters=10, int covMatType=EM::COV_MAT_DIAGONAL, int startStep=EM::START_AUTO_STEP,
const cv::TermCriteria& termCrit=cv::TermCriteria(cv::TermCriteria::COUNT+cv::TermCriteria::EPS, 100, FLT_EPSILON),
const cv::Mat* probs=0, const cv::Mat* weights=0,
const cv::Mat* means=0, const std::vector<cv::Mat>* covs=0)
: nclusters(nclusters), covMatType(covMatType), startStep(startStep),
probs(probs), weights(weights), means(means), covs(covs), termCrit(termCrit)
{}
int nclusters;
int covMatType;
int startStep;
// all 4 following matrices should have type CV_32FC1
const cv::Mat* probs;
const cv::Mat* weights;
const cv::Mat* means;
const std::vector<cv::Mat>* covs;
cv::TermCriteria termCrit;
};
//--------------------------------------------------------------------------------------------
class CV_EMTest : public cvtest::BaseTest
{
@@ -327,13 +351,13 @@ public:
CV_EMTest() {}
protected:
virtual void run( int start_from );
int runCase( int caseIndex, const cv::EM::Params& params,
int runCase( int caseIndex, const EM_Params& params,
const cv::Mat& trainData, const cv::Mat& trainLabels,
const cv::Mat& testData, const cv::Mat& testLabels,
const vector<int>& sizes);
};
int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
const cv::Mat& trainData, const cv::Mat& trainLabels,
const cv::Mat& testData, const cv::Mat& testLabels,
const vector<int>& sizes )
@@ -343,8 +367,13 @@ int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
cv::Mat labels;
float err;
cv::EM em;
em.train( trainData, Mat(), params, &labels );
cv::EM em(params.nclusters, params.covMatType, params.termCrit);
if( params.startStep == EM::START_AUTO_STEP )
em.train( trainData, labels );
else if( params.startStep == EM::START_E_STEP )
em.trainE( trainData, *params.means, *params.covs, *params.weights, labels );
else if( params.startStep == EM::START_M_STEP )
em.trainM( trainData, *params.probs, labels );
// check train error
if( !calcErr( labels, trainLabels, sizes, err , false ) )
@@ -363,7 +392,7 @@ int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
for( int i = 0; i < testData.rows; i++ )
{
Mat sample = testData.row(i);
labels.at<int>(i,0) = (int)em.predict( sample, 0 );
labels.at<int>(i,0) = (int)em.predict( sample, noArray() );
}
if( !calcErr( labels, testLabels, sizes, err, false ) )
{
@@ -398,7 +427,7 @@ void CV_EMTest::run( int /*start_from*/ )
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels;
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
cv::EM::Params params;
EM_Params params;
params.nclusters = 3;
Mat probs(trainData.rows, params.nclusters, CV_32FC1, cv::Scalar(1));
params.probs = &probs;
@@ -474,19 +503,16 @@ protected:
virtual void run( int /*start_from*/ )
{
int code = cvtest::TS::OK;
cv::EM em;
cv::EM em(2);
Mat samples = Mat(3,1,CV_32F);
samples.at<float>(0,0) = 1;
samples.at<float>(1,0) = 2;
samples.at<float>(2,0) = 3;
cv::EM::Params params;
params.nclusters = 2;
Mat labels;
em.train(samples, Mat(), params, &labels);
em.train(samples, labels);
Mat firstResult(samples.rows, 1, CV_32FC1);
for( int i = 0; i < samples.rows; i++)