refactored train and predict methods of em

This commit is contained in:
Maria Dimashova
2012-04-17 06:29:40 +00:00
parent 8f7e5811b6
commit 3dfa917879
7 changed files with 56 additions and 65 deletions

View File

@@ -373,11 +373,11 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
cv::EM em(params.nclusters, params.covMatType, params.termCrit);
if( params.startStep == EM::START_AUTO_STEP )
em.train( trainData, labels );
em.train( trainData, noArray(), labels );
else if( params.startStep == EM::START_E_STEP )
em.trainE( trainData, *params.means, *params.covs, *params.weights, labels );
em.trainE( trainData, *params.means, *params.covs, *params.weights, noArray(), labels );
else if( params.startStep == EM::START_M_STEP )
em.trainM( trainData, *params.probs, labels );
em.trainM( trainData, *params.probs, noArray(), labels );
// check train error
if( !calcErr( labels, trainLabels, sizes, err , false, false ) )
@@ -396,9 +396,8 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
for( int i = 0; i < testData.rows; i++ )
{
Mat sample = testData.row(i);
double likelihood = 0;
Mat probs;
labels.at<int>(i,0) = (int)em.predict( sample, probs, &likelihood );
labels.at<int>(i) = static_cast<int>(em.predict( sample, probs )[1]);
}
if( !calcErr( labels, testLabels, sizes, err, false, false ) )
{
@@ -523,7 +522,7 @@ protected:
Mat firstResult(samples.rows, 1, CV_32SC1);
for( int i = 0; i < samples.rows; i++)
firstResult.at<int>(i) = em.predict(samples.row(i));
firstResult.at<int>(i) = static_cast<int>(em.predict(samples.row(i))[1]);
// Write out
string filename = tempfile() + ".xml";
@@ -564,7 +563,7 @@ protected:
int errCaseCount = 0;
for( int i = 0; i < samples.rows; i++)
errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
errCaseCount = std::abs(em.predict(samples.row(i))[1] - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
if( errCaseCount > 0 )
{
@@ -637,10 +636,9 @@ protected:
const double lambda = 1.;
for(int i = 0; i < samples.rows; i++)
{
double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0;
Mat sample = samples.row(i);
model0.predict(sample, noArray(), &sampleLogLikelihoods0);
model1.predict(sample, noArray(), &sampleLogLikelihoods1);
double sampleLogLikelihoods0 = model0.predict(sample)[0];
double sampleLogLikelihoods1 = model1.predict(sample)[0];
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;