modified EM interface; updated tests & samples
This commit is contained in:
@@ -320,6 +320,30 @@ void CV_KNearestTest::run( int /*start_from*/ )
|
||||
ts->set_failed_test_info( code );
|
||||
}
|
||||
|
||||
class EM_Params
|
||||
{
|
||||
public:
|
||||
EM_Params(int nclusters=10, int covMatType=EM::COV_MAT_DIAGONAL, int startStep=EM::START_AUTO_STEP,
|
||||
const cv::TermCriteria& termCrit=cv::TermCriteria(cv::TermCriteria::COUNT+cv::TermCriteria::EPS, 100, FLT_EPSILON),
|
||||
const cv::Mat* probs=0, const cv::Mat* weights=0,
|
||||
const cv::Mat* means=0, const std::vector<cv::Mat>* covs=0)
|
||||
: nclusters(nclusters), covMatType(covMatType), startStep(startStep),
|
||||
probs(probs), weights(weights), means(means), covs(covs), termCrit(termCrit)
|
||||
{}
|
||||
|
||||
int nclusters;
|
||||
int covMatType;
|
||||
int startStep;
|
||||
|
||||
// all 4 following matrices should have type CV_32FC1
|
||||
const cv::Mat* probs;
|
||||
const cv::Mat* weights;
|
||||
const cv::Mat* means;
|
||||
const std::vector<cv::Mat>* covs;
|
||||
|
||||
cv::TermCriteria termCrit;
|
||||
};
|
||||
|
||||
//--------------------------------------------------------------------------------------------
|
||||
class CV_EMTest : public cvtest::BaseTest
|
||||
{
|
||||
@@ -327,13 +351,13 @@ public:
|
||||
CV_EMTest() {}
|
||||
protected:
|
||||
virtual void run( int start_from );
|
||||
int runCase( int caseIndex, const cv::EM::Params& params,
|
||||
int runCase( int caseIndex, const EM_Params& params,
|
||||
const cv::Mat& trainData, const cv::Mat& trainLabels,
|
||||
const cv::Mat& testData, const cv::Mat& testLabels,
|
||||
const vector<int>& sizes);
|
||||
};
|
||||
|
||||
int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
|
||||
int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
|
||||
const cv::Mat& trainData, const cv::Mat& trainLabels,
|
||||
const cv::Mat& testData, const cv::Mat& testLabels,
|
||||
const vector<int>& sizes )
|
||||
@@ -343,8 +367,13 @@ int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
|
||||
cv::Mat labels;
|
||||
float err;
|
||||
|
||||
cv::EM em;
|
||||
em.train( trainData, Mat(), params, &labels );
|
||||
cv::EM em(params.nclusters, params.covMatType, params.termCrit);
|
||||
if( params.startStep == EM::START_AUTO_STEP )
|
||||
em.train( trainData, labels );
|
||||
else if( params.startStep == EM::START_E_STEP )
|
||||
em.trainE( trainData, *params.means, *params.covs, *params.weights, labels );
|
||||
else if( params.startStep == EM::START_M_STEP )
|
||||
em.trainM( trainData, *params.probs, labels );
|
||||
|
||||
// check train error
|
||||
if( !calcErr( labels, trainLabels, sizes, err , false ) )
|
||||
@@ -363,7 +392,7 @@ int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
|
||||
for( int i = 0; i < testData.rows; i++ )
|
||||
{
|
||||
Mat sample = testData.row(i);
|
||||
labels.at<int>(i,0) = (int)em.predict( sample, 0 );
|
||||
labels.at<int>(i,0) = (int)em.predict( sample, noArray() );
|
||||
}
|
||||
if( !calcErr( labels, testLabels, sizes, err, false ) )
|
||||
{
|
||||
@@ -398,7 +427,7 @@ void CV_EMTest::run( int /*start_from*/ )
|
||||
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels;
|
||||
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
|
||||
|
||||
cv::EM::Params params;
|
||||
EM_Params params;
|
||||
params.nclusters = 3;
|
||||
Mat probs(trainData.rows, params.nclusters, CV_32FC1, cv::Scalar(1));
|
||||
params.probs = &probs;
|
||||
@@ -474,19 +503,16 @@ protected:
|
||||
virtual void run( int /*start_from*/ )
|
||||
{
|
||||
int code = cvtest::TS::OK;
|
||||
cv::EM em;
|
||||
cv::EM em(2);
|
||||
|
||||
Mat samples = Mat(3,1,CV_32F);
|
||||
samples.at<float>(0,0) = 1;
|
||||
samples.at<float>(1,0) = 2;
|
||||
samples.at<float>(2,0) = 3;
|
||||
|
||||
cv::EM::Params params;
|
||||
params.nclusters = 2;
|
||||
|
||||
Mat labels;
|
||||
|
||||
em.train(samples, Mat(), params, &labels);
|
||||
em.train(samples, labels);
|
||||
|
||||
Mat firstResult(samples.rows, 1, CV_32FC1);
|
||||
for( int i = 0; i < samples.rows; i++)
|
||||
|
Reference in New Issue
Block a user