incorporated several critical fixes in EM implementation from Albert G (ticket #264)
This commit is contained in:
parent
7174957f0d
commit
4d676165ea
@ -789,8 +789,9 @@ double CvEM::run_em( const CvVectors& train_data )
|
|||||||
int nsamples = train_data.count, dims = train_data.dims, nclusters = params.nclusters;
|
int nsamples = train_data.count, dims = train_data.dims, nclusters = params.nclusters;
|
||||||
double min_variation = FLT_EPSILON;
|
double min_variation = FLT_EPSILON;
|
||||||
double min_det_value = MAX( DBL_MIN, pow( min_variation, dims ));
|
double min_det_value = MAX( DBL_MIN, pow( min_variation, dims ));
|
||||||
double likelihood_bias = -CV_LOG2PI * (double)nsamples * (double)dims / 2., _log_likelihood = -DBL_MAX;
|
double _log_likelihood = -DBL_MAX;
|
||||||
int start_step = params.start_step;
|
int start_step = params.start_step;
|
||||||
|
double sum_max_val;
|
||||||
|
|
||||||
int i, j, k, n;
|
int i, j, k, n;
|
||||||
int is_general = 0, is_diagonal = 0, is_spherical = 0;
|
int is_general = 0, is_diagonal = 0, is_spherical = 0;
|
||||||
@ -912,6 +913,7 @@ double CvEM::run_em( const CvVectors& train_data )
|
|||||||
// e-step: compute probs_ik from means_k, covs_k and weights_k.
|
// e-step: compute probs_ik from means_k, covs_k and weights_k.
|
||||||
CV_CALL(cvLog( weights, log_weights ));
|
CV_CALL(cvLog( weights, log_weights ));
|
||||||
|
|
||||||
|
sum_max_val = 0.;
|
||||||
// S_ik = -0.5[log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] + log(weights_k)
|
// S_ik = -0.5[log(det(Sigma_k)) + (x_i - mu_k)' Sigma_k^(-1) (x_i - mu_k)] + log(weights_k)
|
||||||
for( k = 0; k < nclusters; k++ )
|
for( k = 0; k < nclusters; k++ )
|
||||||
{
|
{
|
||||||
@ -934,14 +936,16 @@ double CvEM::run_em( const CvVectors& train_data )
|
|||||||
cvGEMM( centered_sample, u, 1, 0, 0, centered_sample, CV_GEMM_B_T );
|
cvGEMM( centered_sample, u, 1, 0, 0, centered_sample, CV_GEMM_B_T );
|
||||||
for( j = 0; j < dims; j++ )
|
for( j = 0; j < dims; j++ )
|
||||||
p += csample[j]*csample[j]*w_data[is_spherical ? 0 : j];
|
p += csample[j]*csample[j]*w_data[is_spherical ? 0 : j];
|
||||||
pp[k] = -0.5*p + log_weights->data.db[k];
|
//pp[k] = -0.5*p + log_weights->data.db[k];
|
||||||
|
pp[k] = -0.5*(p+CV_LOG2PI * (double)dims) + log_weights->data.db[k];
|
||||||
|
|
||||||
// S_ik <- S_ik - max_j S_ij
|
// S_ik <- S_ik - max_j S_ij
|
||||||
if( k == nclusters - 1 )
|
if( k == nclusters - 1 )
|
||||||
{
|
{
|
||||||
double max_val = 0;
|
double max_val = pp[0];
|
||||||
for( j = 0; j < nclusters; j++ )
|
for( j = 1; j < nclusters; j++ )
|
||||||
max_val = MAX( max_val, pp[j] );
|
max_val = MAX( max_val, pp[j] );
|
||||||
|
sum_max_val += max_val;
|
||||||
for( j = 0; j < nclusters; j++ )
|
for( j = 0; j < nclusters; j++ )
|
||||||
pp[j] -= max_val;
|
pp[j] -= max_val;
|
||||||
}
|
}
|
||||||
@ -953,7 +957,7 @@ double CvEM::run_em( const CvVectors& train_data )
|
|||||||
|
|
||||||
// alpha_ik = exp( S_ik ) / sum_j exp( S_ij ),
|
// alpha_ik = exp( S_ik ) / sum_j exp( S_ij ),
|
||||||
// log_likelihood = sum_i log (sum_j exp(S_ij))
|
// log_likelihood = sum_i log (sum_j exp(S_ij))
|
||||||
for( i = 0, _log_likelihood = likelihood_bias; i < nsamples; i++ )
|
for( i = 0, _log_likelihood = 0; i < nsamples; i++ )
|
||||||
{
|
{
|
||||||
double* pp = (double*)(probs->data.ptr + probs->step*i), sum = 0;
|
double* pp = (double*)(probs->data.ptr + probs->step*i), sum = 0;
|
||||||
for( j = 0; j < nclusters; j++ )
|
for( j = 0; j < nclusters; j++ )
|
||||||
@ -966,9 +970,11 @@ double CvEM::run_em( const CvVectors& train_data )
|
|||||||
}
|
}
|
||||||
_log_likelihood -= log( sum );
|
_log_likelihood -= log( sum );
|
||||||
}
|
}
|
||||||
|
_log_likelihood+=sum_max_val;
|
||||||
|
|
||||||
// check termination criteria
|
// check termination criteria
|
||||||
if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon )
|
//if( fabs( (_log_likelihood - prev_log_likelihood) / prev_log_likelihood ) < params.term_crit.epsilon )
|
||||||
|
if( fabs( (_log_likelihood - prev_log_likelihood) ) < params.term_crit.epsilon )
|
||||||
break;
|
break;
|
||||||
prev_log_likelihood = _log_likelihood;
|
prev_log_likelihood = _log_likelihood;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user