modified EM interface; updated tests & samples
This commit is contained in:
@@ -1821,10 +1821,10 @@ public:
|
||||
CV_WRAP virtual double calcLikelihood( const cv::Mat &sample ) const;
|
||||
|
||||
CV_WRAP int getNClusters() const;
|
||||
CV_WRAP const cv::Mat& getMeans() const;
|
||||
CV_WRAP cv::Mat getMeans() const;
|
||||
CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const;
|
||||
CV_WRAP const cv::Mat& getWeights() const;
|
||||
CV_WRAP const cv::Mat& getProbs() const;
|
||||
CV_WRAP cv::Mat getWeights() const;
|
||||
CV_WRAP cv::Mat getProbs() const;
|
||||
|
||||
CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? likelihood : DBL_MAX; }
|
||||
#endif
|
||||
|
@@ -41,6 +41,8 @@
|
||||
|
||||
#include "precomp.hpp"
|
||||
|
||||
using namespace cv;
|
||||
|
||||
CvEMParams::CvEMParams() : nclusters(10), cov_mat_type(CvEM::COV_MAT_DIAGONAL),
|
||||
start_step(CvEM::START_AUTO_STEP), probs(0), weights(0), means(0), covs(0)
|
||||
{
|
||||
@@ -76,38 +78,44 @@ void CvEM::clear()
|
||||
|
||||
void CvEM::read( CvFileStorage* fs, CvFileNode* node )
|
||||
{
|
||||
cv::FileNode fn(fs, node);
|
||||
FileNode fn(fs, node);
|
||||
emObj.read(fn);
|
||||
set_mat_hdrs();
|
||||
}
|
||||
|
||||
void CvEM::write( CvFileStorage* _fs, const char* name ) const
|
||||
{
|
||||
cv::FileStorage fs = _fs;
|
||||
FileStorage fs = _fs;
|
||||
if(name)
|
||||
fs << name << "{";
|
||||
emObj.write(fs);
|
||||
if(name)
|
||||
fs << "}";
|
||||
fs.fs.obj = 0;
|
||||
}
|
||||
|
||||
double CvEM::calcLikelihood( const cv::Mat &input_sample ) const
|
||||
double CvEM::calcLikelihood( const Mat &input_sample ) const
|
||||
{
|
||||
double likelihood;
|
||||
emObj.predict(input_sample, 0, &likelihood);
|
||||
emObj.predict(input_sample, noArray(), &likelihood);
|
||||
return likelihood;
|
||||
}
|
||||
|
||||
float
|
||||
CvEM::predict( const CvMat* _sample, CvMat* _probs, bool isNormalize ) const
|
||||
{
|
||||
cv::Mat prbs;
|
||||
int cls = emObj.predict(_sample, _probs ? &prbs : 0);
|
||||
Mat prbs0 = cvarrToMat(_probs), prbs = prbs0, sample = cvarrToMat(_sample);
|
||||
int cls = emObj.predict(sample, _probs ? _OutputArray(prbs) : _OutputArray::_OutputArray());
|
||||
if(_probs)
|
||||
{
|
||||
if(isNormalize)
|
||||
cv::normalize(prbs, prbs, 1, 0, cv::NORM_L1);
|
||||
*_probs = prbs;
|
||||
normalize(prbs, prbs, 1, 0, NORM_L1);
|
||||
|
||||
if( prbs.data != prbs0.data )
|
||||
{
|
||||
CV_Assert( prbs.size == prbs0.size );
|
||||
prbs.convertTo(prbs0, prbs0.type());
|
||||
}
|
||||
}
|
||||
return (float)cls;
|
||||
}
|
||||
@@ -116,73 +124,55 @@ void CvEM::set_mat_hdrs()
|
||||
{
|
||||
if(emObj.isTrained())
|
||||
{
|
||||
meansHdr = emObj.getMeans();
|
||||
covsHdrs.resize(emObj.getNClusters());
|
||||
covsPtrs.resize(emObj.getNClusters());
|
||||
const std::vector<cv::Mat>& covs = emObj.getCovs();
|
||||
meansHdr = emObj.get<Mat>("means");
|
||||
int K = emObj.get<int>("nclusters");
|
||||
covsHdrs.resize(K);
|
||||
covsPtrs.resize(K);
|
||||
const std::vector<Mat>& covs = emObj.get<vector<Mat> >("covs");
|
||||
for(size_t i = 0; i < covsHdrs.size(); i++)
|
||||
{
|
||||
covsHdrs[i] = covs[i];
|
||||
covsPtrs[i] = &covsHdrs[i];
|
||||
}
|
||||
weightsHdr = emObj.getWeights();
|
||||
weightsHdr = emObj.get<Mat>("weights");
|
||||
probsHdr = probs;
|
||||
}
|
||||
}
|
||||
|
||||
static
|
||||
void init_params(const CvEMParams& src, cv::EM::Params& dst,
|
||||
cv::Mat& prbs, cv::Mat& weights,
|
||||
cv::Mat& means, cv::vector<cv::Mat>& covsHdrs)
|
||||
void init_params(const CvEMParams& src,
|
||||
Mat& prbs, Mat& weights,
|
||||
Mat& means, vector<Mat>& covsHdrs)
|
||||
{
|
||||
dst.nclusters = src.nclusters;
|
||||
dst.covMatType = src.cov_mat_type;
|
||||
dst.startStep = src.start_step;
|
||||
dst.termCrit = src.term_crit;
|
||||
|
||||
prbs = src.probs;
|
||||
dst.probs = &prbs;
|
||||
|
||||
weights = src.weights;
|
||||
dst.weights = &weights;
|
||||
|
||||
means = src.means;
|
||||
dst.means = &means;
|
||||
|
||||
if(src.covs)
|
||||
{
|
||||
covsHdrs.resize(src.nclusters);
|
||||
for(size_t i = 0; i < covsHdrs.size(); i++)
|
||||
covsHdrs[i] = src.covs[i];
|
||||
dst.covs = &covsHdrs;
|
||||
}
|
||||
}
|
||||
|
||||
bool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx,
|
||||
CvEMParams _params, CvMat* _labels )
|
||||
{
|
||||
cv::EM::Params params;
|
||||
cv::Mat prbs, weights, means;
|
||||
std::vector<cv::Mat> covsHdrs;
|
||||
init_params(_params, params, prbs, weights, means, covsHdrs);
|
||||
|
||||
cv::Mat lbls;
|
||||
cv::Mat likelihoods;
|
||||
bool isOk = emObj.train(_samples, _sample_idx, params, _labels ? &lbls : 0, &probs, &likelihoods );
|
||||
if(isOk)
|
||||
{
|
||||
if(_labels)
|
||||
*_labels = lbls;
|
||||
likelihood = cv::sum(likelihoods)[0];
|
||||
set_mat_hdrs();
|
||||
}
|
||||
CV_Assert(_sample_idx == 0);
|
||||
Mat samples = cvarrToMat(_samples), labels0, labels;
|
||||
if( _labels )
|
||||
labels0 = labels = cvarrToMat(_labels);
|
||||
|
||||
bool isOk = train(samples, Mat(), _params, _labels ? &labels : 0);
|
||||
CV_Assert( labels0.data == labels.data );
|
||||
|
||||
return isOk;
|
||||
}
|
||||
|
||||
int CvEM::get_nclusters() const
|
||||
{
|
||||
return emObj.getNClusters();
|
||||
return emObj.get<int>("nclusters");
|
||||
}
|
||||
|
||||
const CvMat* CvEM::get_means() const
|
||||
@@ -215,16 +205,29 @@ CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params )
|
||||
bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
|
||||
CvEMParams _params, Mat* _labels )
|
||||
{
|
||||
cv::EM::Params params;
|
||||
cv::Mat prbs, weights, means;
|
||||
std::vector<cv::Mat> covsHdrs;
|
||||
init_params(_params, params, prbs, weights, means, covsHdrs);
|
||||
Mat prbs, weights, means, likelihoods;
|
||||
std::vector<Mat> covsHdrs;
|
||||
init_params(_params, prbs, weights, means, covsHdrs);
|
||||
|
||||
cv::Mat likelihoods;
|
||||
bool isOk = emObj.train(_samples, _sample_idx, params, _labels, &probs, &likelihoods);
|
||||
emObj = EM(_params.nclusters, _params.cov_mat_type, _params.term_crit);
|
||||
bool isOk = false;
|
||||
if( _params.start_step == EM::START_AUTO_STEP )
|
||||
isOk = emObj.train(_samples, _labels ? _OutputArray(*_labels) : _OutputArray::_OutputArray(),
|
||||
probs, likelihoods);
|
||||
else if( _params.start_step == EM::START_E_STEP )
|
||||
isOk = emObj.trainE(_samples, means, covsHdrs, weights,
|
||||
_labels ? _OutputArray(*_labels) : _OutputArray::_OutputArray(),
|
||||
probs, likelihoods);
|
||||
else if( _params.start_step == EM::START_M_STEP )
|
||||
isOk = emObj.trainM(_samples, prbs,
|
||||
_labels ? _OutputArray(*_labels) : _OutputArray::_OutputArray(),
|
||||
probs, likelihoods);
|
||||
else
|
||||
CV_Error(CV_StsBadArg, "Bad start type of EM algorithm");
|
||||
|
||||
if(isOk)
|
||||
{
|
||||
likelihoods = cv::sum(likelihoods).val[0];
|
||||
likelihoods = sum(likelihoods).val[0];
|
||||
set_mat_hdrs();
|
||||
}
|
||||
|
||||
@@ -234,34 +237,34 @@ bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
|
||||
float
|
||||
CvEM::predict( const Mat& _sample, Mat* _probs, bool isNormalize ) const
|
||||
{
|
||||
int cls = emObj.predict(_sample, _probs);
|
||||
int cls = emObj.predict(_sample, _probs ? _OutputArray(*_probs) : _OutputArray::_OutputArray());
|
||||
if(_probs && isNormalize)
|
||||
cv::normalize(*_probs, *_probs, 1, 0, cv::NORM_L1);
|
||||
normalize(*_probs, *_probs, 1, 0, NORM_L1);
|
||||
|
||||
return (float)cls;
|
||||
}
|
||||
|
||||
int CvEM::getNClusters() const
|
||||
{
|
||||
return emObj.getNClusters();
|
||||
return emObj.get<int>("nclusters");
|
||||
}
|
||||
|
||||
const Mat& CvEM::getMeans() const
|
||||
Mat CvEM::getMeans() const
|
||||
{
|
||||
return emObj.getMeans();
|
||||
return emObj.get<Mat>("means");
|
||||
}
|
||||
|
||||
void CvEM::getCovs(vector<Mat>& _covs) const
|
||||
{
|
||||
_covs = emObj.getCovs();
|
||||
_covs = emObj.get<vector<Mat> >("covs");
|
||||
}
|
||||
|
||||
const Mat& CvEM::getWeights() const
|
||||
Mat CvEM::getWeights() const
|
||||
{
|
||||
return emObj.getWeights();
|
||||
return emObj.get<Mat>("weights");
|
||||
}
|
||||
|
||||
const Mat& CvEM::getProbs() const
|
||||
Mat CvEM::getProbs() const
|
||||
{
|
||||
return probs;
|
||||
}
|
||||
|
@@ -371,19 +371,20 @@ protected:
|
||||
virtual void run( int /*start_from*/ )
|
||||
{
|
||||
int code = cvtest::TS::OK;
|
||||
cv::EM em;
|
||||
|
||||
Mat samples = Mat(3,1,CV_32F);
|
||||
samples.at<float>(0,0) = 1;
|
||||
samples.at<float>(1,0) = 2;
|
||||
samples.at<float>(2,0) = 3;
|
||||
|
||||
Mat labels(samples.rows, 1, CV_32S);
|
||||
|
||||
cv::EM::Params params;
|
||||
CvEMParams params;
|
||||
params.nclusters = 2;
|
||||
|
||||
Mat labels;
|
||||
CvMat samples_c = samples, labels_c = labels;
|
||||
|
||||
em.train(samples, Mat(), params, &labels);
|
||||
CvEM em(&samples_c, 0, params, &labels_c);
|
||||
|
||||
Mat firstResult(samples.rows, 1, CV_32FC1);
|
||||
for( int i = 0; i < samples.rows; i++)
|
||||
@@ -396,9 +397,7 @@ protected:
|
||||
FileStorage fs = FileStorage(filename, FileStorage::WRITE);
|
||||
try
|
||||
{
|
||||
fs << "em" << "{";
|
||||
em.write(fs);
|
||||
fs << "}";
|
||||
em.write(fs.fs, "em");
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
@@ -416,7 +415,7 @@ protected:
|
||||
FileNode fn = fs["em"];
|
||||
try
|
||||
{
|
||||
em.read(fn);
|
||||
em.read(fs.fs, (CvFileNode*)fn.node);
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
|
Reference in New Issue
Block a user