fixed nan in EM, added new test on EM

This commit is contained in:
Maria Dimashova
2012-04-16 10:30:42 +00:00
parent 94bcaeb2e9
commit 71d7482aee
3 changed files with 192 additions and 79 deletions

View File

@@ -572,7 +572,106 @@ protected:
}
};
class CV_EMTest_Classification : public cvtest::BaseTest
{
public:
CV_EMTest_Classification() {}
protected:
virtual void run(int)
{
// This test classifies spam by the following way:
// 1. estimates distributions of "spam" / "not spam"
// 2. predict classID using Bayes classifier for estimated distributions.
CvMLData data;
string dataFilename = string(ts->get_data_path()) + "spambase.data";
if(data.read_csv(dataFilename.c_str()) != 0)
{
ts->printf(cvtest::TS::LOG, "File with spambase dataset cann't be read.\n");
ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA);
}
Mat values = data.get_values();
CV_Assert(values.cols == 58);
int responseIndex = 57;
Mat samples = values.colRange(0, responseIndex);
Mat responses = values.col(responseIndex);
vector<int> trainSamplesMask(samples.rows, 0);
int trainSamplesCount = (int)(0.5f * samples.rows);
for(int i = 0; i < trainSamplesCount; i++)
trainSamplesMask[i] = 1;
RNG rng(0);
for(size_t i = 0; i < trainSamplesMask.size(); i++)
{
int i1 = rng(trainSamplesMask.size());
int i2 = rng(trainSamplesMask.size());
std::swap(trainSamplesMask[i1], trainSamplesMask[i2]);
}
EM model0(3), model1(3);
Mat samples0, samples1;
for(int i = 0; i < samples.rows; i++)
{
if(trainSamplesMask[i])
{
Mat sample = samples.row(i);
int resp = (int)responses.at<float>(i);
if(resp == 0)
samples0.push_back(sample);
else
samples1.push_back(sample);
}
}
model0.train(samples0);
model1.train(samples1);
Mat trainConfusionMat(2, 2, CV_32SC1, Scalar(0)),
testConfusionMat(2, 2, CV_32SC1, Scalar(0));
const double lambda = 1.;
for(int i = 0; i < samples.rows; i++)
{
double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0;
Mat sample = samples.row(i);
model0.predict(sample, noArray(), &sampleLogLikelihoods0);
model1.predict(sample, noArray(), &sampleLogLikelihoods1);
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;
if(trainSamplesMask[i])
trainConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
else
testConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
}
// std::cout << trainConfusionMat << std::endl;
// std::cout << testConfusionMat << std::endl;
double trainError = (double)(trainConfusionMat.at<int>(1,0) + trainConfusionMat.at<int>(0,1)) / trainSamplesCount;
double testError = (double)(testConfusionMat.at<int>(1,0) + testConfusionMat.at<int>(0,1)) / (samples.rows - trainSamplesCount);
const double maxTrainError = 0.16;
const double maxTestError = 0.19;
int code = cvtest::TS::OK;
if(trainError > maxTrainError)
{
ts->printf(cvtest::TS::LOG, "Too large train classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
}
if(testError > maxTestError)
{
ts->printf(cvtest::TS::LOG, "Too large test classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
}
ts->set_failed_test_info(code);
}
};
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }
TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); }