added cv::EM, moved CvEM to legacy, added/updated tests

This commit is contained in:
Maria Dimashova
2012-04-06 09:26:11 +00:00
parent cdc5bbc0bc
commit 85fa0e7763
13 changed files with 1726 additions and 1449 deletions

View File

@@ -46,6 +46,10 @@
#ifdef __cplusplus
#include <map>
#include <string>
#include <iostream>
// Apple defines a check() macro somewhere in the debug headers
// that interferes with a method definiton in this header
#undef check
@@ -549,114 +553,93 @@ protected:
/****************************************************************************************\
* Expectation - Maximization *
\****************************************************************************************/
struct CV_EXPORTS_W_MAP CvEMParams
namespace cv
{
CvEMParams();
CvEMParams( int nclusters, int cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
int start_step=0/*CvEM::START_AUTO_STEP*/,
CvTermCriteria term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
const CvMat* probs=0, const CvMat* weights=0, const CvMat* means=0, const CvMat** covs=0 );
CV_PROP_RW int nclusters;
CV_PROP_RW int cov_mat_type;
CV_PROP_RW int start_step;
const CvMat* probs;
const CvMat* weights;
const CvMat* means;
const CvMat** covs;
CV_PROP_RW CvTermCriteria term_crit;
};
class CV_EXPORTS_W CvEM : public CvStatModel
class CV_EXPORTS_W EM : public Algorithm
{
public:
// Type of covariation matrices
enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2};
// The initial step
enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
CV_WRAP CvEM();
CvEM( const CvMat* samples, const CvMat* sampleIdx=0,
CvEMParams params=CvEMParams(), CvMat* labels=0 );
//CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights,
// CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
class CV_EXPORTS_W Params
{
public:
Params(int nclusters=10, int covMatType=EM::COV_MAT_DIAGONAL, int startStep=EM::START_AUTO_STEP,
const cv::TermCriteria& termCrit=cv::TermCriteria(cv::TermCriteria::COUNT+cv::TermCriteria::EPS, 100, FLT_EPSILON),
const cv::Mat* probs=0, const cv::Mat* weights=0,
const cv::Mat* means=0, const std::vector<cv::Mat>* covs=0);
virtual ~CvEM();
int nclusters;
int covMatType;
int startStep;
virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0,
CvEMParams params=CvEMParams(), CvMat* labels=0 );
// all 4 following matrices should have type CV_32FC1
const cv::Mat* probs;
const cv::Mat* weights;
const cv::Mat* means;
const std::vector<cv::Mat>* covs;
virtual float predict( const CvMat* sample, CV_OUT CvMat* probs ) const;
cv::TermCriteria termCrit;
};
#ifndef SWIG
CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(),
CvEMParams params=CvEMParams() );
CV_WRAP virtual bool train( const cv::Mat& samples,
const cv::Mat& sampleIdx=cv::Mat(),
CvEMParams params=CvEMParams(),
CV_OUT cv::Mat* labels=0 );
CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0 ) const;
CV_WRAP virtual double calcLikelihood( const cv::Mat &sample ) const;
CV_WRAP int getNClusters() const;
CV_WRAP cv::Mat getMeans() const;
CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const;
CV_WRAP cv::Mat getWeights() const;
CV_WRAP cv::Mat getProbs() const;
CV_WRAP inline double getLikelihood() const { return log_likelihood; }
CV_WRAP inline double getLikelihoodDelta() const { return log_likelihood_delta; }
#endif
CV_WRAP virtual void clear();
EM();
EM(const cv::Mat& samples, const cv::Mat samplesMask=cv::Mat(),
const EM::Params& params=EM::Params(), cv::Mat* labels=0, cv::Mat* probs=0, cv::Mat* likelihoods=0);
virtual ~EM();
virtual void clear();
int get_nclusters() const;
const CvMat* get_means() const;
const CvMat** get_covs() const;
const CvMat* get_weights() const;
const CvMat* get_probs() const;
virtual bool train(const cv::Mat& samples, const cv::Mat& samplesMask=cv::Mat(),
const EM::Params& params=EM::Params(), cv::Mat* labels=0, cv::Mat* probs=0, cv::Mat* likelihoods=0);
int predict(const cv::Mat& sample, cv::Mat* probs=0, double* likelihood=0) const;
inline double get_log_likelihood() const { return log_likelihood; }
inline double get_log_likelihood_delta() const { return log_likelihood_delta; }
// inline const CvMat * get_log_weight_div_det () const { return log_weight_div_det; };
// inline const CvMat * get_inv_eigen_values () const { return inv_eigen_values; };
// inline const CvMat ** get_cov_rotate_mats () const { return cov_rotate_mats; };
bool isTrained() const;
int getNClusters() const;
int getCovMatType() const;
virtual void read( CvFileStorage* fs, CvFileNode* node );
virtual void write( CvFileStorage* fs, const char* name ) const;
const cv::Mat& getWeights() const;
const cv::Mat& getMeans() const;
const std::vector<cv::Mat>& getCovs() const;
virtual void write_params( CvFileStorage* fs ) const;
virtual void read_params( CvFileStorage* fs, CvFileNode* node );
AlgorithmInfo* info() const;
virtual void read(const FileNode& fn);
protected:
virtual void setTrainData(const cv::Mat& samples, const cv::Mat& samplesMask, const EM::Params& params);
virtual void set_params( const CvEMParams& params,
const CvVectors& train_data );
virtual void init_em( const CvVectors& train_data );
virtual double run_em( const CvVectors& train_data );
virtual void init_auto( const CvVectors& samples );
virtual void kmeans( const CvVectors& train_data, int nclusters,
CvMat* labels, CvTermCriteria criteria,
const CvMat* means );
CvEMParams params;
double log_likelihood;
double log_likelihood_delta;
bool doTrain(const cv::TermCriteria& termCrit);
virtual void eStep();
virtual void mStep();
CvMat* means;
CvMat** covs;
CvMat* weights;
CvMat* probs;
void clusterTrainSamples();
void decomposeCovs();
void computeLogWeightDivDet();
CvMat* log_weight_div_det;
CvMat* inv_eigen_values;
CvMat** cov_rotate_mats;
void computeProbabilities(const cv::Mat& sample, int& label, cv::Mat* probs, float* likelihood) const;
// all inner matrices have type CV_32FC1
int nclusters;
int covMatType;
int startStep;
cv::Mat trainSamples;
cv::Mat trainProbs;
cv::Mat trainLikelihoods;
cv::Mat trainLabels;
cv::Mat trainCounts;
cv::Mat weights;
cv::Mat means;
std::vector<cv::Mat> covs;
std::vector<cv::Mat> covsEigenValues;
std::vector<cv::Mat> covsRotateMats;
std::vector<cv::Mat> invCovsEigenValues;
cv::Mat logWeightDivDet;
};
} // namespace cv
/****************************************************************************************\
* Decision Tree *
@@ -2012,17 +1995,10 @@ CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
CvMat** responses,
int num_classes, ... );
#endif
/****************************************************************************************\
* Data *
\****************************************************************************************/
#include <map>
#include <string>
#include <iostream>
#define CV_COUNT 0
#define CV_PORTION 1
@@ -2133,8 +2109,6 @@ typedef CvSVMParams SVMParams;
typedef CvSVMKernel SVMKernel;
typedef CvSVMSolver SVMSolver;
typedef CvSVM SVM;
typedef CvEMParams EMParams;
typedef CvEM ExpectationMaximization;
typedef CvDTreeParams DTreeParams;
typedef CvMLData TrainData;
typedef CvDTree DecisionTree;
@@ -2156,5 +2130,7 @@ template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
}
#endif
#endif // __cplusplus
#endif // __OPENCV_ML_HPP__
/* End of file. */

File diff suppressed because it is too large Load Diff

View File

@@ -44,34 +44,49 @@
using namespace std;
using namespace cv;
void defaultDistribs( vector<Mat>& means, vector<Mat>& covs )
static
void defaultDistribs( Mat& means, vector<Mat>& covs )
{
float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f};
float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f};
float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f};
means.create(3, 2, CV_32FC1);
Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 );
Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 );
Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 );
means.resize(3), covs.resize(3);
m0.copyTo(means[0]), c0.copyTo(covs[0]);
m1.copyTo(means[1]), c1.copyTo(covs[1]);
m2.copyTo(means[2]), c2.copyTo(covs[2]);
Mat mr0 = means.row(0);
m0.copyTo(mr0);
c0.copyTo(covs[0]);
Mat mr1 = means.row(1);
m1.copyTo(mr1);
c1.copyTo(covs[1]);
Mat mr2 = means.row(2);
m2.copyTo(mr2);
c2.copyTo(covs[2]);
}
// generate points sets by normal distributions
void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const vector<Mat>& means, const vector<Mat>& covs, int labelType )
static
void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& _means, const vector<Mat>& covs, int labelType )
{
vector<int>::const_iterator sit = sizes.begin();
int total = 0;
for( ; sit != sizes.end(); ++sit )
total += *sit;
assert( means.size() == sizes.size() && covs.size() == sizes.size() );
assert( _means.rows == (int)sizes.size() && covs.size() == sizes.size() );
assert( !data.empty() && data.rows == total );
assert( data.type() == CV_32FC1 );
labels.create( data.rows, 1, labelType );
randn( data, Scalar::all(0.0), Scalar::all(1.0) );
vector<Mat> means(sizes.size());
for(int i = 0; i < _means.rows; i++)
means[i] = _means.row(i);
vector<Mat>::const_iterator mit = means.begin(), cit = covs.begin();
int bi, ei = 0;
sit = sizes.begin();
@@ -95,6 +110,7 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const vecto
}
}
static
int maxIdx( const vector<int>& count )
{
int idx = -1;
@@ -112,74 +128,83 @@ int maxIdx( const vector<int>& count )
return idx;
}
static
bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap )
{
int total = 0, setCount = (int)sizes.size();
vector<int>::const_iterator sit = sizes.begin();
for( ; sit != sizes.end(); ++sit )
total += *sit;
size_t total = 0, nclusters = sizes.size();
for(size_t i = 0; i < sizes.size(); i++)
total += sizes[i];
assert( !labels.empty() );
assert( labels.rows == total && labels.cols == 1 );
assert( labels.total() == total && (labels.cols == 1 || labels.rows == 1));
assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
bool isFlt = labels.type() == CV_32FC1;
labelsMap.resize(setCount);
vector<int>::iterator lmit = labelsMap.begin();
vector<bool> buzy(setCount, false);
int bi, ei = 0;
for( sit = sizes.begin(); sit != sizes.end(); ++sit, ++lmit )
labelsMap.resize(nclusters);
vector<bool> buzy(nclusters, false);
int startIndex = 0;
for( size_t clusterIndex = 0; clusterIndex < sizes.size(); clusterIndex++ )
{
vector<int> count( setCount, 0 );
bi = ei;
ei = bi + *sit;
if( isFlt )
vector<int> count( nclusters, 0 );
for( int i = startIndex; i < startIndex + sizes[clusterIndex]; i++)
{
for( int i = bi; i < ei; i++ )
count[(int)labels.at<float>(i, 0)]++;
int lbl = isFlt ? (int)labels.at<float>(i) : labels.at<int>(i);
CV_Assert(lbl < (int)nclusters);
count[lbl]++;
CV_Assert(count[lbl] < (int)total);
}
else
{
for( int i = bi; i < ei; i++ )
count[labels.at<int>(i, 0)]++;
}
*lmit = maxIdx( count );
if( buzy[*lmit] )
return false;
buzy[*lmit] = true;
startIndex += sizes[clusterIndex];
int cls = maxIdx( count );
CV_Assert( !buzy[cls] );
labelsMap[clusterIndex] = cls;
buzy[cls] = true;
}
return true;
for(size_t i = 0; i < buzy.size(); i++)
if(!buzy[i])
return false;
return true;
}
float calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, bool labelsEquivalent = true )
static
bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true )
{
int err = 0;
assert( !labels.empty() && !origLabels.empty() );
assert( labels.cols == 1 && origLabels.cols == 1 );
assert( labels.rows == origLabels.rows );
assert( labels.type() == origLabels.type() );
assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
err = 0;
CV_Assert( !labels.empty() && !origLabels.empty() );
CV_Assert( labels.rows == 1 || labels.cols == 1 );
CV_Assert( origLabels.rows == 1 || origLabels.cols == 1 );
CV_Assert( labels.total() == origLabels.total() );
CV_Assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
CV_Assert( origLabels.type() == labels.type() );
vector<int> labelsMap;
bool isFlt = labels.type() == CV_32FC1;
if( !labelsEquivalent )
{
getLabelsMap( labels, sizes, labelsMap );
if( !getLabelsMap( labels, sizes, labelsMap ) )
return false;
for( int i = 0; i < labels.rows; i++ )
if( isFlt )
err += labels.at<float>(i, 0) != labelsMap[(int)origLabels.at<float>(i, 0)];
err += labels.at<float>(i) != labelsMap[(int)origLabels.at<float>(i)] ? 1.f : 0.f;
else
err += labels.at<int>(i, 0) != labelsMap[origLabels.at<int>(i, 0)];
err += labels.at<int>(i) != labelsMap[origLabels.at<int>(i)] ? 1.f : 0.f;
}
else
{
for( int i = 0; i < labels.rows; i++ )
if( isFlt )
err += labels.at<float>(i, 0) != origLabels.at<float>(i, 0);
err += labels.at<float>(i) != origLabels.at<float>(i) ? 1.f : 0.f;
else
err += labels.at<int>(i, 0) != origLabels.at<int>(i, 0);
err += labels.at<int>(i) != origLabels.at<int>(i) ? 1.f : 0.f;
}
return (float)err / (float)labels.rows;
err /= (float)labels.rows;
return true;
}
//--------------------------------------------------------------------------------------------
@@ -198,7 +223,8 @@ void CV_KMeansTest::run( int /*start_from*/ )
Mat data( pointsCount, 2, CV_32FC1 ), labels;
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
vector<Mat> means, covs;
Mat means;
vector<Mat> covs;
defaultDistribs( means, covs );
generateData( data, labels, sizes, means, covs, CV_32SC1 );
@@ -207,8 +233,12 @@ void CV_KMeansTest::run( int /*start_from*/ )
Mat bestLabels;
// 1. flag==KMEANS_PP_CENTERS
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_PP_CENTERS, noArray() );
err = calcErr( bestLabels, labels, sizes, false );
if( err > 0.01f )
if( !calcErr( bestLabels, labels, sizes, err , false ) )
{
ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_PP_CENTERS.\n" );
code = cvtest::TS::FAIL_INVALID_OUTPUT;
}
else if( err > 0.01f )
{
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
@@ -216,10 +246,14 @@ void CV_KMeansTest::run( int /*start_from*/ )
// 2. flag==KMEANS_RANDOM_CENTERS
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_RANDOM_CENTERS, noArray() );
err = calcErr( bestLabels, labels, sizes, false );
if( err > 0.01f )
if( !calcErr( bestLabels, labels, sizes, err, false ) )
{
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_RANDOM_CENTERS.\n" );
code = cvtest::TS::FAIL_INVALID_OUTPUT;
}
else if( err > 0.01f )
{
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_RANDOM_CENTERS.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
@@ -229,10 +263,14 @@ void CV_KMeansTest::run( int /*start_from*/ )
for( int i = 0; i < 0.5f * pointsCount; i++ )
bestLabels.at<int>( rng.next() % pointsCount, 0 ) = rng.next() % 3;
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_USE_INITIAL_LABELS, noArray() );
err = calcErr( bestLabels, labels, sizes, false );
if( err > 0.01f )
if( !calcErr( bestLabels, labels, sizes, err, false ) )
{
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_USE_INITIAL_LABELS.\n" );
code = cvtest::TS::FAIL_INVALID_OUTPUT;
}
else if( err > 0.01f )
{
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_USE_INITIAL_LABELS.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
@@ -255,7 +293,8 @@ void CV_KNearestTest::run( int /*start_from*/ )
// train data
Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
vector<Mat> means, covs;
Mat means;
vector<Mat> covs;
defaultDistribs( means, covs );
generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1 );
@@ -267,8 +306,13 @@ void CV_KNearestTest::run( int /*start_from*/ )
KNearest knearest;
knearest.train( trainData, trainLabels );
knearest.find_nearest( testData, 4, &bestLabels );
float err = calcErr( bestLabels, testLabels, sizes, true );
if( err > 0.01f )
float err;
if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
{
ts->printf( cvtest::TS::LOG, "Bad output labels.\n" );
code = cvtest::TS::FAIL_INVALID_OUTPUT;
}
else if( err > 0.01f )
{
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
@@ -277,76 +321,167 @@ void CV_KNearestTest::run( int /*start_from*/ )
}
//--------------------------------------------------------------------------------------------
class CV_EMTest : public cvtest::BaseTest {
class CV_EMTest : public cvtest::BaseTest
{
public:
CV_EMTest() {}
protected:
virtual void run( int start_from );
int runCase( int caseIndex, const cv::EM::Params& params,
const cv::Mat& trainData, const cv::Mat& trainLabels,
const cv::Mat& testData, const cv::Mat& testLabels,
const vector<int>& sizes);
};
void CV_EMTest::run( int /*start_from*/ )
int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
const cv::Mat& trainData, const cv::Mat& trainLabels,
const cv::Mat& testData, const cv::Mat& testLabels,
const vector<int>& sizes )
{
int sizesArr[] = { 5000, 7000, 8000 };
int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
// train data
Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
vector<Mat> means, covs;
defaultDistribs( means, covs );
generateData( trainData, trainLabels, sizes, means, covs, CV_32SC1 );
// test data
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels, bestLabels;
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
int code = cvtest::TS::OK;
cv::Mat labels;
float err;
ExpectationMaximization em;
CvEMParams params;
params.nclusters = 3;
em.train( trainData, Mat(), params, &bestLabels );
cv::EM em;
em.train( trainData, Mat(), params, &labels );
// check train error
err = calcErr( bestLabels, trainLabels, sizes, false );
if( err > 0.002f )
if( !calcErr( labels, trainLabels, sizes, err , false ) )
{
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on train data.\n", err );
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
code = cvtest::TS::FAIL_INVALID_OUTPUT;
}
else if( err > 0.006f )
{
ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on train data.\n", caseIndex, err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
// check test error
bestLabels.create( testData.rows, 1, CV_32SC1 );
labels.create( testData.rows, 1, CV_32SC1 );
for( int i = 0; i < testData.rows; i++ )
{
Mat sample( 1, testData.cols, CV_32FC1, testData.ptr<float>(i));
bestLabels.at<int>(i,0) = (int)em.predict( sample, 0 );
Mat sample = testData.row(i);
labels.at<int>(i,0) = (int)em.predict( sample, 0 );
}
err = calcErr( bestLabels, testLabels, sizes, false );
if( err > 0.005f )
if( !calcErr( labels, testLabels, sizes, err, false ) )
{
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
code = cvtest::TS::FAIL_INVALID_OUTPUT;
}
else if( err > 0.006f )
{
ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on test data.\n", caseIndex, err );
code = cvtest::TS::FAIL_BAD_ACCURACY;
}
return code;
}
void CV_EMTest::run( int /*start_from*/ )
{
int sizesArr[] = { 500, 700, 800 };
int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
// Points distribution
Mat means;
vector<Mat> covs;
defaultDistribs( means, covs );
// train data
Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
generateData( trainData, trainLabels, sizes, means, covs, CV_32SC1 );
// test data
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels;
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
cv::EM::Params params;
params.nclusters = 3;
Mat probs(trainData.rows, params.nclusters, CV_32FC1, cv::Scalar(1));
params.probs = &probs;
Mat weights(1, params.nclusters, CV_32FC1, cv::Scalar(1));
params.weights = &weights;
params.means = &means;
params.covs = &covs;
int code = cvtest::TS::OK;
int caseIndex = 0;
{
params.startStep = cv::EM::START_AUTO_STEP;
params.covMatType = cv::EM::COV_MAT_GENERIC;
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
code = currCode == cvtest::TS::OK ? code : currCode;
}
{
params.startStep = cv::EM::START_AUTO_STEP;
params.covMatType = cv::EM::COV_MAT_DIAGONAL;
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
code = currCode == cvtest::TS::OK ? code : currCode;
}
{
params.startStep = cv::EM::START_AUTO_STEP;
params.covMatType = cv::EM::COV_MAT_SPHERICAL;
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
code = currCode == cvtest::TS::OK ? code : currCode;
}
{
params.startStep = cv::EM::START_M_STEP;
params.covMatType = cv::EM::COV_MAT_GENERIC;
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
code = currCode == cvtest::TS::OK ? code : currCode;
}
{
params.startStep = cv::EM::START_M_STEP;
params.covMatType = cv::EM::COV_MAT_DIAGONAL;
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
code = currCode == cvtest::TS::OK ? code : currCode;
}
{
params.startStep = cv::EM::START_M_STEP;
params.covMatType = cv::EM::COV_MAT_SPHERICAL;
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
code = currCode == cvtest::TS::OK ? code : currCode;
}
{
params.startStep = cv::EM::START_E_STEP;
params.covMatType = cv::EM::COV_MAT_GENERIC;
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
code = currCode == cvtest::TS::OK ? code : currCode;
}
{
params.startStep = cv::EM::START_E_STEP;
params.covMatType = cv::EM::COV_MAT_DIAGONAL;
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
code = currCode == cvtest::TS::OK ? code : currCode;
}
{
params.startStep = cv::EM::START_E_STEP;
params.covMatType = cv::EM::COV_MAT_SPHERICAL;
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
code = currCode == cvtest::TS::OK ? code : currCode;
}
ts->set_failed_test_info( code );
}
class CV_EMTest_Smoke : public cvtest::BaseTest {
class CV_EMTest_SaveLoad : public cvtest::BaseTest {
public:
CV_EMTest_Smoke() {}
CV_EMTest_SaveLoad() {}
protected:
virtual void run( int /*start_from*/ )
{
int code = cvtest::TS::OK;
CvEM em;
cv::EM em;
Mat samples = Mat(3,2,CV_32F);
Mat samples = Mat(3,1,CV_32F);
samples.at<float>(0,0) = 1;
samples.at<float>(1,0) = 2;
samples.at<float>(2,0) = 3;
CvEMParams params;
cv::EM::Params params;
params.nclusters = 2;
Mat labels;
@@ -361,10 +496,11 @@ protected:
string filename = tempfile() + ".xml";
{
FileStorage fs = FileStorage(filename, FileStorage::WRITE);
try
{
em.write(fs.fs, "EM");
fs << "em" << "{";
em.write(fs);
fs << "}";
}
catch(...)
{
@@ -378,11 +514,11 @@ protected:
// Read in
{
FileStorage fs = FileStorage(filename, FileStorage::READ);
FileNode fileNode = fs["EM"];
CV_Assert(fs.isOpened());
FileNode fn = fs["em"];
try
{
em.read(const_cast<CvFileStorage*>(fileNode.fs), const_cast<CvFileNode*>(fileNode.node));
em.read(fn);
}
catch(...)
{
@@ -410,4 +546,4 @@ protected:
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
TEST(ML_EM, smoke) { CV_EMTest_Smoke test; test.safe_run(); }
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }

View File

@@ -451,7 +451,6 @@ CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
nbayes = 0;
knearest = 0;
svm = 0;
em = 0;
ann = 0;
dtree = 0;
boost = 0;
@@ -463,8 +462,6 @@ CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
knearest = new CvKNearest;
else if( !modelName.compare(CV_SVM) )
svm = new CvSVM;
else if( !modelName.compare(CV_EM) )
em = new CvEM;
else if( !modelName.compare(CV_ANN) )
ann = new CvANN_MLP;
else if( !modelName.compare(CV_DTREE) )
@@ -487,8 +484,6 @@ CV_MLBaseTest::~CV_MLBaseTest()
delete knearest;
if( svm )
delete svm;
if( em )
delete em;
if( ann )
delete ann;
if( dtree )
@@ -756,8 +751,6 @@ void CV_MLBaseTest::save( const char* filename )
knearest->save( filename );
else if( !modelName.compare(CV_SVM) )
svm->save( filename );
else if( !modelName.compare(CV_EM) )
em->save( filename );
else if( !modelName.compare(CV_ANN) )
ann->save( filename );
else if( !modelName.compare(CV_DTREE) )
@@ -778,8 +771,6 @@ void CV_MLBaseTest::load( const char* filename )
knearest->load( filename );
else if( !modelName.compare(CV_SVM) )
svm->load( filename );
else if( !modelName.compare(CV_EM) )
em->load( filename );
else if( !modelName.compare(CV_ANN) )
ann->load( filename );
else if( !modelName.compare(CV_DTREE) )

View File

@@ -44,7 +44,6 @@ protected:
CvNormalBayesClassifier* nbayes;
CvKNearest* knearest;
CvSVM* svm;
CvEM* em;
CvANN_MLP* ann;
CvDTree* dtree;
CvBoost* boost;