fixed nan in EM, added new test on EM
This commit is contained in:
@@ -572,7 +572,106 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
class CV_EMTest_Classification : public cvtest::BaseTest
|
||||
{
|
||||
public:
|
||||
CV_EMTest_Classification() {}
|
||||
protected:
|
||||
virtual void run(int)
|
||||
{
|
||||
// This test classifies spam by the following way:
|
||||
// 1. estimates distributions of "spam" / "not spam"
|
||||
// 2. predict classID using Bayes classifier for estimated distributions.
|
||||
|
||||
CvMLData data;
|
||||
string dataFilename = string(ts->get_data_path()) + "spambase.data";
|
||||
|
||||
if(data.read_csv(dataFilename.c_str()) != 0)
|
||||
{
|
||||
ts->printf(cvtest::TS::LOG, "File with spambase dataset cann't be read.\n");
|
||||
ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA);
|
||||
}
|
||||
|
||||
Mat values = data.get_values();
|
||||
CV_Assert(values.cols == 58);
|
||||
int responseIndex = 57;
|
||||
|
||||
Mat samples = values.colRange(0, responseIndex);
|
||||
Mat responses = values.col(responseIndex);
|
||||
|
||||
vector<int> trainSamplesMask(samples.rows, 0);
|
||||
int trainSamplesCount = (int)(0.5f * samples.rows);
|
||||
for(int i = 0; i < trainSamplesCount; i++)
|
||||
trainSamplesMask[i] = 1;
|
||||
RNG rng(0);
|
||||
for(size_t i = 0; i < trainSamplesMask.size(); i++)
|
||||
{
|
||||
int i1 = rng(trainSamplesMask.size());
|
||||
int i2 = rng(trainSamplesMask.size());
|
||||
std::swap(trainSamplesMask[i1], trainSamplesMask[i2]);
|
||||
}
|
||||
|
||||
EM model0(3), model1(3);
|
||||
Mat samples0, samples1;
|
||||
for(int i = 0; i < samples.rows; i++)
|
||||
{
|
||||
if(trainSamplesMask[i])
|
||||
{
|
||||
Mat sample = samples.row(i);
|
||||
int resp = (int)responses.at<float>(i);
|
||||
if(resp == 0)
|
||||
samples0.push_back(sample);
|
||||
else
|
||||
samples1.push_back(sample);
|
||||
}
|
||||
}
|
||||
model0.train(samples0);
|
||||
model1.train(samples1);
|
||||
|
||||
Mat trainConfusionMat(2, 2, CV_32SC1, Scalar(0)),
|
||||
testConfusionMat(2, 2, CV_32SC1, Scalar(0));
|
||||
const double lambda = 1.;
|
||||
for(int i = 0; i < samples.rows; i++)
|
||||
{
|
||||
double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0;
|
||||
Mat sample = samples.row(i);
|
||||
model0.predict(sample, noArray(), &sampleLogLikelihoods0);
|
||||
model1.predict(sample, noArray(), &sampleLogLikelihoods1);
|
||||
|
||||
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;
|
||||
|
||||
if(trainSamplesMask[i])
|
||||
trainConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
|
||||
else
|
||||
testConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
|
||||
}
|
||||
// std::cout << trainConfusionMat << std::endl;
|
||||
// std::cout << testConfusionMat << std::endl;
|
||||
|
||||
double trainError = (double)(trainConfusionMat.at<int>(1,0) + trainConfusionMat.at<int>(0,1)) / trainSamplesCount;
|
||||
double testError = (double)(testConfusionMat.at<int>(1,0) + testConfusionMat.at<int>(0,1)) / (samples.rows - trainSamplesCount);
|
||||
const double maxTrainError = 0.16;
|
||||
const double maxTestError = 0.19;
|
||||
|
||||
int code = cvtest::TS::OK;
|
||||
if(trainError > maxTrainError)
|
||||
{
|
||||
ts->printf(cvtest::TS::LOG, "Too large train classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
|
||||
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||
}
|
||||
if(testError > maxTestError)
|
||||
{
|
||||
ts->printf(cvtest::TS::LOG, "Too large test classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
|
||||
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||
}
|
||||
|
||||
ts->set_failed_test_info(code);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
|
||||
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
|
||||
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
|
||||
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }
|
||||
TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); }
|
||||
|
||||
|
Reference in New Issue
Block a user