merged 2.4 into trunk
This commit is contained in:
@@ -129,7 +129,7 @@ int maxIdx( const vector<int>& count )
|
||||
}
|
||||
|
||||
static
|
||||
bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap )
|
||||
bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap, bool checkClusterUniq=true )
|
||||
{
|
||||
size_t total = 0, nclusters = sizes.size();
|
||||
for(size_t i = 0; i < sizes.size(); i++)
|
||||
@@ -158,21 +158,25 @@ bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& lab
|
||||
startIndex += sizes[clusterIndex];
|
||||
|
||||
int cls = maxIdx( count );
|
||||
CV_Assert( !buzy[cls] );
|
||||
CV_Assert( !checkClusterUniq || !buzy[cls] );
|
||||
|
||||
labelsMap[clusterIndex] = cls;
|
||||
|
||||
buzy[cls] = true;
|
||||
}
|
||||
for(size_t i = 0; i < buzy.size(); i++)
|
||||
if(!buzy[i])
|
||||
return false;
|
||||
|
||||
if(checkClusterUniq)
|
||||
{
|
||||
for(size_t i = 0; i < buzy.size(); i++)
|
||||
if(!buzy[i])
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static
|
||||
bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true )
|
||||
bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent = true, bool checkClusterUniq=true )
|
||||
{
|
||||
err = 0;
|
||||
CV_Assert( !labels.empty() && !origLabels.empty() );
|
||||
@@ -186,7 +190,7 @@ bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes
|
||||
bool isFlt = labels.type() == CV_32FC1;
|
||||
if( !labelsEquivalent )
|
||||
{
|
||||
if( !getLabelsMap( labels, sizes, labelsMap ) )
|
||||
if( !getLabelsMap( labels, sizes, labelsMap, checkClusterUniq ) )
|
||||
return false;
|
||||
|
||||
for( int i = 0; i < labels.rows; i++ )
|
||||
@@ -369,14 +373,14 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
|
||||
|
||||
cv::EM em(params.nclusters, params.covMatType, params.termCrit);
|
||||
if( params.startStep == EM::START_AUTO_STEP )
|
||||
em.train( trainData, labels );
|
||||
em.train( trainData, noArray(), labels );
|
||||
else if( params.startStep == EM::START_E_STEP )
|
||||
em.trainE( trainData, *params.means, *params.covs, *params.weights, labels );
|
||||
em.trainE( trainData, *params.means, *params.covs, *params.weights, noArray(), labels );
|
||||
else if( params.startStep == EM::START_M_STEP )
|
||||
em.trainM( trainData, *params.probs, labels );
|
||||
em.trainM( trainData, *params.probs, noArray(), labels );
|
||||
|
||||
// check train error
|
||||
if( !calcErr( labels, trainLabels, sizes, err , false ) )
|
||||
if( !calcErr( labels, trainLabels, sizes, err , false, false ) )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||
@@ -392,11 +396,10 @@ int CV_EMTest::runCase( int caseIndex, const EM_Params& params,
|
||||
for( int i = 0; i < testData.rows; i++ )
|
||||
{
|
||||
Mat sample = testData.row(i);
|
||||
double likelihood = 0;
|
||||
Mat probs;
|
||||
labels.at<int>(i,0) = (int)em.predict( sample, probs, &likelihood );
|
||||
labels.at<int>(i) = static_cast<int>(em.predict( sample, probs )[1]);
|
||||
}
|
||||
if( !calcErr( labels, testLabels, sizes, err, false ) )
|
||||
if( !calcErr( labels, testLabels, sizes, err, false, false ) )
|
||||
{
|
||||
ts->printf( cvtest::TS::LOG, "Case index %i : Bad output labels.\n", caseIndex );
|
||||
code = cvtest::TS::FAIL_INVALID_OUTPUT;
|
||||
@@ -519,7 +522,7 @@ protected:
|
||||
|
||||
Mat firstResult(samples.rows, 1, CV_32SC1);
|
||||
for( int i = 0; i < samples.rows; i++)
|
||||
firstResult.at<int>(i) = em.predict(samples.row(i));
|
||||
firstResult.at<int>(i) = static_cast<int>(em.predict(samples.row(i))[1]);
|
||||
|
||||
// Write out
|
||||
string filename = tempfile() + ".xml";
|
||||
@@ -560,7 +563,7 @@ protected:
|
||||
|
||||
int errCaseCount = 0;
|
||||
for( int i = 0; i < samples.rows; i++)
|
||||
errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
|
||||
errCaseCount = std::abs(em.predict(samples.row(i))[1] - firstResult.at<int>(i)) < FLT_EPSILON ? 0 : 1;
|
||||
|
||||
if( errCaseCount > 0 )
|
||||
{
|
||||
@@ -572,7 +575,105 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
class CV_EMTest_Classification : public cvtest::BaseTest
|
||||
{
|
||||
public:
|
||||
CV_EMTest_Classification() {}
|
||||
protected:
|
||||
virtual void run(int)
|
||||
{
|
||||
// This test classifies spam by the following way:
|
||||
// 1. estimates distributions of "spam" / "not spam"
|
||||
// 2. predict classID using Bayes classifier for estimated distributions.
|
||||
|
||||
CvMLData data;
|
||||
string dataFilename = string(ts->get_data_path()) + "spambase.data";
|
||||
|
||||
if(data.read_csv(dataFilename.c_str()) != 0)
|
||||
{
|
||||
ts->printf(cvtest::TS::LOG, "File with spambase dataset cann't be read.\n");
|
||||
ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA);
|
||||
}
|
||||
|
||||
Mat values = data.get_values();
|
||||
CV_Assert(values.cols == 58);
|
||||
int responseIndex = 57;
|
||||
|
||||
Mat samples = values.colRange(0, responseIndex);
|
||||
Mat responses = values.col(responseIndex);
|
||||
|
||||
vector<int> trainSamplesMask(samples.rows, 0);
|
||||
int trainSamplesCount = (int)(0.5f * samples.rows);
|
||||
for(int i = 0; i < trainSamplesCount; i++)
|
||||
trainSamplesMask[i] = 1;
|
||||
RNG rng(0);
|
||||
for(size_t i = 0; i < trainSamplesMask.size(); i++)
|
||||
{
|
||||
int i1 = rng(static_cast<unsigned>(trainSamplesMask.size()));
|
||||
int i2 = rng(static_cast<unsigned>(trainSamplesMask.size()));
|
||||
std::swap(trainSamplesMask[i1], trainSamplesMask[i2]);
|
||||
}
|
||||
|
||||
EM model0(3), model1(3);
|
||||
Mat samples0, samples1;
|
||||
for(int i = 0; i < samples.rows; i++)
|
||||
{
|
||||
if(trainSamplesMask[i])
|
||||
{
|
||||
Mat sample = samples.row(i);
|
||||
int resp = (int)responses.at<float>(i);
|
||||
if(resp == 0)
|
||||
samples0.push_back(sample);
|
||||
else
|
||||
samples1.push_back(sample);
|
||||
}
|
||||
}
|
||||
model0.train(samples0);
|
||||
model1.train(samples1);
|
||||
|
||||
Mat trainConfusionMat(2, 2, CV_32SC1, Scalar(0)),
|
||||
testConfusionMat(2, 2, CV_32SC1, Scalar(0));
|
||||
const double lambda = 1.;
|
||||
for(int i = 0; i < samples.rows; i++)
|
||||
{
|
||||
Mat sample = samples.row(i);
|
||||
double sampleLogLikelihoods0 = model0.predict(sample)[0];
|
||||
double sampleLogLikelihoods1 = model1.predict(sample)[0];
|
||||
|
||||
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;
|
||||
|
||||
if(trainSamplesMask[i])
|
||||
trainConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
|
||||
else
|
||||
testConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
|
||||
}
|
||||
// std::cout << trainConfusionMat << std::endl;
|
||||
// std::cout << testConfusionMat << std::endl;
|
||||
|
||||
double trainError = (double)(trainConfusionMat.at<int>(1,0) + trainConfusionMat.at<int>(0,1)) / trainSamplesCount;
|
||||
double testError = (double)(testConfusionMat.at<int>(1,0) + testConfusionMat.at<int>(0,1)) / (samples.rows - trainSamplesCount);
|
||||
const double maxTrainError = 0.16;
|
||||
const double maxTestError = 0.19;
|
||||
|
||||
int code = cvtest::TS::OK;
|
||||
if(trainError > maxTrainError)
|
||||
{
|
||||
ts->printf(cvtest::TS::LOG, "Too large train classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
|
||||
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||
}
|
||||
if(testError > maxTestError)
|
||||
{
|
||||
ts->printf(cvtest::TS::LOG, "Too large test classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
|
||||
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||
}
|
||||
|
||||
ts->set_failed_test_info(code);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
|
||||
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
|
||||
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
|
||||
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }
|
||||
TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); }
|
||||
|
||||
|
Reference in New Issue
Block a user