modified likelihood computing
This commit is contained in:
parent
74b38e978b
commit
3b02ee4b29
@ -44,7 +44,7 @@
|
||||
namespace cv
|
||||
{
|
||||
|
||||
const double minEigenValue = 1.e-5;
|
||||
const double minEigenValue = DBL_MIN;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -121,7 +121,7 @@ bool EM::trainM(InputArray samples,
|
||||
}
|
||||
|
||||
|
||||
int EM::predict(InputArray _sample, OutputArray _probs, double* _logLikelihood) const
|
||||
int EM::predict(InputArray _sample, OutputArray _probs, double* logLikelihood) const
|
||||
{
|
||||
Mat sample = _sample.getMat();
|
||||
CV_Assert(isTrained());
|
||||
@ -135,16 +135,13 @@ int EM::predict(InputArray _sample, OutputArray _probs, double* _logLikelihood)
|
||||
}
|
||||
|
||||
int label;
|
||||
double logLikelihood = 0.;
|
||||
Mat probs;
|
||||
if( _probs.needed() )
|
||||
{
|
||||
_probs.create(1, nclusters, CV_64FC1);
|
||||
probs = _probs.getMat();
|
||||
}
|
||||
computeProbabilities(sample, label, !probs.empty() ? &probs : 0, _logLikelihood ? &logLikelihood : 0);
|
||||
if(_logLikelihood)
|
||||
*_logLikelihood = logLikelihood;
|
||||
computeProbabilities(sample, label, !probs.empty() ? &probs : 0, logLikelihood);
|
||||
|
||||
return label;
|
||||
}
|
||||
@ -372,6 +369,7 @@ void EM::computeLogWeightDivDet()
|
||||
CV_Assert(!covsEigenValues.empty());
|
||||
|
||||
Mat logWeights;
|
||||
cv::max(weights, DBL_MIN, weights);
|
||||
log(weights, logWeights);
|
||||
|
||||
logWeightDivDet.create(1, nclusters, CV_64FC1);
|
||||
@ -504,28 +502,24 @@ void EM::computeProbabilities(const Mat& sample, int& label, Mat* probs, double*
|
||||
if(!probs && !logLikelihood)
|
||||
return;
|
||||
|
||||
if(probs)
|
||||
{
|
||||
Mat expL_Lmax;
|
||||
exp(L - L.at<double>(label), expL_Lmax);
|
||||
double partSum = 0, // sum_j!=q (exp(L_ij - L_iq))
|
||||
factor; // 1/(1 + partExpSum)
|
||||
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
||||
if(clusterIndex != label)
|
||||
partSum += expL_Lmax.at<double>(clusterIndex);
|
||||
factor = 1./(1 + partSum);
|
||||
Mat buf, *sampleProbs = probs ? probs : &buf;
|
||||
Mat expL_Lmax;
|
||||
exp(L - L.at<double>(label), expL_Lmax);
|
||||
double partSum = 0, // sum_j!=q (exp(L_ij - L_iq))
|
||||
factor; // 1/(1 + partExpSum)
|
||||
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
||||
if(clusterIndex != label)
|
||||
partSum += expL_Lmax.at<double>(clusterIndex);
|
||||
factor = 1./(1 + partSum);
|
||||
|
||||
probs->create(1, nclusters, CV_64FC1);
|
||||
expL_Lmax *= factor;
|
||||
expL_Lmax.copyTo(*probs);
|
||||
}
|
||||
sampleProbs->create(1, nclusters, CV_64FC1);
|
||||
expL_Lmax *= factor;
|
||||
expL_Lmax.copyTo(*sampleProbs);
|
||||
|
||||
if(logLikelihood)
|
||||
{
|
||||
Mat expL;
|
||||
exp(L, expL);
|
||||
// note logLikelihood = log (sum_j exp(L_ij)) - 0.5 * dims * ln2Pi
|
||||
*logLikelihood = std::log(sum(expL)[0]) - (double)(0.5 * dim * CV_LOG2PI);
|
||||
double logWeightProbs = std::log(std::max(DBL_MIN, sum(*sampleProbs)[0]));
|
||||
*logLikelihood = logWeightProbs;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,7 +83,7 @@ void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat&
|
||||
|
||||
labels.create( data.rows, 1, labelType );
|
||||
|
||||
randn( data, Scalar::all(0.0), Scalar::all(1.0) );
|
||||
randn( data, Scalar::all(-1.0), Scalar::all(1.0) );
|
||||
vector<Mat> means(sizes.size());
|
||||
for(int i = 0; i < _means.rows; i++)
|
||||
means[i] = _means.row(i);
|
||||
@ -381,7 +381,7 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
|
||||
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||
}
|
||||
else if( err > 0.006f )
|
||||
else if( err > 0.008f )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on train data.\n", caseIndex, err );
|
||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||
@ -401,7 +401,7 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
|
||||
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||
}
|
||||
else if( err > 0.006f )
|
||||
else if( err > 0.008f )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "Case index %i : Bad accuracy (%f) on test data.\n", caseIndex, err );
|
||||
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||
@ -505,7 +505,8 @@ protected:
|
||||
virtual void run( int /*start_from*/ )
|
||||
{
|
||||
int code = cvtest::TS::OK;
|
||||
cv::EM em(2);
|
||||
const int nclusters = 2;
|
||||
cv::EM em(nclusters);
|
||||
|
||||
Mat samples = Mat(3,1,CV_64FC1);
|
||||
samples.at<double>(0,0) = 1;
|
||||
|
Loading…
x
Reference in New Issue
Block a user