Updated ml module interfaces and documentation
This commit is contained in:
@@ -48,37 +48,49 @@ namespace ml
|
||||
|
||||
const double minEigenValue = DBL_EPSILON;
|
||||
|
||||
EM::Params::Params(int _nclusters, int _covMatType, const TermCriteria& _termCrit)
|
||||
{
|
||||
nclusters = _nclusters;
|
||||
covMatType = _covMatType;
|
||||
termCrit = _termCrit;
|
||||
}
|
||||
|
||||
class CV_EXPORTS EMImpl : public EM
|
||||
{
|
||||
public:
|
||||
EMImpl(const Params& _params)
|
||||
|
||||
int nclusters;
|
||||
int covMatType;
|
||||
TermCriteria termCrit;
|
||||
|
||||
CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, termCrit)
|
||||
|
||||
void setClustersNumber(int val)
|
||||
{
|
||||
setParams(_params);
|
||||
nclusters = val;
|
||||
CV_Assert(nclusters > 1);
|
||||
}
|
||||
|
||||
int getClustersNumber() const
|
||||
{
|
||||
return nclusters;
|
||||
}
|
||||
|
||||
void setCovarianceMatrixType(int val)
|
||||
{
|
||||
covMatType = val;
|
||||
CV_Assert(covMatType == COV_MAT_SPHERICAL ||
|
||||
covMatType == COV_MAT_DIAGONAL ||
|
||||
covMatType == COV_MAT_GENERIC);
|
||||
}
|
||||
|
||||
int getCovarianceMatrixType() const
|
||||
{
|
||||
return covMatType;
|
||||
}
|
||||
|
||||
EMImpl()
|
||||
{
|
||||
nclusters = DEFAULT_NCLUSTERS;
|
||||
covMatType=EM::COV_MAT_DIAGONAL;
|
||||
termCrit = TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, EM::DEFAULT_MAX_ITERS, 1e-6);
|
||||
}
|
||||
|
||||
virtual ~EMImpl() {}
|
||||
|
||||
void setParams(const Params& _params)
|
||||
{
|
||||
params = _params;
|
||||
CV_Assert(params.nclusters > 1);
|
||||
CV_Assert(params.covMatType == COV_MAT_SPHERICAL ||
|
||||
params.covMatType == COV_MAT_DIAGONAL ||
|
||||
params.covMatType == COV_MAT_GENERIC);
|
||||
}
|
||||
|
||||
Params getParams() const
|
||||
{
|
||||
return params;
|
||||
}
|
||||
|
||||
void clear()
|
||||
{
|
||||
trainSamples.release();
|
||||
@@ -100,10 +112,10 @@ public:
|
||||
bool train(const Ptr<TrainData>& data, int)
|
||||
{
|
||||
Mat samples = data->getTrainSamples(), labels;
|
||||
return train_(samples, labels, noArray(), noArray());
|
||||
return trainEM(samples, labels, noArray(), noArray());
|
||||
}
|
||||
|
||||
bool train_(InputArray samples,
|
||||
bool trainEM(InputArray samples,
|
||||
OutputArray logLikelihoods,
|
||||
OutputArray labels,
|
||||
OutputArray probs)
|
||||
@@ -157,7 +169,7 @@ public:
|
||||
{
|
||||
if( _outputs.fixedType() )
|
||||
ptype = _outputs.type();
|
||||
_outputs.create(samples.rows, params.nclusters, ptype);
|
||||
_outputs.create(samples.rows, nclusters, ptype);
|
||||
}
|
||||
else
|
||||
nsamples = std::min(nsamples, 1);
|
||||
@@ -193,7 +205,7 @@ public:
|
||||
{
|
||||
if( _probs.fixedType() )
|
||||
ptype = _probs.type();
|
||||
_probs.create(1, params.nclusters, ptype);
|
||||
_probs.create(1, nclusters, ptype);
|
||||
probs = _probs.getMat();
|
||||
}
|
||||
|
||||
@@ -311,7 +323,6 @@ public:
|
||||
const std::vector<Mat>* covs0,
|
||||
const Mat* weights0)
|
||||
{
|
||||
int nclusters = params.nclusters, covMatType = params.covMatType;
|
||||
clear();
|
||||
|
||||
checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
|
||||
@@ -350,7 +361,6 @@ public:
|
||||
|
||||
void decomposeCovs()
|
||||
{
|
||||
int nclusters = params.nclusters, covMatType = params.covMatType;
|
||||
CV_Assert(!covs.empty());
|
||||
covsEigenValues.resize(nclusters);
|
||||
if(covMatType == COV_MAT_GENERIC)
|
||||
@@ -383,7 +393,6 @@ public:
|
||||
|
||||
void clusterTrainSamples()
|
||||
{
|
||||
int nclusters = params.nclusters;
|
||||
int nsamples = trainSamples.rows;
|
||||
|
||||
// Cluster samples, compute/update means
|
||||
@@ -443,7 +452,6 @@ public:
|
||||
|
||||
void computeLogWeightDivDet()
|
||||
{
|
||||
int nclusters = params.nclusters;
|
||||
CV_Assert(!covsEigenValues.empty());
|
||||
|
||||
Mat logWeights;
|
||||
@@ -458,7 +466,7 @@ public:
|
||||
double logDetCov = 0.;
|
||||
const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
|
||||
for(int di = 0; di < evalCount; di++)
|
||||
logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(params.covMatType != COV_MAT_SPHERICAL ? di : 0));
|
||||
logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0));
|
||||
|
||||
logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
|
||||
}
|
||||
@@ -466,7 +474,6 @@ public:
|
||||
|
||||
bool doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
|
||||
{
|
||||
int nclusters = params.nclusters;
|
||||
int dim = trainSamples.cols;
|
||||
// Precompute the empty initial train data in the cases of START_E_STEP and START_AUTO_STEP
|
||||
if(startStep != START_M_STEP)
|
||||
@@ -488,9 +495,9 @@ public:
|
||||
mStep();
|
||||
|
||||
double trainLogLikelihood, prevTrainLogLikelihood = 0.;
|
||||
int maxIters = (params.termCrit.type & TermCriteria::MAX_ITER) ?
|
||||
params.termCrit.maxCount : DEFAULT_MAX_ITERS;
|
||||
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0.;
|
||||
int maxIters = (termCrit.type & TermCriteria::MAX_ITER) ?
|
||||
termCrit.maxCount : DEFAULT_MAX_ITERS;
|
||||
double epsilon = (termCrit.type & TermCriteria::EPS) ? termCrit.epsilon : 0.;
|
||||
|
||||
for(int iter = 0; ; iter++)
|
||||
{
|
||||
@@ -521,12 +528,12 @@ public:
|
||||
covs.resize(nclusters);
|
||||
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
||||
{
|
||||
if(params.covMatType == COV_MAT_SPHERICAL)
|
||||
if(covMatType == COV_MAT_SPHERICAL)
|
||||
{
|
||||
covs[clusterIndex].create(dim, dim, CV_64FC1);
|
||||
setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
|
||||
}
|
||||
else if(params.covMatType == COV_MAT_DIAGONAL)
|
||||
else if(covMatType == COV_MAT_DIAGONAL)
|
||||
{
|
||||
covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
|
||||
}
|
||||
@@ -555,7 +562,6 @@ public:
|
||||
// see Alex Smola's blog http://blog.smola.org/page/2 for
|
||||
// details on the log-sum-exp trick
|
||||
|
||||
int nclusters = params.nclusters, covMatType = params.covMatType;
|
||||
int stype = sample.type();
|
||||
CV_Assert(!means.empty());
|
||||
CV_Assert((stype == CV_32F || stype == CV_64F) && (ptype == CV_32F || ptype == CV_64F));
|
||||
@@ -621,7 +627,7 @@ public:
|
||||
void eStep()
|
||||
{
|
||||
// Compute probs_ik from means_k, covs_k and weights_k.
|
||||
trainProbs.create(trainSamples.rows, params.nclusters, CV_64FC1);
|
||||
trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
|
||||
trainLabels.create(trainSamples.rows, 1, CV_32SC1);
|
||||
trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
|
||||
|
||||
@@ -642,8 +648,6 @@ public:
|
||||
void mStep()
|
||||
{
|
||||
// Update means_k, covs_k and weights_k from probs_ik
|
||||
int nclusters = params.nclusters;
|
||||
int covMatType = params.covMatType;
|
||||
int dim = trainSamples.cols;
|
||||
|
||||
// Update weights
|
||||
@@ -755,12 +759,12 @@ public:
|
||||
|
||||
void write_params(FileStorage& fs) const
|
||||
{
|
||||
fs << "nclusters" << params.nclusters;
|
||||
fs << "cov_mat_type" << (params.covMatType == COV_MAT_SPHERICAL ? String("spherical") :
|
||||
params.covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
|
||||
params.covMatType == COV_MAT_GENERIC ? String("generic") :
|
||||
format("unknown_%d", params.covMatType));
|
||||
writeTermCrit(fs, params.termCrit);
|
||||
fs << "nclusters" << nclusters;
|
||||
fs << "cov_mat_type" << (covMatType == COV_MAT_SPHERICAL ? String("spherical") :
|
||||
covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
|
||||
covMatType == COV_MAT_GENERIC ? String("generic") :
|
||||
format("unknown_%d", covMatType));
|
||||
writeTermCrit(fs, termCrit);
|
||||
}
|
||||
|
||||
void write(FileStorage& fs) const
|
||||
@@ -781,15 +785,13 @@ public:
|
||||
|
||||
void read_params(const FileNode& fn)
|
||||
{
|
||||
Params _params;
|
||||
_params.nclusters = (int)fn["nclusters"];
|
||||
nclusters = (int)fn["nclusters"];
|
||||
String s = (String)fn["cov_mat_type"];
|
||||
_params.covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
|
||||
covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
|
||||
s == "diagonal" ? COV_MAT_DIAGONAL :
|
||||
s == "generic" ? COV_MAT_GENERIC : -1;
|
||||
CV_Assert(_params.covMatType >= 0);
|
||||
_params.termCrit = readTermCrit(fn);
|
||||
setParams(_params);
|
||||
CV_Assert(covMatType >= 0);
|
||||
termCrit = readTermCrit(fn);
|
||||
}
|
||||
|
||||
void read(const FileNode& fn)
|
||||
@@ -820,8 +822,6 @@ public:
|
||||
std::copy(covs.begin(), covs.end(), _covs.begin());
|
||||
}
|
||||
|
||||
Params params;
|
||||
|
||||
// all inner matrices have type CV_64FC1
|
||||
Mat trainSamples;
|
||||
Mat trainProbs;
|
||||
@@ -838,41 +838,9 @@ public:
|
||||
Mat logWeightDivDet;
|
||||
};
|
||||
|
||||
|
||||
Ptr<EM> EM::train(InputArray samples, OutputArray logLikelihoods,
|
||||
OutputArray labels, OutputArray probs,
|
||||
const EM::Params& params)
|
||||
Ptr<EM> EM::create()
|
||||
{
|
||||
Ptr<EMImpl> em = makePtr<EMImpl>(params);
|
||||
if(!em->train_(samples, logLikelihoods, labels, probs))
|
||||
em.release();
|
||||
return em;
|
||||
}
|
||||
|
||||
Ptr<EM> EM::train_startWithE(InputArray samples, InputArray means0,
|
||||
InputArray covs0, InputArray weights0,
|
||||
OutputArray logLikelihoods, OutputArray labels,
|
||||
OutputArray probs, const EM::Params& params)
|
||||
{
|
||||
Ptr<EMImpl> em = makePtr<EMImpl>(params);
|
||||
if(!em->trainE(samples, means0, covs0, weights0, logLikelihoods, labels, probs))
|
||||
em.release();
|
||||
return em;
|
||||
}
|
||||
|
||||
Ptr<EM> EM::train_startWithM(InputArray samples, InputArray probs0,
|
||||
OutputArray logLikelihoods, OutputArray labels,
|
||||
OutputArray probs, const EM::Params& params)
|
||||
{
|
||||
Ptr<EMImpl> em = makePtr<EMImpl>(params);
|
||||
if(!em->trainM(samples, probs0, logLikelihoods, labels, probs))
|
||||
em.release();
|
||||
return em;
|
||||
}
|
||||
|
||||
Ptr<EM> EM::create(const Params& params)
|
||||
{
|
||||
return makePtr<EMImpl>(params);
|
||||
return makePtr<EMImpl>();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user