modified EM interface; updated tests & samples

This commit is contained in:
Vadim Pisarevsky
2012-04-06 15:59:30 +00:00
parent 1c1c6b98f6
commit b8c310065c
8 changed files with 338 additions and 333 deletions

View File

@@ -555,61 +555,66 @@ protected:
\****************************************************************************************/
namespace cv
{
class CV_EXPORTS EM : public Algorithm
class CV_EXPORTS_W EM : public Algorithm
{
public:
// Type of covariation matrices
enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2};
enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2, COV_MAT_DEFAULT=COV_MAT_DIAGONAL};
// Default parameters
enum {DEFAULT_NCLUSTERS=10, DEFAULT_MAX_ITERS=100};
// The initial step
enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
class CV_EXPORTS Params
{
public:
Params(int nclusters=10, int covMatType=EM::COV_MAT_DIAGONAL, int startStep=EM::START_AUTO_STEP,
const cv::TermCriteria& termCrit=cv::TermCriteria(cv::TermCriteria::COUNT+cv::TermCriteria::EPS, 100, FLT_EPSILON),
const cv::Mat* probs=0, const cv::Mat* weights=0,
const cv::Mat* means=0, const std::vector<cv::Mat>* covs=0);
int nclusters;
int covMatType;
int startStep;
// all 4 following matrices should have type CV_32FC1
const cv::Mat* probs;
const cv::Mat* weights;
const cv::Mat* means;
const std::vector<cv::Mat>* covs;
cv::TermCriteria termCrit;
};
EM();
EM(const cv::Mat& samples, const cv::Mat samplesMask=cv::Mat(),
const EM::Params& params=EM::Params(), cv::Mat* labels=0, cv::Mat* probs=0, cv::Mat* likelihoods=0);
CV_WRAP EM(int nclusters=EM::DEFAULT_NCLUSTERS, int covMatType=EM::COV_MAT_DIAGONAL,
const TermCriteria& termcrit=TermCriteria(TermCriteria::COUNT+
TermCriteria::EPS,
EM::DEFAULT_MAX_ITERS, FLT_EPSILON));
virtual ~EM();
virtual void clear();
CV_WRAP virtual void clear();
virtual bool train(const cv::Mat& samples, const cv::Mat& samplesMask=cv::Mat(),
const EM::Params& params=EM::Params(), cv::Mat* labels=0, cv::Mat* probs=0, cv::Mat* likelihoods=0);
int predict(const cv::Mat& sample, cv::Mat* probs=0, double* likelihood=0) const;
CV_WRAP virtual bool train(InputArray samples,
OutputArray labels=noArray(),
OutputArray probs=noArray(),
OutputArray likelihoods=noArray());
CV_WRAP virtual bool trainE(InputArray samples,
InputArray means0,
InputArray covs0=noArray(),
InputArray weights0=noArray(),
OutputArray labels=noArray(),
OutputArray probs=noArray(),
OutputArray likelihoods=noArray());
CV_WRAP virtual bool trainM(InputArray samples,
InputArray probs0,
OutputArray labels=noArray(),
OutputArray probs=noArray(),
OutputArray likelihoods=noArray());
CV_WRAP int predict(InputArray sample,
OutputArray probs=noArray(),
CV_OUT double* likelihood=0) const;
bool isTrained() const;
int getNClusters() const;
int getCovMatType() const;
const cv::Mat& getWeights() const;
const cv::Mat& getMeans() const;
const std::vector<cv::Mat>& getCovs() const;
CV_WRAP bool isTrained() const;
AlgorithmInfo* info() const;
virtual void read(const FileNode& fn);
protected:
virtual void setTrainData(const cv::Mat& samples, const cv::Mat& samplesMask, const EM::Params& params);
virtual void setTrainData(int startStep, const Mat& samples,
const Mat* probs0,
const Mat* means0,
const vector<Mat>* covs0,
const Mat* weights0);
bool doTrain(const cv::TermCriteria& termCrit);
bool doTrain(int startStep,
OutputArray labels,
OutputArray probs,
OutputArray likelihoods);
virtual void eStep();
virtual void mStep();
@@ -617,27 +622,28 @@ protected:
void decomposeCovs();
void computeLogWeightDivDet();
void computeProbabilities(const cv::Mat& sample, int& label, cv::Mat* probs, float* likelihood) const;
void computeProbabilities(const Mat& sample, int& label, Mat* probs, float* likelihood) const;
// all inner matrices have type CV_32FC1
int nclusters;
int covMatType;
int startStep;
CV_PROP_RW int nclusters;
CV_PROP_RW int covMatType;
CV_PROP_RW int maxIters;
CV_PROP_RW double epsilon;
cv::Mat trainSamples;
cv::Mat trainProbs;
cv::Mat trainLikelihoods;
cv::Mat trainLabels;
cv::Mat trainCounts;
Mat trainSamples;
Mat trainProbs;
Mat trainLikelihoods;
Mat trainLabels;
Mat trainCounts;
cv::Mat weights;
cv::Mat means;
std::vector<cv::Mat> covs;
CV_PROP Mat weights;
CV_PROP Mat means;
CV_PROP vector<Mat> covs;
std::vector<cv::Mat> covsEigenValues;
std::vector<cv::Mat> covsRotateMats;
std::vector<cv::Mat> invCovsEigenValues;
cv::Mat logWeightDivDet;
vector<Mat> covsEigenValues;
vector<Mat> covsRotateMats;
vector<Mat> invCovsEigenValues;
Mat logWeightDivDet;
};
} // namespace cv