refactored train and predict methods of em
This commit is contained in:
@@ -373,11 +373,11 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
|
||||
|
||||
cv::EM em(params.nclusters, params.covMatType, params.termCrit);
|
||||
if( params.startStep == EM::START_AUTO_STEP )
|
||||
em.train( trainData, labels );
|
||||
em.train( trainData, noArray(), labels );
|
||||
else if( params.startStep == EM::START_E_STEP )
|
||||
em.trainE( trainData, *params.means, *params.covs, *params.weights, labels );
|
||||
em.trainE( trainData, *params.means, *params.covs, *params.weights, noArray(), labels );
|
||||
else if( params.startStep == EM::START_M_STEP )
|
||||
em.trainM( trainData, *params.probs, labels );
|
||||
em.trainM( trainData, *params.probs, noArray(), labels );
|
||||
|
||||
// check train error
|
||||
if( !calcErr( labels, trainLabels, sizes, err , false, false ) )
|
||||
@@ -396,9 +396,8 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
|
||||
for( int i = 0; i < testData.rows; i++ )
|
||||
{
|
||||
Mat sample = testData.row(i);
|
||||
double likelihood = 0;
|
||||
Mat probs;
|
||||
labels.at<int>(i,0) = (int)em.predict( sample, probs, &likelihood );
|
||||
labels.at<int>(i) = static_cast<int>(em.predict( sample, probs )[1]);
|
||||
}
|
||||
if( !calcErr( labels, testLabels, sizes, err, false, false ) )
|
||||
{
|
||||
@@ -523,7 +522,7 @@ protected:
|
||||
|
||||
Mat firstResult(samples.rows, 1, CV_32SC1);
|
||||
for( int i = 0; i < samples.rows; i++)
|
||||
firstResult.at<int>(i) = em.predict(samples.row(i));
|
||||
firstResult.at<int>(i) = static_cast<int>(em.predict(samples.row(i))[1]);
|
||||
|
||||
// Write out
|
||||
string filename = tempfile() + ".xml";
|
||||
@@ -564,7 +563,7 @@ protected:
|
||||
|
||||
int errCaseCount = 0;
|
||||
for( int i = 0; i < samples.rows; i++)
|
||||
errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
|
||||
errCaseCount = std::abs(em.predict(samples.row(i))[1] - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
|
||||
|
||||
if( errCaseCount > 0 )
|
||||
{
|
||||
@@ -637,10 +636,9 @@ protected:
|
||||
const double lambda = 1.;
|
||||
for(int i = 0; i < samples.rows; i++)
|
||||
{
|
||||
double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0;
|
||||
Mat sample = samples.row(i);
|
||||
model0.predict(sample, noArray(), &sampleLogLikelihoods0);
|
||||
model1.predict(sample, noArray(), &sampleLogLikelihoods1);
|
||||
double sampleLogLikelihoods0 = model0.predict(sample)[0];
|
||||
double sampleLogLikelihoods1 = model1.predict(sample)[0];
|
||||
|
||||
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;
|
||||
|
||||
|
Reference in New Issue
Block a user