added cv::EM, moved CvEM to legacy, added/updated tests
This commit is contained in:
parent
cdc5bbc0bc
commit
85fa0e7763
modules
contrib
legacy
ml
samples/cpp
@ -66,7 +66,7 @@ struct CV_EXPORTS CvMotionModel
|
|||||||
}
|
}
|
||||||
|
|
||||||
float low_pass_gain; // low pass gain
|
float low_pass_gain; // low pass gain
|
||||||
CvEMParams em_params; // EM parameters
|
cv::EM::Params em_params; // EM parameters
|
||||||
};
|
};
|
||||||
|
|
||||||
// Mean Shift Tracker parameters for specifying use of HSV channel and CamShift parameters.
|
// Mean Shift Tracker parameters for specifying use of HSV channel and CamShift parameters.
|
||||||
@ -109,7 +109,7 @@ struct CV_EXPORTS CvHybridTrackerParams
|
|||||||
float ms_tracker_weight;
|
float ms_tracker_weight;
|
||||||
CvFeatureTrackerParams ft_params;
|
CvFeatureTrackerParams ft_params;
|
||||||
CvMeanShiftTrackerParams ms_params;
|
CvMeanShiftTrackerParams ms_params;
|
||||||
CvEMParams em_params;
|
cv::EM::Params em_params;
|
||||||
int motion_model;
|
int motion_model;
|
||||||
float low_pass_gain;
|
float low_pass_gain;
|
||||||
};
|
};
|
||||||
@ -182,7 +182,7 @@ private:
|
|||||||
|
|
||||||
CvMat* samples;
|
CvMat* samples;
|
||||||
CvMat* labels;
|
CvMat* labels;
|
||||||
CvEM em_model;
|
cv::EM em_model;
|
||||||
|
|
||||||
Rect prev_window;
|
Rect prev_window;
|
||||||
Point2f prev_center;
|
Point2f prev_center;
|
||||||
|
@ -137,11 +137,11 @@ void CvHybridTracker::newTracker(Mat image, Rect selection) {
|
|||||||
params.em_params.probs = NULL;
|
params.em_params.probs = NULL;
|
||||||
params.em_params.nclusters = 1;
|
params.em_params.nclusters = 1;
|
||||||
params.em_params.weights = NULL;
|
params.em_params.weights = NULL;
|
||||||
params.em_params.cov_mat_type = CvEM::COV_MAT_SPHERICAL;
|
params.em_params.covMatType = cv::EM::COV_MAT_SPHERICAL;
|
||||||
params.em_params.start_step = CvEM::START_AUTO_STEP;
|
params.em_params.startStep = cv::EM::START_AUTO_STEP;
|
||||||
params.em_params.term_crit.max_iter = 10000;
|
params.em_params.termCrit.maxCount = 10000;
|
||||||
params.em_params.term_crit.epsilon = 0.001;
|
params.em_params.termCrit.epsilon = 0.001;
|
||||||
params.em_params.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
|
params.em_params.termCrit.type = cv::TermCriteria::COUNT + cv::TermCriteria::EPS;
|
||||||
|
|
||||||
samples = cvCreateMat(2, 1, CV_32FC1);
|
samples = cvCreateMat(2, 1, CV_32FC1);
|
||||||
labels = cvCreateMat(2, 1, CV_32SC1);
|
labels = cvCreateMat(2, 1, CV_32SC1);
|
||||||
@ -221,7 +221,10 @@ void CvHybridTracker::updateTrackerWithEM(Mat image) {
|
|||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
|
|
||||||
em_model.train(samples, 0, params.em_params, labels);
|
cv::Mat lbls;
|
||||||
|
em_model.train(samples, cv::Mat(), params.em_params, &lbls);
|
||||||
|
if(labels)
|
||||||
|
*labels = lbls;
|
||||||
|
|
||||||
curr_center.x = (float)em_model.getMeans().at<double> (0, 0);
|
curr_center.x = (float)em_model.getMeans().at<double> (0, 0);
|
||||||
curr_center.y = (float)em_model.getMeans().at<double> (0, 1);
|
curr_center.y = (float)em_model.getMeans().at<double> (0, 1);
|
||||||
|
@ -1 +1 @@
|
|||||||
ocv_define_module(legacy opencv_calib3d opencv_highgui opencv_video)
|
ocv_define_module(legacy opencv_calib3d opencv_highgui opencv_video opencv_ml)
|
||||||
|
@ -46,6 +46,7 @@
|
|||||||
#include "opencv2/imgproc/imgproc_c.h"
|
#include "opencv2/imgproc/imgproc_c.h"
|
||||||
#include "opencv2/features2d/features2d.hpp"
|
#include "opencv2/features2d/features2d.hpp"
|
||||||
#include "opencv2/calib3d/calib3d.hpp"
|
#include "opencv2/calib3d/calib3d.hpp"
|
||||||
|
#include "opencv2/ml/ml.hpp"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
@ -1761,10 +1762,106 @@ protected:
|
|||||||
IplImage* m_mask;
|
IplImage* m_mask;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/****************************************************************************************\
|
||||||
|
* Expectation - Maximization *
|
||||||
|
\****************************************************************************************/
|
||||||
|
struct CV_EXPORTS_W_MAP CvEMParams
|
||||||
|
{
|
||||||
|
CvEMParams();
|
||||||
|
CvEMParams( int nclusters, int cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
|
||||||
|
int start_step=0/*CvEM::START_AUTO_STEP*/,
|
||||||
|
CvTermCriteria term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
|
||||||
|
const CvMat* probs=0, const CvMat* weights=0, const CvMat* means=0, const CvMat** covs=0 );
|
||||||
|
|
||||||
|
CV_PROP_RW int nclusters;
|
||||||
|
CV_PROP_RW int cov_mat_type;
|
||||||
|
CV_PROP_RW int start_step;
|
||||||
|
const CvMat* probs;
|
||||||
|
const CvMat* weights;
|
||||||
|
const CvMat* means;
|
||||||
|
const CvMat** covs;
|
||||||
|
CV_PROP_RW CvTermCriteria term_crit;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
class CV_EXPORTS_W CvEM : public CvStatModel
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
// Type of covariation matrices
|
||||||
|
enum { COV_MAT_SPHERICAL=cv::EM::COV_MAT_SPHERICAL,
|
||||||
|
COV_MAT_DIAGONAL =cv::EM::COV_MAT_DIAGONAL,
|
||||||
|
COV_MAT_GENERIC =cv::EM::COV_MAT_GENERIC };
|
||||||
|
|
||||||
|
// The initial step
|
||||||
|
enum { START_E_STEP=cv::EM::START_E_STEP,
|
||||||
|
START_M_STEP=cv::EM::START_M_STEP,
|
||||||
|
START_AUTO_STEP=cv::EM::START_AUTO_STEP };
|
||||||
|
|
||||||
|
CV_WRAP CvEM();
|
||||||
|
CvEM( const CvMat* samples, const CvMat* sampleIdx=0,
|
||||||
|
CvEMParams params=CvEMParams(), CvMat* labels=0 );
|
||||||
|
|
||||||
|
virtual ~CvEM();
|
||||||
|
|
||||||
|
virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0,
|
||||||
|
CvEMParams params=CvEMParams(), CvMat* labels=0 );
|
||||||
|
|
||||||
|
virtual float predict( const CvMat* sample, CV_OUT CvMat* probs, bool isNormalize=true ) const;
|
||||||
|
|
||||||
|
#ifndef SWIG
|
||||||
|
CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(),
|
||||||
|
CvEMParams params=CvEMParams() );
|
||||||
|
|
||||||
|
CV_WRAP virtual bool train( const cv::Mat& samples,
|
||||||
|
const cv::Mat& sampleIdx=cv::Mat(),
|
||||||
|
CvEMParams params=CvEMParams(),
|
||||||
|
CV_OUT cv::Mat* labels=0 );
|
||||||
|
|
||||||
|
CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0, bool isNormalize=true ) const;
|
||||||
|
CV_WRAP virtual double calcLikelihood( const cv::Mat &sample ) const;
|
||||||
|
|
||||||
|
CV_WRAP int getNClusters() const;
|
||||||
|
CV_WRAP const cv::Mat& getMeans() const;
|
||||||
|
CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const;
|
||||||
|
CV_WRAP const cv::Mat& getWeights() const;
|
||||||
|
CV_WRAP const cv::Mat& getProbs() const;
|
||||||
|
|
||||||
|
CV_WRAP inline double getLikelihood() const { return emObj.isTrained() ? likelihood : DBL_MAX; }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
CV_WRAP virtual void clear();
|
||||||
|
|
||||||
|
int get_nclusters() const;
|
||||||
|
const CvMat* get_means() const;
|
||||||
|
const CvMat** get_covs() const;
|
||||||
|
const CvMat* get_weights() const;
|
||||||
|
const CvMat* get_probs() const;
|
||||||
|
|
||||||
|
inline double get_log_likelihood() const { return getLikelihood(); }
|
||||||
|
|
||||||
|
virtual void read( CvFileStorage* fs, CvFileNode* node );
|
||||||
|
virtual void write( CvFileStorage* fs, const char* name ) const;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void set_mat_hdrs();
|
||||||
|
|
||||||
|
cv::EM emObj;
|
||||||
|
cv::Mat probs;
|
||||||
|
double likelihood;
|
||||||
|
|
||||||
|
CvMat meansHdr;
|
||||||
|
std::vector<CvMat> covsHdrs;
|
||||||
|
std::vector<CvMat*> covsPtrs;
|
||||||
|
CvMat weightsHdr;
|
||||||
|
CvMat probsHdr;
|
||||||
|
};
|
||||||
|
|
||||||
namespace cv
|
namespace cv
|
||||||
{
|
{
|
||||||
|
|
||||||
|
typedef CvEMParams EMParams;
|
||||||
|
typedef CvEM ExpectationMaximization;
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
The Patch Generator class
|
The Patch Generator class
|
||||||
*/
|
*/
|
||||||
|
270
modules/legacy/src/em.cpp
Normal file
270
modules/legacy/src/em.cpp
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
//
|
||||||
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||||
|
//
|
||||||
|
// By downloading, copying, installing or using the software you agree to this license.
|
||||||
|
// If you do not agree to this license, do not download, install,
|
||||||
|
// copy or use the software.
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// Intel License Agreement
|
||||||
|
// For Open Source Computer Vision Library
|
||||||
|
//
|
||||||
|
// Copyright( C) 2000, Intel Corporation, all rights reserved.
|
||||||
|
// Third party copyrights are property of their respective owners.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without modification,
|
||||||
|
// are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// * Redistribution's of source code must retain the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer.
|
||||||
|
//
|
||||||
|
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// * The name of Intel Corporation may not be used to endorse or promote products
|
||||||
|
// derived from this software without specific prior written permission.
|
||||||
|
//
|
||||||
|
// This software is provided by the copyright holders and contributors "as is" and
|
||||||
|
// any express or implied warranties, including, but not limited to, the implied
|
||||||
|
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||||
|
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||||
|
// indirect, incidental, special, exemplary, or consequential damages
|
||||||
|
//(including, but not limited to, procurement of substitute goods or services;
|
||||||
|
// loss of use, data, or profits; or business interruption) however caused
|
||||||
|
// and on any theory of liability, whether in contract, strict liability,
|
||||||
|
// or tort(including negligence or otherwise) arising in any way out of
|
||||||
|
// the use of this software, even ifadvised of the possibility of such damage.
|
||||||
|
//
|
||||||
|
//M*/
|
||||||
|
|
||||||
|
#include "precomp.hpp"
|
||||||
|
|
||||||
|
CvEMParams::CvEMParams() : nclusters(10), cov_mat_type(CvEM::COV_MAT_DIAGONAL),
|
||||||
|
start_step(CvEM::START_AUTO_STEP), probs(0), weights(0), means(0), covs(0)
|
||||||
|
{
|
||||||
|
term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
|
||||||
|
}
|
||||||
|
|
||||||
|
CvEMParams::CvEMParams( int _nclusters, int _cov_mat_type, int _start_step,
|
||||||
|
CvTermCriteria _term_crit, const CvMat* _probs,
|
||||||
|
const CvMat* _weights, const CvMat* _means, const CvMat** _covs ) :
|
||||||
|
nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
|
||||||
|
probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
|
||||||
|
{}
|
||||||
|
|
||||||
|
CvEM::CvEM() : likelihood(DBL_MAX)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
CvEM::CvEM( const CvMat* samples, const CvMat* sample_idx,
|
||||||
|
CvEMParams params, CvMat* labels ) : likelihood(DBL_MAX)
|
||||||
|
{
|
||||||
|
train(samples, sample_idx, params, labels);
|
||||||
|
}
|
||||||
|
|
||||||
|
CvEM::~CvEM()
|
||||||
|
{
|
||||||
|
clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CvEM::clear()
|
||||||
|
{
|
||||||
|
emObj.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CvEM::read( CvFileStorage* fs, CvFileNode* node )
|
||||||
|
{
|
||||||
|
cv::FileNode fn(fs, node);
|
||||||
|
emObj.read(fn);
|
||||||
|
set_mat_hdrs();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CvEM::write( CvFileStorage* _fs, const char* name ) const
|
||||||
|
{
|
||||||
|
cv::FileStorage fs = _fs;
|
||||||
|
if(name)
|
||||||
|
fs << name << "{";
|
||||||
|
emObj.write(fs);
|
||||||
|
if(name)
|
||||||
|
fs << "}";
|
||||||
|
}
|
||||||
|
|
||||||
|
double CvEM::calcLikelihood( const cv::Mat &input_sample ) const
|
||||||
|
{
|
||||||
|
double likelihood;
|
||||||
|
emObj.predict(input_sample, 0, &likelihood);
|
||||||
|
return likelihood;
|
||||||
|
}
|
||||||
|
|
||||||
|
float
|
||||||
|
CvEM::predict( const CvMat* _sample, CvMat* _probs, bool isNormalize ) const
|
||||||
|
{
|
||||||
|
cv::Mat prbs;
|
||||||
|
int cls = emObj.predict(_sample, _probs ? &prbs : 0);
|
||||||
|
if(_probs)
|
||||||
|
{
|
||||||
|
if(isNormalize)
|
||||||
|
cv::normalize(prbs, prbs, 1, 0, cv::NORM_L1);
|
||||||
|
*_probs = prbs;
|
||||||
|
}
|
||||||
|
return (float)cls;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CvEM::set_mat_hdrs()
|
||||||
|
{
|
||||||
|
if(emObj.isTrained())
|
||||||
|
{
|
||||||
|
meansHdr = emObj.getMeans();
|
||||||
|
covsHdrs.resize(emObj.getNClusters());
|
||||||
|
covsPtrs.resize(emObj.getNClusters());
|
||||||
|
const std::vector<cv::Mat>& covs = emObj.getCovs();
|
||||||
|
for(size_t i = 0; i < covsHdrs.size(); i++)
|
||||||
|
{
|
||||||
|
covsHdrs[i] = covs[i];
|
||||||
|
covsPtrs[i] = &covsHdrs[i];
|
||||||
|
}
|
||||||
|
weightsHdr = emObj.getWeights();
|
||||||
|
probsHdr = probs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static
|
||||||
|
void init_params(const CvEMParams& src, cv::EM::Params& dst,
|
||||||
|
cv::Mat& prbs, cv::Mat& weights,
|
||||||
|
cv::Mat& means, cv::vector<cv::Mat>& covsHdrs)
|
||||||
|
{
|
||||||
|
dst.nclusters = src.nclusters;
|
||||||
|
dst.covMatType = src.cov_mat_type;
|
||||||
|
dst.startStep = src.start_step;
|
||||||
|
dst.termCrit = src.term_crit;
|
||||||
|
|
||||||
|
prbs = src.probs;
|
||||||
|
dst.probs = &prbs;
|
||||||
|
|
||||||
|
weights = src.weights;
|
||||||
|
dst.weights = &weights;
|
||||||
|
|
||||||
|
means = src.means;
|
||||||
|
dst.means = &means;
|
||||||
|
|
||||||
|
if(src.covs)
|
||||||
|
{
|
||||||
|
covsHdrs.resize(src.nclusters);
|
||||||
|
for(size_t i = 0; i < covsHdrs.size(); i++)
|
||||||
|
covsHdrs[i] = src.covs[i];
|
||||||
|
dst.covs = &covsHdrs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CvEM::train( const CvMat* _samples, const CvMat* _sample_idx,
|
||||||
|
CvEMParams _params, CvMat* _labels )
|
||||||
|
{
|
||||||
|
cv::EM::Params params;
|
||||||
|
cv::Mat prbs, weights, means;
|
||||||
|
std::vector<cv::Mat> covsHdrs;
|
||||||
|
init_params(_params, params, prbs, weights, means, covsHdrs);
|
||||||
|
|
||||||
|
cv::Mat lbls;
|
||||||
|
cv::Mat likelihoods;
|
||||||
|
bool isOk = emObj.train(_samples, _sample_idx, params, _labels ? &lbls : 0, &probs, &likelihoods );
|
||||||
|
if(isOk)
|
||||||
|
{
|
||||||
|
if(_labels)
|
||||||
|
*_labels = lbls;
|
||||||
|
likelihood = cv::sum(likelihoods)[0];
|
||||||
|
set_mat_hdrs();
|
||||||
|
}
|
||||||
|
|
||||||
|
return isOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
int CvEM::get_nclusters() const
|
||||||
|
{
|
||||||
|
return emObj.getNClusters();
|
||||||
|
}
|
||||||
|
|
||||||
|
const CvMat* CvEM::get_means() const
|
||||||
|
{
|
||||||
|
return emObj.isTrained() ? &meansHdr : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CvMat** CvEM::get_covs() const
|
||||||
|
{
|
||||||
|
return emObj.isTrained() ? (const CvMat**)&covsPtrs[0] : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CvMat* CvEM::get_weights() const
|
||||||
|
{
|
||||||
|
return emObj.isTrained() ? &weightsHdr : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const CvMat* CvEM::get_probs() const
|
||||||
|
{
|
||||||
|
return emObj.isTrained() ? &probsHdr : 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
using namespace cv;
|
||||||
|
|
||||||
|
CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params )
|
||||||
|
{
|
||||||
|
train(samples, sample_idx, params, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
|
||||||
|
CvEMParams _params, Mat* _labels )
|
||||||
|
{
|
||||||
|
cv::EM::Params params;
|
||||||
|
cv::Mat prbs, weights, means;
|
||||||
|
std::vector<cv::Mat> covsHdrs;
|
||||||
|
init_params(_params, params, prbs, weights, means, covsHdrs);
|
||||||
|
|
||||||
|
cv::Mat likelihoods;
|
||||||
|
bool isOk = emObj.train(_samples, _sample_idx, params, _labels, &probs, &likelihoods);
|
||||||
|
if(isOk)
|
||||||
|
{
|
||||||
|
likelihoods = cv::sum(likelihoods).val[0];
|
||||||
|
set_mat_hdrs();
|
||||||
|
}
|
||||||
|
|
||||||
|
return isOk;
|
||||||
|
}
|
||||||
|
|
||||||
|
float
|
||||||
|
CvEM::predict( const Mat& _sample, Mat* _probs, bool isNormalize ) const
|
||||||
|
{
|
||||||
|
int cls = emObj.predict(_sample, _probs);
|
||||||
|
if(_probs && isNormalize)
|
||||||
|
cv::normalize(*_probs, *_probs, 1, 0, cv::NORM_L1);
|
||||||
|
|
||||||
|
return (float)cls;
|
||||||
|
}
|
||||||
|
|
||||||
|
int CvEM::getNClusters() const
|
||||||
|
{
|
||||||
|
return emObj.getNClusters();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Mat& CvEM::getMeans() const
|
||||||
|
{
|
||||||
|
return emObj.getMeans();
|
||||||
|
}
|
||||||
|
|
||||||
|
void CvEM::getCovs(vector<Mat>& _covs) const
|
||||||
|
{
|
||||||
|
_covs = emObj.getCovs();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Mat& CvEM::getWeights() const
|
||||||
|
{
|
||||||
|
return emObj.getWeights();
|
||||||
|
}
|
||||||
|
|
||||||
|
const Mat& CvEM::getProbs() const
|
||||||
|
{
|
||||||
|
return probs;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/* End of file. */
|
445
modules/legacy/test/test_em.cpp
Normal file
445
modules/legacy/test/test_em.cpp
Normal file
@ -0,0 +1,445 @@
|
|||||||
|
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
//
|
||||||
|
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||||
|
//
|
||||||
|
// By downloading, copying, installing or using the software you agree to this license.
|
||||||
|
// If you do not agree to this license, do not download, install,
|
||||||
|
// copy or use the software.
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// Intel License Agreement
|
||||||
|
// For Open Source Computer Vision Library
|
||||||
|
//
|
||||||
|
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
||||||
|
// Third party copyrights are property of their respective owners.
|
||||||
|
//
|
||||||
|
// Redistribution and use in source and binary forms, with or without modification,
|
||||||
|
// are permitted provided that the following conditions are met:
|
||||||
|
//
|
||||||
|
// * Redistribution's of source code must retain the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer.
|
||||||
|
//
|
||||||
|
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||||
|
// this list of conditions and the following disclaimer in the documentation
|
||||||
|
// and/or other materials provided with the distribution.
|
||||||
|
//
|
||||||
|
// * The name of Intel Corporation may not be used to endorse or promote products
|
||||||
|
// derived from this software without specific prior written permission.
|
||||||
|
//
|
||||||
|
// This software is provided by the copyright holders and contributors "as is" and
|
||||||
|
// any express or implied warranties, including, but not limited to, the implied
|
||||||
|
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||||
|
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||||
|
// indirect, incidental, special, exemplary, or consequential damages
|
||||||
|
// (including, but not limited to, procurement of substitute goods or services;
|
||||||
|
// loss of use, data, or profits; or business interruption) however caused
|
||||||
|
// and on any theory of liability, whether in contract, strict liability,
|
||||||
|
// or tort (including negligence or otherwise) arising in any way out of
|
||||||
|
// the use of this software, even if advised of the possibility of such damage.
|
||||||
|
//
|
||||||
|
//M*/
|
||||||
|
|
||||||
|
#include "test_precomp.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace cv;
|
||||||
|
|
||||||
|
static
|
||||||
|
void defaultDistribs( Mat& means, vector<Mat>& covs )
|
||||||
|
{
|
||||||
|
float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f};
|
||||||
|
float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f};
|
||||||
|
float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f};
|
||||||
|
means.create(3, 2, CV_32FC1);
|
||||||
|
Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 );
|
||||||
|
Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 );
|
||||||
|
Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 );
|
||||||
|
means.resize(3), covs.resize(3);
|
||||||
|
|
||||||
|
Mat mr0 = means.row(0);
|
||||||
|
m0.copyTo(mr0);
|
||||||
|
c0.copyTo(covs[0]);
|
||||||
|
|
||||||
|
Mat mr1 = means.row(1);
|
||||||
|
m1.copyTo(mr1);
|
||||||
|
c1.copyTo(covs[1]);
|
||||||
|
|
||||||
|
Mat mr2 = means.row(2);
|
||||||
|
m2.copyTo(mr2);
|
||||||
|
c2.copyTo(covs[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate points sets by normal distributions
|
||||||
|
static
|
||||||
|
void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& _means, const vector<Mat>& covs, int labelType )
|
||||||
|
{
|
||||||
|
vector<int>::const_iterator sit = sizes.begin();
|
||||||
|
int total = 0;
|
||||||
|
for( ; sit != sizes.end(); ++sit )
|
||||||
|
total += *sit;
|
||||||
|
assert( _means.rows == (int)sizes.size() && covs.size() == sizes.size() );
|
||||||
|
assert( !data.empty() && data.rows == total );
|
||||||
|
assert( data.type() == CV_32FC1 );
|
||||||
|
|
||||||
|
labels.create( data.rows, 1, labelType );
|
||||||
|
|
||||||
|
randn( data, Scalar::all(0.0), Scalar::all(1.0) );
|
||||||
|
vector<Mat> means(sizes.size());
|
||||||
|
for(int i = 0; i < _means.rows; i++)
|
||||||
|
means[i] = _means.row(i);
|
||||||
|
vector<Mat>::const_iterator mit = means.begin(), cit = covs.begin();
|
||||||
|
int bi, ei = 0;
|
||||||
|
sit = sizes.begin();
|
||||||
|
for( int p = 0, l = 0; sit != sizes.end(); ++sit, ++mit, ++cit, l++ )
|
||||||
|
{
|
||||||
|
bi = ei;
|
||||||
|
ei = bi + *sit;
|
||||||
|
assert( mit->rows == 1 && mit->cols == data.cols );
|
||||||
|
assert( cit->rows == data.cols && cit->cols == data.cols );
|
||||||
|
for( int i = bi; i < ei; i++, p++ )
|
||||||
|
{
|
||||||
|
Mat r(1, data.cols, CV_32FC1, data.ptr<float>(i));
|
||||||
|
r = r * (*cit) + *mit;
|
||||||
|
if( labelType == CV_32FC1 )
|
||||||
|
labels.at<float>(p, 0) = (float)l;
|
||||||
|
else if( labelType == CV_32SC1 )
|
||||||
|
labels.at<int>(p, 0) = l;
|
||||||
|
else
|
||||||
|
CV_DbgAssert(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static
|
||||||
|
int maxIdx( const vector<int>& count )
|
||||||
|
{
|
||||||
|
int idx = -1;
|
||||||
|
int maxVal = -1;
|
||||||
|
vector<int>::const_iterator it = count.begin();
|
||||||
|
for( int i = 0; it != count.end(); ++it, i++ )
|
||||||
|
{
|
||||||
|
if( *it > maxVal)
|
||||||
|
{
|
||||||
|
maxVal = *it;
|
||||||
|
idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert( idx >= 0);
|
||||||
|
return idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
static
|
||||||
|
bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap )
|
||||||
|
{
|
||||||
|
size_t total = 0, nclusters = sizes.size();
|
||||||
|
for(size_t i = 0; i < sizes.size(); i++)
|
||||||
|
total += sizes[i];
|
||||||
|
|
||||||
|
assert( !labels.empty() );
|
||||||
|
assert( labels.total() == total && (labels.cols == 1 || labels.rows == 1));
|
||||||
|
assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
|
||||||
|
|
||||||
|
bool isFlt = labels.type() == CV_32FC1;
|
||||||
|
|
||||||
|
labelsMap.resize(nclusters);
|
||||||
|
|
||||||
|
vector<bool> buzy(nclusters, false);
|
||||||
|
int startIndex = 0;
|
||||||
|
for( size_t clusterIndex = 0; clusterIndex < sizes.size(); clusterIndex++ )
|
||||||
|
{
|
||||||
|
vector<int> count( nclusters, 0 );
|
||||||
|
for( int i = startIndex; i < startIndex + sizes[clusterIndex]; i++)
|
||||||
|
{
|
||||||
|
int lbl = isFlt ? (int)labels.at<float>(i) : labels.at<int>(i);
|
||||||
|
CV_Assert(lbl < (int)nclusters);
|
||||||
|
count[lbl]++;
|
||||||
|
CV_Assert(count[lbl] < (int)total);
|
||||||
|
}
|
||||||
|
startIndex += sizes[clusterIndex];
|
||||||
|
|
||||||
|
int cls = maxIdx( count );
|
||||||
|
CV_Assert( !buzy[cls] );
|
||||||
|
|
||||||
|
labelsMap[clusterIndex] = cls;
|
||||||
|
|
||||||
|
buzy[cls] = true;
|
||||||
|
}
|
||||||
|
for(size_t i = 0; i < buzy.size(); i++)
|
||||||
|
if(!buzy[i])
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static
|
||||||
|
bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true )
|
||||||
|
{
|
||||||
|
err = 0;
|
||||||
|
CV_Assert( !labels.empty() && !origLabels.empty() );
|
||||||
|
CV_Assert( labels.rows == 1 || labels.cols == 1 );
|
||||||
|
CV_Assert( origLabels.rows == 1 || origLabels.cols == 1 );
|
||||||
|
CV_Assert( labels.total() == origLabels.total() );
|
||||||
|
CV_Assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
|
||||||
|
CV_Assert( origLabels.type() == labels.type() );
|
||||||
|
|
||||||
|
vector<int> labelsMap;
|
||||||
|
bool isFlt = labels.type() == CV_32FC1;
|
||||||
|
if( !labelsEquivalent )
|
||||||
|
{
|
||||||
|
if( !getLabelsMap( labels, sizes, labelsMap ) )
|
||||||
|
return false;
|
||||||
|
|
||||||
|
for( int i = 0; i < labels.rows; i++ )
|
||||||
|
if( isFlt )
|
||||||
|
err += labels.at<float>(i) != labelsMap[(int)origLabels.at<float>(i)] ? 1.f : 0.f;
|
||||||
|
else
|
||||||
|
err += labels.at<int>(i) != labelsMap[origLabels.at<int>(i)] ? 1.f : 0.f;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for( int i = 0; i < labels.rows; i++ )
|
||||||
|
if( isFlt )
|
||||||
|
err += labels.at<float>(i) != origLabels.at<float>(i) ? 1.f : 0.f;
|
||||||
|
else
|
||||||
|
err += labels.at<int>(i) != origLabels.at<int>(i) ? 1.f : 0.f;
|
||||||
|
}
|
||||||
|
err /= (float)labels.rows;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
class CV_CvEMTest : public cvtest::BaseTest
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
CV_CvEMTest() {}
|
||||||
|
protected:
|
||||||
|
virtual void run( int start_from );
|
||||||
|
int runCase( int caseIndex, const CvEMParams& params,
|
||||||
|
const cv::Mat& trainData, const cv::Mat& trainLabels,
|
||||||
|
const cv::Mat& testData, const cv::Mat& testLabels,
|
||||||
|
const vector<int>& sizes);
|
||||||
|
};
|
||||||
|
|
||||||
|
int CV_CvEMTest::runCase( int caseIndex, const CvEMParams& params,
|
||||||
|
const cv::Mat& trainData, const cv::Mat& trainLabels,
|
||||||
|
const cv::Mat& testData, const cv::Mat& testLabels,
|
||||||
|
const vector<int>& sizes )
|
||||||
|
{
|
||||||
|
int code = cvtest::TS::OK;
|
||||||
|
|
||||||
|
cv::Mat labels;
|
||||||
|
float err;
|
||||||
|
|
||||||
|
CvEM em;
|
||||||
|
em.train( trainData, Mat(), params, &labels );
|
||||||
|
|
||||||
|
// check train error
|
||||||
|
if( !calcErr( labels, trainLabels, sizes, err , false ) )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||||
|
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||||
|
}
|
||||||
|
else if( err > 0.006f )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on train data.\n", caseIndex, err );
|
||||||
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check test error
|
||||||
|
labels.create( testData.rows, 1, CV_32SC1 );
|
||||||
|
for( int i = 0; i < testData.rows; i++ )
|
||||||
|
{
|
||||||
|
Mat sample = testData.row(i);
|
||||||
|
labels.at<int>(i,0) = (int)em.predict( sample, 0 );
|
||||||
|
}
|
||||||
|
if( !calcErr( labels, testLabels, sizes, err, false ) )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||||
|
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||||
|
}
|
||||||
|
else if( err > 0.006f )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on test data.\n", caseIndex, err );
|
||||||
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
|
}
|
||||||
|
|
||||||
|
return code;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CV_CvEMTest::run( int /*start_from*/ )
|
||||||
|
{
|
||||||
|
int sizesArr[] = { 500, 700, 800 };
|
||||||
|
int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
|
||||||
|
|
||||||
|
// Points distribution
|
||||||
|
Mat means;
|
||||||
|
vector<Mat> covs;
|
||||||
|
defaultDistribs( means, covs );
|
||||||
|
|
||||||
|
// train data
|
||||||
|
Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
|
||||||
|
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
|
||||||
|
generateData( trainData, trainLabels, sizes, means, covs, CV_32SC1 );
|
||||||
|
|
||||||
|
// test data
|
||||||
|
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels;
|
||||||
|
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
|
||||||
|
|
||||||
|
CvEMParams params;
|
||||||
|
params.nclusters = 3;
|
||||||
|
Mat probs(trainData.rows, params.nclusters, CV_32FC1, cv::Scalar(1));
|
||||||
|
CvMat probsHdr = probs;
|
||||||
|
params.probs = &probsHdr;
|
||||||
|
Mat weights(1, params.nclusters, CV_32FC1, cv::Scalar(1));
|
||||||
|
CvMat weightsHdr = weights;
|
||||||
|
params.weights = &weightsHdr;
|
||||||
|
CvMat meansHdr = means;
|
||||||
|
params.means = &meansHdr;
|
||||||
|
std::vector<CvMat> covsHdrs(params.nclusters);
|
||||||
|
std::vector<const CvMat*> covsPtrs(params.nclusters);
|
||||||
|
for(int i = 0; i < params.nclusters; i++)
|
||||||
|
{
|
||||||
|
covsHdrs[i] = covs[i];
|
||||||
|
covsPtrs[i] = &covsHdrs[i];
|
||||||
|
}
|
||||||
|
params.covs = &covsPtrs[0];
|
||||||
|
|
||||||
|
int code = cvtest::TS::OK;
|
||||||
|
int caseIndex = 0;
|
||||||
|
{
|
||||||
|
params.start_step = cv::EM::START_AUTO_STEP;
|
||||||
|
params.cov_mat_type = cv::EM::COV_MAT_GENERIC;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.start_step = cv::EM::START_AUTO_STEP;
|
||||||
|
params.cov_mat_type = cv::EM::COV_MAT_DIAGONAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.start_step = cv::EM::START_AUTO_STEP;
|
||||||
|
params.cov_mat_type = cv::EM::COV_MAT_SPHERICAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.start_step = cv::EM::START_M_STEP;
|
||||||
|
params.cov_mat_type = cv::EM::COV_MAT_GENERIC;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.start_step = cv::EM::START_M_STEP;
|
||||||
|
params.cov_mat_type = cv::EM::COV_MAT_DIAGONAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.start_step = cv::EM::START_M_STEP;
|
||||||
|
params.cov_mat_type = cv::EM::COV_MAT_SPHERICAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.start_step = cv::EM::START_E_STEP;
|
||||||
|
params.cov_mat_type = cv::EM::COV_MAT_GENERIC;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.start_step = cv::EM::START_E_STEP;
|
||||||
|
params.cov_mat_type = cv::EM::COV_MAT_DIAGONAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.start_step = cv::EM::START_E_STEP;
|
||||||
|
params.cov_mat_type = cv::EM::COV_MAT_SPHERICAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
|
||||||
|
ts->set_failed_test_info( code );
|
||||||
|
}
|
||||||
|
|
||||||
|
class CV_CvEMTest_SaveLoad : public cvtest::BaseTest {
|
||||||
|
public:
|
||||||
|
CV_CvEMTest_SaveLoad() {}
|
||||||
|
protected:
|
||||||
|
virtual void run( int /*start_from*/ )
|
||||||
|
{
|
||||||
|
int code = cvtest::TS::OK;
|
||||||
|
cv::EM em;
|
||||||
|
|
||||||
|
Mat samples = Mat(3,1,CV_32F);
|
||||||
|
samples.at<float>(0,0) = 1;
|
||||||
|
samples.at<float>(1,0) = 2;
|
||||||
|
samples.at<float>(2,0) = 3;
|
||||||
|
|
||||||
|
cv::EM::Params params;
|
||||||
|
params.nclusters = 2;
|
||||||
|
|
||||||
|
Mat labels;
|
||||||
|
|
||||||
|
em.train(samples, Mat(), params, &labels);
|
||||||
|
|
||||||
|
Mat firstResult(samples.rows, 1, CV_32FC1);
|
||||||
|
for( int i = 0; i < samples.rows; i++)
|
||||||
|
firstResult.at<float>(i) = em.predict( samples.row(i) );
|
||||||
|
|
||||||
|
// Write out
|
||||||
|
|
||||||
|
string filename = tempfile() + ".xml";
|
||||||
|
{
|
||||||
|
FileStorage fs = FileStorage(filename, FileStorage::WRITE);
|
||||||
|
try
|
||||||
|
{
|
||||||
|
fs << "em" << "{";
|
||||||
|
em.write(fs);
|
||||||
|
fs << "}";
|
||||||
|
}
|
||||||
|
catch(...)
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Crash in write method.\n" );
|
||||||
|
ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
em.clear();
|
||||||
|
|
||||||
|
// Read in
|
||||||
|
{
|
||||||
|
FileStorage fs = FileStorage(filename, FileStorage::READ);
|
||||||
|
CV_Assert(fs.isOpened());
|
||||||
|
FileNode fn = fs["em"];
|
||||||
|
try
|
||||||
|
{
|
||||||
|
em.read(fn);
|
||||||
|
}
|
||||||
|
catch(...)
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Crash in read method.\n" );
|
||||||
|
ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
remove( filename.c_str() );
|
||||||
|
|
||||||
|
int errCaseCount = 0;
|
||||||
|
for( int i = 0; i < samples.rows; i++)
|
||||||
|
errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<float>(i)) < FLT_EPSILON ? 0 : 1;
|
||||||
|
|
||||||
|
if( errCaseCount > 0 )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Different prediction results before writeing and after reading (errCaseCount=%d).\n", errCaseCount );
|
||||||
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
|
}
|
||||||
|
|
||||||
|
ts->set_failed_test_info( code );
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(ML_CvEM, accuracy) { CV_CvEMTest test; test.safe_run(); }
|
||||||
|
TEST(ML_CvEM, save_load) { CV_CvEMTest_SaveLoad test; test.safe_run(); }
|
@ -46,6 +46,10 @@
|
|||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
// Apple defines a check() macro somewhere in the debug headers
|
// Apple defines a check() macro somewhere in the debug headers
|
||||||
// that interferes with a method definiton in this header
|
// that interferes with a method definiton in this header
|
||||||
#undef check
|
#undef check
|
||||||
@ -549,114 +553,93 @@ protected:
|
|||||||
/****************************************************************************************\
|
/****************************************************************************************\
|
||||||
* Expectation - Maximization *
|
* Expectation - Maximization *
|
||||||
\****************************************************************************************/
|
\****************************************************************************************/
|
||||||
|
namespace cv
|
||||||
struct CV_EXPORTS_W_MAP CvEMParams
|
|
||||||
{
|
{
|
||||||
CvEMParams();
|
class CV_EXPORTS_W EM : public Algorithm
|
||||||
CvEMParams( int nclusters, int cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
|
|
||||||
int start_step=0/*CvEM::START_AUTO_STEP*/,
|
|
||||||
CvTermCriteria term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
|
|
||||||
const CvMat* probs=0, const CvMat* weights=0, const CvMat* means=0, const CvMat** covs=0 );
|
|
||||||
|
|
||||||
CV_PROP_RW int nclusters;
|
|
||||||
CV_PROP_RW int cov_mat_type;
|
|
||||||
CV_PROP_RW int start_step;
|
|
||||||
const CvMat* probs;
|
|
||||||
const CvMat* weights;
|
|
||||||
const CvMat* means;
|
|
||||||
const CvMat** covs;
|
|
||||||
CV_PROP_RW CvTermCriteria term_crit;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
|
||||||
class CV_EXPORTS_W CvEM : public CvStatModel
|
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
// Type of covariation matrices
|
// Type of covariation matrices
|
||||||
enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
|
enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2};
|
||||||
|
|
||||||
// The initial step
|
// The initial step
|
||||||
enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
|
enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
|
||||||
|
|
||||||
CV_WRAP CvEM();
|
class CV_EXPORTS_W Params
|
||||||
CvEM( const CvMat* samples, const CvMat* sampleIdx=0,
|
{
|
||||||
CvEMParams params=CvEMParams(), CvMat* labels=0 );
|
public:
|
||||||
//CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights,
|
Params(int nclusters=10, int covMatType=EM::COV_MAT_DIAGONAL, int startStep=EM::START_AUTO_STEP,
|
||||||
// CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
|
const cv::TermCriteria& termCrit=cv::TermCriteria(cv::TermCriteria::COUNT+cv::TermCriteria::EPS, 100, FLT_EPSILON),
|
||||||
|
const cv::Mat* probs=0, const cv::Mat* weights=0,
|
||||||
|
const cv::Mat* means=0, const std::vector<cv::Mat>* covs=0);
|
||||||
|
|
||||||
virtual ~CvEM();
|
int nclusters;
|
||||||
|
int covMatType;
|
||||||
|
int startStep;
|
||||||
|
|
||||||
virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0,
|
// all 4 following matrices should have type CV_32FC1
|
||||||
CvEMParams params=CvEMParams(), CvMat* labels=0 );
|
const cv::Mat* probs;
|
||||||
|
const cv::Mat* weights;
|
||||||
|
const cv::Mat* means;
|
||||||
|
const std::vector<cv::Mat>* covs;
|
||||||
|
|
||||||
virtual float predict( const CvMat* sample, CV_OUT CvMat* probs ) const;
|
cv::TermCriteria termCrit;
|
||||||
|
};
|
||||||
|
|
||||||
#ifndef SWIG
|
EM();
|
||||||
CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(),
|
EM(const cv::Mat& samples, const cv::Mat samplesMask=cv::Mat(),
|
||||||
CvEMParams params=CvEMParams() );
|
const EM::Params& params=EM::Params(), cv::Mat* labels=0, cv::Mat* probs=0, cv::Mat* likelihoods=0);
|
||||||
|
virtual ~EM();
|
||||||
CV_WRAP virtual bool train( const cv::Mat& samples,
|
virtual void clear();
|
||||||
const cv::Mat& sampleIdx=cv::Mat(),
|
|
||||||
CvEMParams params=CvEMParams(),
|
|
||||||
CV_OUT cv::Mat* labels=0 );
|
|
||||||
|
|
||||||
CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0 ) const;
|
|
||||||
CV_WRAP virtual double calcLikelihood( const cv::Mat &sample ) const;
|
|
||||||
|
|
||||||
CV_WRAP int getNClusters() const;
|
|
||||||
CV_WRAP cv::Mat getMeans() const;
|
|
||||||
CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const;
|
|
||||||
CV_WRAP cv::Mat getWeights() const;
|
|
||||||
CV_WRAP cv::Mat getProbs() const;
|
|
||||||
|
|
||||||
CV_WRAP inline double getLikelihood() const { return log_likelihood; }
|
|
||||||
CV_WRAP inline double getLikelihoodDelta() const { return log_likelihood_delta; }
|
|
||||||
#endif
|
|
||||||
|
|
||||||
CV_WRAP virtual void clear();
|
|
||||||
|
|
||||||
int get_nclusters() const;
|
virtual bool train(const cv::Mat& samples, const cv::Mat& samplesMask=cv::Mat(),
|
||||||
const CvMat* get_means() const;
|
const EM::Params& params=EM::Params(), cv::Mat* labels=0, cv::Mat* probs=0, cv::Mat* likelihoods=0);
|
||||||
const CvMat** get_covs() const;
|
int predict(const cv::Mat& sample, cv::Mat* probs=0, double* likelihood=0) const;
|
||||||
const CvMat* get_weights() const;
|
|
||||||
const CvMat* get_probs() const;
|
|
||||||
|
|
||||||
inline double get_log_likelihood() const { return log_likelihood; }
|
bool isTrained() const;
|
||||||
inline double get_log_likelihood_delta() const { return log_likelihood_delta; }
|
int getNClusters() const;
|
||||||
|
int getCovMatType() const;
|
||||||
// inline const CvMat * get_log_weight_div_det () const { return log_weight_div_det; };
|
|
||||||
// inline const CvMat * get_inv_eigen_values () const { return inv_eigen_values; };
|
|
||||||
// inline const CvMat ** get_cov_rotate_mats () const { return cov_rotate_mats; };
|
|
||||||
|
|
||||||
virtual void read( CvFileStorage* fs, CvFileNode* node );
|
const cv::Mat& getWeights() const;
|
||||||
virtual void write( CvFileStorage* fs, const char* name ) const;
|
const cv::Mat& getMeans() const;
|
||||||
|
const std::vector<cv::Mat>& getCovs() const;
|
||||||
|
|
||||||
virtual void write_params( CvFileStorage* fs ) const;
|
AlgorithmInfo* info() const;
|
||||||
virtual void read_params( CvFileStorage* fs, CvFileNode* node );
|
virtual void read(const FileNode& fn);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
virtual void setTrainData(const cv::Mat& samples, const cv::Mat& samplesMask, const EM::Params& params);
|
||||||
|
|
||||||
virtual void set_params( const CvEMParams& params,
|
bool doTrain(const cv::TermCriteria& termCrit);
|
||||||
const CvVectors& train_data );
|
virtual void eStep();
|
||||||
virtual void init_em( const CvVectors& train_data );
|
virtual void mStep();
|
||||||
virtual double run_em( const CvVectors& train_data );
|
|
||||||
virtual void init_auto( const CvVectors& samples );
|
|
||||||
virtual void kmeans( const CvVectors& train_data, int nclusters,
|
|
||||||
CvMat* labels, CvTermCriteria criteria,
|
|
||||||
const CvMat* means );
|
|
||||||
CvEMParams params;
|
|
||||||
double log_likelihood;
|
|
||||||
double log_likelihood_delta;
|
|
||||||
|
|
||||||
CvMat* means;
|
void clusterTrainSamples();
|
||||||
CvMat** covs;
|
void decomposeCovs();
|
||||||
CvMat* weights;
|
void computeLogWeightDivDet();
|
||||||
CvMat* probs;
|
|
||||||
|
|
||||||
CvMat* log_weight_div_det;
|
void computeProbabilities(const cv::Mat& sample, int& label, cv::Mat* probs, float* likelihood) const;
|
||||||
CvMat* inv_eigen_values;
|
|
||||||
CvMat** cov_rotate_mats;
|
// all inner matrices have type CV_32FC1
|
||||||
|
int nclusters;
|
||||||
|
int covMatType;
|
||||||
|
int startStep;
|
||||||
|
|
||||||
|
cv::Mat trainSamples;
|
||||||
|
cv::Mat trainProbs;
|
||||||
|
cv::Mat trainLikelihoods;
|
||||||
|
cv::Mat trainLabels;
|
||||||
|
cv::Mat trainCounts;
|
||||||
|
|
||||||
|
cv::Mat weights;
|
||||||
|
cv::Mat means;
|
||||||
|
std::vector<cv::Mat> covs;
|
||||||
|
|
||||||
|
std::vector<cv::Mat> covsEigenValues;
|
||||||
|
std::vector<cv::Mat> covsRotateMats;
|
||||||
|
std::vector<cv::Mat> invCovsEigenValues;
|
||||||
|
cv::Mat logWeightDivDet;
|
||||||
};
|
};
|
||||||
|
} // namespace cv
|
||||||
|
|
||||||
/****************************************************************************************\
|
/****************************************************************************************\
|
||||||
* Decision Tree *
|
* Decision Tree *
|
||||||
@ -2012,17 +1995,10 @@ CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
|
|||||||
CvMat** responses,
|
CvMat** responses,
|
||||||
int num_classes, ... );
|
int num_classes, ... );
|
||||||
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/****************************************************************************************\
|
/****************************************************************************************\
|
||||||
* Data *
|
* Data *
|
||||||
\****************************************************************************************/
|
\****************************************************************************************/
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <string>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#define CV_COUNT 0
|
#define CV_COUNT 0
|
||||||
#define CV_PORTION 1
|
#define CV_PORTION 1
|
||||||
|
|
||||||
@ -2133,8 +2109,6 @@ typedef CvSVMParams SVMParams;
|
|||||||
typedef CvSVMKernel SVMKernel;
|
typedef CvSVMKernel SVMKernel;
|
||||||
typedef CvSVMSolver SVMSolver;
|
typedef CvSVMSolver SVMSolver;
|
||||||
typedef CvSVM SVM;
|
typedef CvSVM SVM;
|
||||||
typedef CvEMParams EMParams;
|
|
||||||
typedef CvEM ExpectationMaximization;
|
|
||||||
typedef CvDTreeParams DTreeParams;
|
typedef CvDTreeParams DTreeParams;
|
||||||
typedef CvMLData TrainData;
|
typedef CvMLData TrainData;
|
||||||
typedef CvDTree DecisionTree;
|
typedef CvDTree DecisionTree;
|
||||||
@ -2156,5 +2130,7 @@ template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
#endif // __cplusplus
|
||||||
|
#endif // __OPENCV_ML_HPP__
|
||||||
|
|
||||||
/* End of file. */
|
/* End of file. */
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -44,34 +44,49 @@
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace cv;
|
using namespace cv;
|
||||||
|
|
||||||
void defaultDistribs( vector<Mat>& means, vector<Mat>& covs )
|
static
|
||||||
|
void defaultDistribs( Mat& means, vector<Mat>& covs )
|
||||||
{
|
{
|
||||||
float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f};
|
float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f};
|
||||||
float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f};
|
float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f};
|
||||||
float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f};
|
float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f};
|
||||||
|
means.create(3, 2, CV_32FC1);
|
||||||
Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 );
|
Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 );
|
||||||
Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 );
|
Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 );
|
||||||
Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 );
|
Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 );
|
||||||
means.resize(3), covs.resize(3);
|
means.resize(3), covs.resize(3);
|
||||||
m0.copyTo(means[0]), c0.copyTo(covs[0]);
|
|
||||||
m1.copyTo(means[1]), c1.copyTo(covs[1]);
|
Mat mr0 = means.row(0);
|
||||||
m2.copyTo(means[2]), c2.copyTo(covs[2]);
|
m0.copyTo(mr0);
|
||||||
|
c0.copyTo(covs[0]);
|
||||||
|
|
||||||
|
Mat mr1 = means.row(1);
|
||||||
|
m1.copyTo(mr1);
|
||||||
|
c1.copyTo(covs[1]);
|
||||||
|
|
||||||
|
Mat mr2 = means.row(2);
|
||||||
|
m2.copyTo(mr2);
|
||||||
|
c2.copyTo(covs[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// generate points sets by normal distributions
|
// generate points sets by normal distributions
|
||||||
void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const vector<Mat>& means, const vector<Mat>& covs, int labelType )
|
static
|
||||||
|
void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& _means, const vector<Mat>& covs, int labelType )
|
||||||
{
|
{
|
||||||
vector<int>::const_iterator sit = sizes.begin();
|
vector<int>::const_iterator sit = sizes.begin();
|
||||||
int total = 0;
|
int total = 0;
|
||||||
for( ; sit != sizes.end(); ++sit )
|
for( ; sit != sizes.end(); ++sit )
|
||||||
total += *sit;
|
total += *sit;
|
||||||
assert( means.size() == sizes.size() && covs.size() == sizes.size() );
|
assert( _means.rows == (int)sizes.size() && covs.size() == sizes.size() );
|
||||||
assert( !data.empty() && data.rows == total );
|
assert( !data.empty() && data.rows == total );
|
||||||
assert( data.type() == CV_32FC1 );
|
assert( data.type() == CV_32FC1 );
|
||||||
|
|
||||||
labels.create( data.rows, 1, labelType );
|
labels.create( data.rows, 1, labelType );
|
||||||
|
|
||||||
randn( data, Scalar::all(0.0), Scalar::all(1.0) );
|
randn( data, Scalar::all(0.0), Scalar::all(1.0) );
|
||||||
|
vector<Mat> means(sizes.size());
|
||||||
|
for(int i = 0; i < _means.rows; i++)
|
||||||
|
means[i] = _means.row(i);
|
||||||
vector<Mat>::const_iterator mit = means.begin(), cit = covs.begin();
|
vector<Mat>::const_iterator mit = means.begin(), cit = covs.begin();
|
||||||
int bi, ei = 0;
|
int bi, ei = 0;
|
||||||
sit = sizes.begin();
|
sit = sizes.begin();
|
||||||
@ -95,6 +110,7 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const vecto
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static
|
||||||
int maxIdx( const vector<int>& count )
|
int maxIdx( const vector<int>& count )
|
||||||
{
|
{
|
||||||
int idx = -1;
|
int idx = -1;
|
||||||
@ -112,74 +128,83 @@ int maxIdx( const vector<int>& count )
|
|||||||
return idx;
|
return idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static
|
||||||
bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap )
|
bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap )
|
||||||
{
|
{
|
||||||
int total = 0, setCount = (int)sizes.size();
|
size_t total = 0, nclusters = sizes.size();
|
||||||
vector<int>::const_iterator sit = sizes.begin();
|
for(size_t i = 0; i < sizes.size(); i++)
|
||||||
for( ; sit != sizes.end(); ++sit )
|
total += sizes[i];
|
||||||
total += *sit;
|
|
||||||
assert( !labels.empty() );
|
assert( !labels.empty() );
|
||||||
assert( labels.rows == total && labels.cols == 1 );
|
assert( labels.total() == total && (labels.cols == 1 || labels.rows == 1));
|
||||||
assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
|
assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
|
||||||
|
|
||||||
bool isFlt = labels.type() == CV_32FC1;
|
bool isFlt = labels.type() == CV_32FC1;
|
||||||
labelsMap.resize(setCount);
|
|
||||||
vector<int>::iterator lmit = labelsMap.begin();
|
labelsMap.resize(nclusters);
|
||||||
vector<bool> buzy(setCount, false);
|
|
||||||
int bi, ei = 0;
|
vector<bool> buzy(nclusters, false);
|
||||||
for( sit = sizes.begin(); sit != sizes.end(); ++sit, ++lmit )
|
int startIndex = 0;
|
||||||
|
for( size_t clusterIndex = 0; clusterIndex < sizes.size(); clusterIndex++ )
|
||||||
{
|
{
|
||||||
vector<int> count( setCount, 0 );
|
vector<int> count( nclusters, 0 );
|
||||||
bi = ei;
|
for( int i = startIndex; i < startIndex + sizes[clusterIndex]; i++)
|
||||||
ei = bi + *sit;
|
|
||||||
if( isFlt )
|
|
||||||
{
|
{
|
||||||
for( int i = bi; i < ei; i++ )
|
int lbl = isFlt ? (int)labels.at<float>(i) : labels.at<int>(i);
|
||||||
count[(int)labels.at<float>(i, 0)]++;
|
CV_Assert(lbl < (int)nclusters);
|
||||||
|
count[lbl]++;
|
||||||
|
CV_Assert(count[lbl] < (int)total);
|
||||||
}
|
}
|
||||||
else
|
startIndex += sizes[clusterIndex];
|
||||||
{
|
|
||||||
for( int i = bi; i < ei; i++ )
|
int cls = maxIdx( count );
|
||||||
count[labels.at<int>(i, 0)]++;
|
CV_Assert( !buzy[cls] );
|
||||||
}
|
|
||||||
|
labelsMap[clusterIndex] = cls;
|
||||||
*lmit = maxIdx( count );
|
|
||||||
if( buzy[*lmit] )
|
buzy[cls] = true;
|
||||||
return false;
|
|
||||||
buzy[*lmit] = true;
|
|
||||||
}
|
}
|
||||||
return true;
|
for(size_t i = 0; i < buzy.size(); i++)
|
||||||
|
if(!buzy[i])
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
float calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, bool labelsEquivalent = true )
|
static
|
||||||
|
bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true )
|
||||||
{
|
{
|
||||||
int err = 0;
|
err = 0;
|
||||||
assert( !labels.empty() && !origLabels.empty() );
|
CV_Assert( !labels.empty() && !origLabels.empty() );
|
||||||
assert( labels.cols == 1 && origLabels.cols == 1 );
|
CV_Assert( labels.rows == 1 || labels.cols == 1 );
|
||||||
assert( labels.rows == origLabels.rows );
|
CV_Assert( origLabels.rows == 1 || origLabels.cols == 1 );
|
||||||
assert( labels.type() == origLabels.type() );
|
CV_Assert( labels.total() == origLabels.total() );
|
||||||
assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
|
CV_Assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
|
||||||
|
CV_Assert( origLabels.type() == labels.type() );
|
||||||
|
|
||||||
vector<int> labelsMap;
|
vector<int> labelsMap;
|
||||||
bool isFlt = labels.type() == CV_32FC1;
|
bool isFlt = labels.type() == CV_32FC1;
|
||||||
if( !labelsEquivalent )
|
if( !labelsEquivalent )
|
||||||
{
|
{
|
||||||
getLabelsMap( labels, sizes, labelsMap );
|
if( !getLabelsMap( labels, sizes, labelsMap ) )
|
||||||
|
return false;
|
||||||
|
|
||||||
for( int i = 0; i < labels.rows; i++ )
|
for( int i = 0; i < labels.rows; i++ )
|
||||||
if( isFlt )
|
if( isFlt )
|
||||||
err += labels.at<float>(i, 0) != labelsMap[(int)origLabels.at<float>(i, 0)];
|
err += labels.at<float>(i) != labelsMap[(int)origLabels.at<float>(i)] ? 1.f : 0.f;
|
||||||
else
|
else
|
||||||
err += labels.at<int>(i, 0) != labelsMap[origLabels.at<int>(i, 0)];
|
err += labels.at<int>(i) != labelsMap[origLabels.at<int>(i)] ? 1.f : 0.f;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
for( int i = 0; i < labels.rows; i++ )
|
for( int i = 0; i < labels.rows; i++ )
|
||||||
if( isFlt )
|
if( isFlt )
|
||||||
err += labels.at<float>(i, 0) != origLabels.at<float>(i, 0);
|
err += labels.at<float>(i) != origLabels.at<float>(i) ? 1.f : 0.f;
|
||||||
else
|
else
|
||||||
err += labels.at<int>(i, 0) != origLabels.at<int>(i, 0);
|
err += labels.at<int>(i) != origLabels.at<int>(i) ? 1.f : 0.f;
|
||||||
}
|
}
|
||||||
return (float)err / (float)labels.rows;
|
err /= (float)labels.rows;
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//--------------------------------------------------------------------------------------------
|
//--------------------------------------------------------------------------------------------
|
||||||
@ -198,7 +223,8 @@ void CV_KMeansTest::run( int /*start_from*/ )
|
|||||||
|
|
||||||
Mat data( pointsCount, 2, CV_32FC1 ), labels;
|
Mat data( pointsCount, 2, CV_32FC1 ), labels;
|
||||||
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
|
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
|
||||||
vector<Mat> means, covs;
|
Mat means;
|
||||||
|
vector<Mat> covs;
|
||||||
defaultDistribs( means, covs );
|
defaultDistribs( means, covs );
|
||||||
generateData( data, labels, sizes, means, covs, CV_32SC1 );
|
generateData( data, labels, sizes, means, covs, CV_32SC1 );
|
||||||
|
|
||||||
@ -207,8 +233,12 @@ void CV_KMeansTest::run( int /*start_from*/ )
|
|||||||
Mat bestLabels;
|
Mat bestLabels;
|
||||||
// 1. flag==KMEANS_PP_CENTERS
|
// 1. flag==KMEANS_PP_CENTERS
|
||||||
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_PP_CENTERS, noArray() );
|
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_PP_CENTERS, noArray() );
|
||||||
err = calcErr( bestLabels, labels, sizes, false );
|
if( !calcErr( bestLabels, labels, sizes, err , false ) )
|
||||||
if( err > 0.01f )
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_PP_CENTERS.\n" );
|
||||||
|
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||||
|
}
|
||||||
|
else if( err > 0.01f )
|
||||||
{
|
{
|
||||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
|
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
|
||||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
@ -216,10 +246,14 @@ void CV_KMeansTest::run( int /*start_from*/ )
|
|||||||
|
|
||||||
// 2. flag==KMEANS_RANDOM_CENTERS
|
// 2. flag==KMEANS_RANDOM_CENTERS
|
||||||
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_RANDOM_CENTERS, noArray() );
|
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_RANDOM_CENTERS, noArray() );
|
||||||
err = calcErr( bestLabels, labels, sizes, false );
|
if( !calcErr( bestLabels, labels, sizes, err, false ) )
|
||||||
if( err > 0.01f )
|
|
||||||
{
|
{
|
||||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
|
ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_RANDOM_CENTERS.\n" );
|
||||||
|
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||||
|
}
|
||||||
|
else if( err > 0.01f )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_RANDOM_CENTERS.\n", err );
|
||||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -229,10 +263,14 @@ void CV_KMeansTest::run( int /*start_from*/ )
|
|||||||
for( int i = 0; i < 0.5f * pointsCount; i++ )
|
for( int i = 0; i < 0.5f * pointsCount; i++ )
|
||||||
bestLabels.at<int>( rng.next() % pointsCount, 0 ) = rng.next() % 3;
|
bestLabels.at<int>( rng.next() % pointsCount, 0 ) = rng.next() % 3;
|
||||||
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_USE_INITIAL_LABELS, noArray() );
|
kmeans( data, 3, bestLabels, TermCriteria( TermCriteria::COUNT, iters, 0.0), 0, KMEANS_USE_INITIAL_LABELS, noArray() );
|
||||||
err = calcErr( bestLabels, labels, sizes, false );
|
if( !calcErr( bestLabels, labels, sizes, err, false ) )
|
||||||
if( err > 0.01f )
|
|
||||||
{
|
{
|
||||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_PP_CENTERS.\n", err );
|
ts->printf( cvtest::TS::LOG, "Bad output labels if flag==KMEANS_USE_INITIAL_LABELS.\n" );
|
||||||
|
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||||
|
}
|
||||||
|
else if( err > 0.01f )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) if flag==KMEANS_USE_INITIAL_LABELS.\n", err );
|
||||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -255,7 +293,8 @@ void CV_KNearestTest::run( int /*start_from*/ )
|
|||||||
// train data
|
// train data
|
||||||
Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
|
Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
|
||||||
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
|
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
|
||||||
vector<Mat> means, covs;
|
Mat means;
|
||||||
|
vector<Mat> covs;
|
||||||
defaultDistribs( means, covs );
|
defaultDistribs( means, covs );
|
||||||
generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1 );
|
generateData( trainData, trainLabels, sizes, means, covs, CV_32FC1 );
|
||||||
|
|
||||||
@ -267,8 +306,13 @@ void CV_KNearestTest::run( int /*start_from*/ )
|
|||||||
KNearest knearest;
|
KNearest knearest;
|
||||||
knearest.train( trainData, trainLabels );
|
knearest.train( trainData, trainLabels );
|
||||||
knearest.find_nearest( testData, 4, &bestLabels );
|
knearest.find_nearest( testData, 4, &bestLabels );
|
||||||
float err = calcErr( bestLabels, testLabels, sizes, true );
|
float err;
|
||||||
if( err > 0.01f )
|
if( !calcErr( bestLabels, testLabels, sizes, err, true ) )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Bad output labels.\n" );
|
||||||
|
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||||
|
}
|
||||||
|
else if( err > 0.01f )
|
||||||
{
|
{
|
||||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
|
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
|
||||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
@ -277,76 +321,167 @@ void CV_KNearestTest::run( int /*start_from*/ )
|
|||||||
}
|
}
|
||||||
|
|
||||||
//--------------------------------------------------------------------------------------------
|
//--------------------------------------------------------------------------------------------
|
||||||
class CV_EMTest : public cvtest::BaseTest {
|
class CV_EMTest : public cvtest::BaseTest
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
CV_EMTest() {}
|
CV_EMTest() {}
|
||||||
protected:
|
protected:
|
||||||
virtual void run( int start_from );
|
virtual void run( int start_from );
|
||||||
|
int runCase( int caseIndex, const cv::EM::Params& params,
|
||||||
|
const cv::Mat& trainData, const cv::Mat& trainLabels,
|
||||||
|
const cv::Mat& testData, const cv::Mat& testLabels,
|
||||||
|
const vector<int>& sizes);
|
||||||
};
|
};
|
||||||
|
|
||||||
void CV_EMTest::run( int /*start_from*/ )
|
int CV_EMTest::runCase( int caseIndex, const cv::EM::Params& params,
|
||||||
|
const cv::Mat& trainData, const cv::Mat& trainLabels,
|
||||||
|
const cv::Mat& testData, const cv::Mat& testLabels,
|
||||||
|
const vector<int>& sizes )
|
||||||
{
|
{
|
||||||
int sizesArr[] = { 5000, 7000, 8000 };
|
|
||||||
int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
|
|
||||||
|
|
||||||
// train data
|
|
||||||
Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
|
|
||||||
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
|
|
||||||
vector<Mat> means, covs;
|
|
||||||
defaultDistribs( means, covs );
|
|
||||||
generateData( trainData, trainLabels, sizes, means, covs, CV_32SC1 );
|
|
||||||
|
|
||||||
// test data
|
|
||||||
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels, bestLabels;
|
|
||||||
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
|
|
||||||
|
|
||||||
int code = cvtest::TS::OK;
|
int code = cvtest::TS::OK;
|
||||||
|
|
||||||
|
cv::Mat labels;
|
||||||
float err;
|
float err;
|
||||||
ExpectationMaximization em;
|
|
||||||
CvEMParams params;
|
cv::EM em;
|
||||||
params.nclusters = 3;
|
em.train( trainData, Mat(), params, &labels );
|
||||||
em.train( trainData, Mat(), params, &bestLabels );
|
|
||||||
|
|
||||||
// check train error
|
// check train error
|
||||||
err = calcErr( bestLabels, trainLabels, sizes, false );
|
if( !calcErr( labels, trainLabels, sizes, err , false ) )
|
||||||
if( err > 0.002f )
|
|
||||||
{
|
{
|
||||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on train data.\n", err );
|
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||||
|
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||||
|
}
|
||||||
|
else if( err > 0.006f )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on train data.\n", caseIndex, err );
|
||||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
}
|
}
|
||||||
|
|
||||||
// check test error
|
// check test error
|
||||||
bestLabels.create( testData.rows, 1, CV_32SC1 );
|
labels.create( testData.rows, 1, CV_32SC1 );
|
||||||
for( int i = 0; i < testData.rows; i++ )
|
for( int i = 0; i < testData.rows; i++ )
|
||||||
{
|
{
|
||||||
Mat sample( 1, testData.cols, CV_32FC1, testData.ptr<float>(i));
|
Mat sample = testData.row(i);
|
||||||
bestLabels.at<int>(i,0) = (int)em.predict( sample, 0 );
|
labels.at<int>(i,0) = (int)em.predict( sample, 0 );
|
||||||
}
|
}
|
||||||
err = calcErr( bestLabels, testLabels, sizes, false );
|
if( !calcErr( labels, testLabels, sizes, err, false ) )
|
||||||
if( err > 0.005f )
|
|
||||||
{
|
{
|
||||||
ts->printf( cvtest::TS::LOG, "Bad accuracy (%f) on test data.\n", err );
|
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||||
|
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||||
|
}
|
||||||
|
else if( err > 0.006f )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on test data.\n", caseIndex, err );
|
||||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return code;
|
||||||
|
}
|
||||||
|
|
||||||
|
void CV_EMTest::run( int /*start_from*/ )
|
||||||
|
{
|
||||||
|
int sizesArr[] = { 500, 700, 800 };
|
||||||
|
int pointsCount = sizesArr[0]+ sizesArr[1] + sizesArr[2];
|
||||||
|
|
||||||
|
// Points distribution
|
||||||
|
Mat means;
|
||||||
|
vector<Mat> covs;
|
||||||
|
defaultDistribs( means, covs );
|
||||||
|
|
||||||
|
// train data
|
||||||
|
Mat trainData( pointsCount, 2, CV_32FC1 ), trainLabels;
|
||||||
|
vector<int> sizes( sizesArr, sizesArr + sizeof(sizesArr) / sizeof(sizesArr[0]) );
|
||||||
|
generateData( trainData, trainLabels, sizes, means, covs, CV_32SC1 );
|
||||||
|
|
||||||
|
// test data
|
||||||
|
Mat testData( pointsCount, 2, CV_32FC1 ), testLabels;
|
||||||
|
generateData( testData, testLabels, sizes, means, covs, CV_32SC1 );
|
||||||
|
|
||||||
|
cv::EM::Params params;
|
||||||
|
params.nclusters = 3;
|
||||||
|
Mat probs(trainData.rows, params.nclusters, CV_32FC1, cv::Scalar(1));
|
||||||
|
params.probs = &probs;
|
||||||
|
Mat weights(1, params.nclusters, CV_32FC1, cv::Scalar(1));
|
||||||
|
params.weights = &weights;
|
||||||
|
params.means = &means;
|
||||||
|
params.covs = &covs;
|
||||||
|
|
||||||
|
int code = cvtest::TS::OK;
|
||||||
|
int caseIndex = 0;
|
||||||
|
{
|
||||||
|
params.startStep = cv::EM::START_AUTO_STEP;
|
||||||
|
params.covMatType = cv::EM::COV_MAT_GENERIC;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.startStep = cv::EM::START_AUTO_STEP;
|
||||||
|
params.covMatType = cv::EM::COV_MAT_DIAGONAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.startStep = cv::EM::START_AUTO_STEP;
|
||||||
|
params.covMatType = cv::EM::COV_MAT_SPHERICAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.startStep = cv::EM::START_M_STEP;
|
||||||
|
params.covMatType = cv::EM::COV_MAT_GENERIC;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.startStep = cv::EM::START_M_STEP;
|
||||||
|
params.covMatType = cv::EM::COV_MAT_DIAGONAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.startStep = cv::EM::START_M_STEP;
|
||||||
|
params.covMatType = cv::EM::COV_MAT_SPHERICAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.startStep = cv::EM::START_E_STEP;
|
||||||
|
params.covMatType = cv::EM::COV_MAT_GENERIC;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.startStep = cv::EM::START_E_STEP;
|
||||||
|
params.covMatType = cv::EM::COV_MAT_DIAGONAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
params.startStep = cv::EM::START_E_STEP;
|
||||||
|
params.covMatType = cv::EM::COV_MAT_SPHERICAL;
|
||||||
|
int currCode = runCase(caseIndex++, params, trainData, trainLabels, testData, testLabels, sizes);
|
||||||
|
code = currCode == cvtest::TS::OK ? code : currCode;
|
||||||
|
}
|
||||||
|
|
||||||
ts->set_failed_test_info( code );
|
ts->set_failed_test_info( code );
|
||||||
}
|
}
|
||||||
|
|
||||||
class CV_EMTest_Smoke : public cvtest::BaseTest {
|
class CV_EMTest_SaveLoad : public cvtest::BaseTest {
|
||||||
public:
|
public:
|
||||||
CV_EMTest_Smoke() {}
|
CV_EMTest_SaveLoad() {}
|
||||||
protected:
|
protected:
|
||||||
virtual void run( int /*start_from*/ )
|
virtual void run( int /*start_from*/ )
|
||||||
{
|
{
|
||||||
int code = cvtest::TS::OK;
|
int code = cvtest::TS::OK;
|
||||||
CvEM em;
|
cv::EM em;
|
||||||
|
|
||||||
Mat samples = Mat(3,2,CV_32F);
|
Mat samples = Mat(3,1,CV_32F);
|
||||||
samples.at<float>(0,0) = 1;
|
samples.at<float>(0,0) = 1;
|
||||||
samples.at<float>(1,0) = 2;
|
samples.at<float>(1,0) = 2;
|
||||||
samples.at<float>(2,0) = 3;
|
samples.at<float>(2,0) = 3;
|
||||||
|
|
||||||
CvEMParams params;
|
cv::EM::Params params;
|
||||||
params.nclusters = 2;
|
params.nclusters = 2;
|
||||||
|
|
||||||
Mat labels;
|
Mat labels;
|
||||||
@ -361,10 +496,11 @@ protected:
|
|||||||
string filename = tempfile() + ".xml";
|
string filename = tempfile() + ".xml";
|
||||||
{
|
{
|
||||||
FileStorage fs = FileStorage(filename, FileStorage::WRITE);
|
FileStorage fs = FileStorage(filename, FileStorage::WRITE);
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
em.write(fs.fs, "EM");
|
fs << "em" << "{";
|
||||||
|
em.write(fs);
|
||||||
|
fs << "}";
|
||||||
}
|
}
|
||||||
catch(...)
|
catch(...)
|
||||||
{
|
{
|
||||||
@ -378,11 +514,11 @@ protected:
|
|||||||
// Read in
|
// Read in
|
||||||
{
|
{
|
||||||
FileStorage fs = FileStorage(filename, FileStorage::READ);
|
FileStorage fs = FileStorage(filename, FileStorage::READ);
|
||||||
FileNode fileNode = fs["EM"];
|
CV_Assert(fs.isOpened());
|
||||||
|
FileNode fn = fs["em"];
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
em.read(const_cast<CvFileStorage*>(fileNode.fs), const_cast<CvFileNode*>(fileNode.node));
|
em.read(fn);
|
||||||
}
|
}
|
||||||
catch(...)
|
catch(...)
|
||||||
{
|
{
|
||||||
@ -410,4 +546,4 @@ protected:
|
|||||||
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
|
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
|
||||||
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
|
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
|
||||||
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
|
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
|
||||||
TEST(ML_EM, smoke) { CV_EMTest_Smoke test; test.safe_run(); }
|
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }
|
||||||
|
@ -451,7 +451,6 @@ CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
|
|||||||
nbayes = 0;
|
nbayes = 0;
|
||||||
knearest = 0;
|
knearest = 0;
|
||||||
svm = 0;
|
svm = 0;
|
||||||
em = 0;
|
|
||||||
ann = 0;
|
ann = 0;
|
||||||
dtree = 0;
|
dtree = 0;
|
||||||
boost = 0;
|
boost = 0;
|
||||||
@ -463,8 +462,6 @@ CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
|
|||||||
knearest = new CvKNearest;
|
knearest = new CvKNearest;
|
||||||
else if( !modelName.compare(CV_SVM) )
|
else if( !modelName.compare(CV_SVM) )
|
||||||
svm = new CvSVM;
|
svm = new CvSVM;
|
||||||
else if( !modelName.compare(CV_EM) )
|
|
||||||
em = new CvEM;
|
|
||||||
else if( !modelName.compare(CV_ANN) )
|
else if( !modelName.compare(CV_ANN) )
|
||||||
ann = new CvANN_MLP;
|
ann = new CvANN_MLP;
|
||||||
else if( !modelName.compare(CV_DTREE) )
|
else if( !modelName.compare(CV_DTREE) )
|
||||||
@ -487,8 +484,6 @@ CV_MLBaseTest::~CV_MLBaseTest()
|
|||||||
delete knearest;
|
delete knearest;
|
||||||
if( svm )
|
if( svm )
|
||||||
delete svm;
|
delete svm;
|
||||||
if( em )
|
|
||||||
delete em;
|
|
||||||
if( ann )
|
if( ann )
|
||||||
delete ann;
|
delete ann;
|
||||||
if( dtree )
|
if( dtree )
|
||||||
@ -756,8 +751,6 @@ void CV_MLBaseTest::save( const char* filename )
|
|||||||
knearest->save( filename );
|
knearest->save( filename );
|
||||||
else if( !modelName.compare(CV_SVM) )
|
else if( !modelName.compare(CV_SVM) )
|
||||||
svm->save( filename );
|
svm->save( filename );
|
||||||
else if( !modelName.compare(CV_EM) )
|
|
||||||
em->save( filename );
|
|
||||||
else if( !modelName.compare(CV_ANN) )
|
else if( !modelName.compare(CV_ANN) )
|
||||||
ann->save( filename );
|
ann->save( filename );
|
||||||
else if( !modelName.compare(CV_DTREE) )
|
else if( !modelName.compare(CV_DTREE) )
|
||||||
@ -778,8 +771,6 @@ void CV_MLBaseTest::load( const char* filename )
|
|||||||
knearest->load( filename );
|
knearest->load( filename );
|
||||||
else if( !modelName.compare(CV_SVM) )
|
else if( !modelName.compare(CV_SVM) )
|
||||||
svm->load( filename );
|
svm->load( filename );
|
||||||
else if( !modelName.compare(CV_EM) )
|
|
||||||
em->load( filename );
|
|
||||||
else if( !modelName.compare(CV_ANN) )
|
else if( !modelName.compare(CV_ANN) )
|
||||||
ann->load( filename );
|
ann->load( filename );
|
||||||
else if( !modelName.compare(CV_DTREE) )
|
else if( !modelName.compare(CV_DTREE) )
|
||||||
|
@ -44,7 +44,6 @@ protected:
|
|||||||
CvNormalBayesClassifier* nbayes;
|
CvNormalBayesClassifier* nbayes;
|
||||||
CvKNearest* knearest;
|
CvKNearest* knearest;
|
||||||
CvSVM* svm;
|
CvSVM* svm;
|
||||||
CvEM* em;
|
|
||||||
CvANN_MLP* ann;
|
CvANN_MLP* ann;
|
||||||
CvDTree* dtree;
|
CvDTree* dtree;
|
||||||
CvBoost* boost;
|
CvBoost* boost;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
#include "opencv2/ml/ml.hpp"
|
#include "opencv2/legacy/legacy.hpp"
|
||||||
#include "opencv2/highgui/highgui.hpp"
|
#include "opencv2/highgui/highgui.hpp"
|
||||||
|
|
||||||
using namespace cv;
|
using namespace cv;
|
||||||
|
@ -11,7 +11,6 @@ const Scalar WHITE_COLOR = CV_RGB(255,255,255);
|
|||||||
const string winName = "points";
|
const string winName = "points";
|
||||||
const int testStep = 5;
|
const int testStep = 5;
|
||||||
|
|
||||||
|
|
||||||
Mat img, imgDst;
|
Mat img, imgDst;
|
||||||
RNG rng;
|
RNG rng;
|
||||||
|
|
||||||
@ -19,16 +18,16 @@ vector<Point> trainedPoints;
|
|||||||
vector<int> trainedPointsMarkers;
|
vector<int> trainedPointsMarkers;
|
||||||
vector<Scalar> classColors;
|
vector<Scalar> classColors;
|
||||||
|
|
||||||
#define NBC 0 // normal Bayessian classifier
|
#define _NBC_ 0 // normal Bayessian classifier
|
||||||
#define KNN 0 // k nearest neighbors classifier
|
#define _KNN_ 0 // k nearest neighbors classifier
|
||||||
#define SVM 0 // support vectors machine
|
#define _SVM_ 0 // support vectors machine
|
||||||
#define DT 1 // decision tree
|
#define _DT_ 1 // decision tree
|
||||||
#define BT 0 // ADA Boost
|
#define _BT_ 0 // ADA Boost
|
||||||
#define GBT 0 // gradient boosted trees
|
#define _GBT_ 0 // gradient boosted trees
|
||||||
#define RF 0 // random forest
|
#define _RF_ 0 // random forest
|
||||||
#define ERT 0 // extremely randomized trees
|
#define _ERT_ 0 // extremely randomized trees
|
||||||
#define ANN 0 // artificial neural networks
|
#define _ANN_ 0 // artificial neural networks
|
||||||
#define EM 0 // expectation-maximization
|
#define _EM_ 0 // expectation-maximization
|
||||||
|
|
||||||
void on_mouse( int event, int x, int y, int /*flags*/, void* )
|
void on_mouse( int event, int x, int y, int /*flags*/, void* )
|
||||||
{
|
{
|
||||||
@ -48,13 +47,13 @@ void on_mouse( int event, int x, int y, int /*flags*/, void* )
|
|||||||
}
|
}
|
||||||
else if( event == CV_EVENT_RBUTTONUP )
|
else if( event == CV_EVENT_RBUTTONUP )
|
||||||
{
|
{
|
||||||
#if BT
|
#if _BT_
|
||||||
if( classColors.size() < 2 )
|
if( classColors.size() < 2 )
|
||||||
{
|
{
|
||||||
#endif
|
#endif
|
||||||
classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
|
classColors.push_back( Scalar((uchar)rng(256), (uchar)rng(256), (uchar)rng(256)) );
|
||||||
updateFlag = true;
|
updateFlag = true;
|
||||||
#if BT
|
#if _BT_
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
cout << "New class can not be added, because CvBoost can only be used for 2-class classification" << endl;
|
cout << "New class can not be added, because CvBoost can only be used for 2-class classification" << endl;
|
||||||
@ -98,7 +97,7 @@ void prepare_train_data( Mat& samples, Mat& classes )
|
|||||||
samples.convertTo( samples, CV_32FC1 );
|
samples.convertTo( samples, CV_32FC1 );
|
||||||
}
|
}
|
||||||
|
|
||||||
#if NBC
|
#if _NBC_
|
||||||
void find_decision_boundary_NBC()
|
void find_decision_boundary_NBC()
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -125,7 +124,7 @@ void find_decision_boundary_NBC()
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#if KNN
|
#if _KNN_
|
||||||
void find_decision_boundary_KNN( int K )
|
void find_decision_boundary_KNN( int K )
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -151,7 +150,7 @@ void find_decision_boundary_KNN( int K )
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if SVM
|
#if _SVM_
|
||||||
void find_decision_boundary_SVM( CvSVMParams params )
|
void find_decision_boundary_SVM( CvSVMParams params )
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -185,7 +184,7 @@ void find_decision_boundary_SVM( CvSVMParams params )
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if DT
|
#if _DT_
|
||||||
void find_decision_boundary_DT()
|
void find_decision_boundary_DT()
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -225,7 +224,7 @@ void find_decision_boundary_DT()
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if BT
|
#if _BT_
|
||||||
void find_decision_boundary_BT()
|
void find_decision_boundary_BT()
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -265,7 +264,7 @@ void find_decision_boundary_BT()
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if GBT
|
#if _GBT_
|
||||||
void find_decision_boundary_GBT()
|
void find_decision_boundary_GBT()
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -305,7 +304,7 @@ void find_decision_boundary_GBT()
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if RF
|
#if _RF_
|
||||||
void find_decision_boundary_RF()
|
void find_decision_boundary_RF()
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -346,7 +345,7 @@ void find_decision_boundary_RF()
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if ERT
|
#if _ERT_
|
||||||
void find_decision_boundary_ERT()
|
void find_decision_boundary_ERT()
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -390,7 +389,7 @@ void find_decision_boundary_ERT()
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if ANN
|
#if _ANN_
|
||||||
void find_decision_boundary_ANN( const Mat& layer_sizes )
|
void find_decision_boundary_ANN( const Mat& layer_sizes )
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -435,7 +434,7 @@ void find_decision_boundary_ANN( const Mat& layer_sizes )
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if EM
|
#if _EM_
|
||||||
void find_decision_boundary_EM()
|
void find_decision_boundary_EM()
|
||||||
{
|
{
|
||||||
img.copyTo( imgDst );
|
img.copyTo( imgDst );
|
||||||
@ -443,19 +442,12 @@ void find_decision_boundary_EM()
|
|||||||
Mat trainSamples, trainClasses;
|
Mat trainSamples, trainClasses;
|
||||||
prepare_train_data( trainSamples, trainClasses );
|
prepare_train_data( trainSamples, trainClasses );
|
||||||
|
|
||||||
CvEM em;
|
cv::EM em;
|
||||||
CvEMParams params;
|
cv::EM::Params params;
|
||||||
params.covs = NULL;
|
|
||||||
params.means = NULL;
|
|
||||||
params.weights = NULL;
|
|
||||||
params.probs = NULL;
|
|
||||||
params.nclusters = classColors.size();
|
params.nclusters = classColors.size();
|
||||||
params.cov_mat_type = CvEM::COV_MAT_GENERIC;
|
params.covMatType = cv::EM::COV_MAT_GENERIC;
|
||||||
params.start_step = CvEM::START_AUTO_STEP;
|
params.startStep = cv::EM::START_AUTO_STEP;
|
||||||
params.term_crit.max_iter = 10;
|
params.termCrit = cv::TermCriteria(cv::TermCriteria::COUNT + cv::TermCriteria::COUNT, 10, 0.1);
|
||||||
params.term_crit.epsilon = 0.1;
|
|
||||||
params.term_crit.type = CV_TERMCRIT_ITER | CV_TERMCRIT_EPS;
|
|
||||||
|
|
||||||
|
|
||||||
// learn classifier
|
// learn classifier
|
||||||
em.train( trainSamples, Mat(), params, &trainClasses );
|
em.train( trainSamples, Mat(), params, &trainClasses );
|
||||||
@ -509,12 +501,12 @@ int main()
|
|||||||
|
|
||||||
if( key == 'r' ) // run
|
if( key == 'r' ) // run
|
||||||
{
|
{
|
||||||
#if NBC
|
#if _NBC_
|
||||||
find_decision_boundary_NBC();
|
find_decision_boundary_NBC();
|
||||||
cvNamedWindow( "NormalBayesClassifier", WINDOW_AUTOSIZE );
|
cvNamedWindow( "NormalBayesClassifier", WINDOW_AUTOSIZE );
|
||||||
imshow( "NormalBayesClassifier", imgDst );
|
imshow( "NormalBayesClassifier", imgDst );
|
||||||
#endif
|
#endif
|
||||||
#if KNN
|
#if _KNN_
|
||||||
int K = 3;
|
int K = 3;
|
||||||
find_decision_boundary_KNN( K );
|
find_decision_boundary_KNN( K );
|
||||||
namedWindow( "kNN", WINDOW_AUTOSIZE );
|
namedWindow( "kNN", WINDOW_AUTOSIZE );
|
||||||
@ -526,7 +518,7 @@ int main()
|
|||||||
imshow( "kNN2", imgDst );
|
imshow( "kNN2", imgDst );
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if SVM
|
#if _SVM_
|
||||||
//(1)-(2)separable and not sets
|
//(1)-(2)separable and not sets
|
||||||
CvSVMParams params;
|
CvSVMParams params;
|
||||||
params.svm_type = CvSVM::C_SVC;
|
params.svm_type = CvSVM::C_SVC;
|
||||||
@ -549,37 +541,37 @@ int main()
|
|||||||
imshow( "classificationSVM2", imgDst );
|
imshow( "classificationSVM2", imgDst );
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if DT
|
#if _DT_
|
||||||
find_decision_boundary_DT();
|
find_decision_boundary_DT();
|
||||||
namedWindow( "DT", WINDOW_AUTOSIZE );
|
namedWindow( "DT", WINDOW_AUTOSIZE );
|
||||||
imshow( "DT", imgDst );
|
imshow( "DT", imgDst );
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if BT
|
#if _BT_
|
||||||
find_decision_boundary_BT();
|
find_decision_boundary_BT();
|
||||||
namedWindow( "BT", WINDOW_AUTOSIZE );
|
namedWindow( "BT", WINDOW_AUTOSIZE );
|
||||||
imshow( "BT", imgDst);
|
imshow( "BT", imgDst);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if GBT
|
#if _GBT_
|
||||||
find_decision_boundary_GBT();
|
find_decision_boundary_GBT();
|
||||||
namedWindow( "GBT", WINDOW_AUTOSIZE );
|
namedWindow( "GBT", WINDOW_AUTOSIZE );
|
||||||
imshow( "GBT", imgDst);
|
imshow( "GBT", imgDst);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if RF
|
#if _RF_
|
||||||
find_decision_boundary_RF();
|
find_decision_boundary_RF();
|
||||||
namedWindow( "RF", WINDOW_AUTOSIZE );
|
namedWindow( "RF", WINDOW_AUTOSIZE );
|
||||||
imshow( "RF", imgDst);
|
imshow( "RF", imgDst);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if ERT
|
#if _ERT_
|
||||||
find_decision_boundary_ERT();
|
find_decision_boundary_ERT();
|
||||||
namedWindow( "ERT", WINDOW_AUTOSIZE );
|
namedWindow( "ERT", WINDOW_AUTOSIZE );
|
||||||
imshow( "ERT", imgDst);
|
imshow( "ERT", imgDst);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if ANN
|
#if _ANN_
|
||||||
Mat layer_sizes1( 1, 3, CV_32SC1 );
|
Mat layer_sizes1( 1, 3, CV_32SC1 );
|
||||||
layer_sizes1.at<int>(0) = 2;
|
layer_sizes1.at<int>(0) = 2;
|
||||||
layer_sizes1.at<int>(1) = 5;
|
layer_sizes1.at<int>(1) = 5;
|
||||||
@ -589,7 +581,7 @@ int main()
|
|||||||
imshow( "ANN", imgDst );
|
imshow( "ANN", imgDst );
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if EM
|
#if _EM_
|
||||||
find_decision_boundary_EM();
|
find_decision_boundary_EM();
|
||||||
namedWindow( "EM", WINDOW_AUTOSIZE );
|
namedWindow( "EM", WINDOW_AUTOSIZE );
|
||||||
imshow( "EM", imgDst );
|
imshow( "EM", imgDst );
|
||||||
|
Loading…
x
Reference in New Issue
Block a user