refactored likelihood computing
This commit is contained in:
parent
51385ac73a
commit
04d24a8824
@ -58,7 +58,7 @@ EM::EM(int _nclusters, int _covMatType, const TermCriteria& _criteria)
|
|||||||
|
|
||||||
EM::~EM()
|
EM::~EM()
|
||||||
{
|
{
|
||||||
clear();
|
//clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
void EM::clear()
|
void EM::clear()
|
||||||
@ -322,6 +322,8 @@ void EM::clusterTrainSamples()
|
|||||||
int nsamples = trainSamples.rows;
|
int nsamples = trainSamples.rows;
|
||||||
|
|
||||||
// Cluster samples, compute/update means
|
// Cluster samples, compute/update means
|
||||||
|
|
||||||
|
// Convert samples and means to 32F, because kmeans requires this type.
|
||||||
Mat trainSamplesFlt, meansFlt;
|
Mat trainSamplesFlt, meansFlt;
|
||||||
if(trainSamples.type() != CV_32FC1)
|
if(trainSamples.type() != CV_32FC1)
|
||||||
trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
|
trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
|
||||||
@ -338,6 +340,7 @@ void EM::clusterTrainSamples()
|
|||||||
Mat labels;
|
Mat labels;
|
||||||
kmeans(trainSamplesFlt, nclusters, labels, TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5), 10, KMEANS_PP_CENTERS, meansFlt);
|
kmeans(trainSamplesFlt, nclusters, labels, TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5), 10, KMEANS_PP_CENTERS, meansFlt);
|
||||||
|
|
||||||
|
// Convert samples and means back to 64F.
|
||||||
CV_Assert(meansFlt.type() == CV_32FC1);
|
CV_Assert(meansFlt.type() == CV_32FC1);
|
||||||
if(trainSamples.type() != CV_64FC1)
|
if(trainSamples.type() != CV_64FC1)
|
||||||
{
|
{
|
||||||
@ -476,6 +479,8 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
|
|||||||
// L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
|
// L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
|
||||||
// q = arg(max_k(L_ik))
|
// q = arg(max_k(L_ik))
|
||||||
// probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
|
// probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
|
||||||
|
// see Alex Smola's blog http://blog.smola.org/page/2 for
|
||||||
|
// details on the log-sum-exp trick
|
||||||
|
|
||||||
CV_Assert(!means.empty());
|
CV_Assert(!means.empty());
|
||||||
CV_Assert(sample.type() == CV_64FC1);
|
CV_Assert(sample.type() == CV_64FC1);
|
||||||
@ -511,29 +516,22 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
|
|||||||
if(!probs && !logLikelihood)
|
if(!probs && !logLikelihood)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
Mat expL_Lmax(L.size(), CV_64FC1);
|
|
||||||
double maxLVal = L.at<double>(label);
|
double maxLVal = L.at<double>(label);
|
||||||
|
Mat expL_Lmax = L; // exp(L_ij - L_iq)
|
||||||
for(int i = 0; i < L.cols; i++)
|
for(int i = 0; i < L.cols; i++)
|
||||||
expL_Lmax.at<double>(i) = std::exp(L.at<double>(i) - maxLVal);
|
expL_Lmax.at<double>(i) = std::exp(L.at<double>(i) - maxLVal);
|
||||||
|
double expDiffSum = sum(expL_Lmax)[0]; // sum_j(exp(L_ij - L_iq))
|
||||||
double partSum = 0; // sum_j!=q (exp(L_ij - L_iq))
|
|
||||||
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
|
||||||
if(clusterIndex != label)
|
|
||||||
partSum += expL_Lmax.at<double>(clusterIndex);
|
|
||||||
|
|
||||||
if(probs)
|
if(probs)
|
||||||
{
|
{
|
||||||
probs->create(1, nclusters, CV_64FC1);
|
probs->create(1, nclusters, CV_64FC1);
|
||||||
double factor = 1./(1 + partSum);
|
double factor = 1./expDiffSum;
|
||||||
expL_Lmax *= factor;
|
expL_Lmax *= factor;
|
||||||
expL_Lmax.copyTo(*probs);
|
expL_Lmax.copyTo(*probs);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(logLikelihood)
|
if(logLikelihood)
|
||||||
{
|
*logLikelihood = std::log(expDiffSum) + maxLVal - 0.5 * dim * CV_LOG2PI;
|
||||||
double logWeightProbs = std::log((1 + partSum) * std::exp(maxLVal)) - 0.5 * dim * CV_LOG2PI;
|
|
||||||
*logLikelihood = logWeightProbs;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void EM::eStep()
|
void EM::eStep()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user