fixed test on em
This commit is contained in:
parent
30f8d5a7d7
commit
570c8254b2
@ -129,7 +129,7 @@ int maxIdx( const vector<int>& count )
|
||||
}
|
||||
|
||||
static
|
||||
bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap )
|
||||
bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap, bool checkClusterUniq=true )
|
||||
{
|
||||
size_t total = 0, nclusters = sizes.size();
|
||||
for(size_t i = 0; i < sizes.size(); i++)
|
||||
@ -158,21 +158,26 @@ bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& lab
|
||||
startIndex += sizes[clusterIndex];
|
||||
|
||||
int cls = maxIdx( count );
|
||||
CV_Assert( !buzy[cls] );
|
||||
if(checkClusterUniq)
|
||||
CV_Assert( !buzy[cls] );
|
||||
|
||||
labelsMap[clusterIndex] = cls;
|
||||
|
||||
buzy[cls] = true;
|
||||
}
|
||||
for(size_t i = 0; i < buzy.size(); i++)
|
||||
if(!buzy[i])
|
||||
return false;
|
||||
|
||||
if(checkClusterUniq)
|
||||
{
|
||||
for(size_t i = 0; i < buzy.size(); i++)
|
||||
if(!buzy[i])
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static
|
||||
bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true )
|
||||
bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true, bool checkClusterUniq=true )
|
||||
{
|
||||
err = 0;
|
||||
CV_Assert( !labels.empty() && !origLabels.empty() );
|
||||
@ -186,7 +191,7 @@ bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes
|
||||
bool isFlt = labels.type() == CV_32FC1;
|
||||
if( !labelsEquivalent )
|
||||
{
|
||||
if( !getLabelsMap( labels, sizes, labelsMap ) )
|
||||
if( !getLabelsMap( labels, sizes, labelsMap, checkClusterUniq ) )
|
||||
return false;
|
||||
|
||||
for( int i = 0; i < labels.rows; i++ )
|
||||
@ -376,7 +381,7 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
|
||||
em.trainM( trainData, *params.probs, labels );
|
||||
|
||||
// check train error
|
||||
if( !calcErr( labels, trainLabels, sizes, err , false ) )
|
||||
if( !calcErr( labels, trainLabels, sizes, err , false, false ) )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||
@ -396,7 +401,7 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
|
||||
Mat probs;
|
||||
labels.at<int>(i,0) = (int)em.predict( sample, probs, &likelihood );
|
||||
}
|
||||
if( !calcErr( labels, testLabels, sizes, err, false ) )
|
||||
if( !calcErr( labels, testLabels, sizes, err, false, false ) )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||
|
Loading…
x
Reference in New Issue
Block a user