fixed nan in EM, added new test on EM
This commit is contained in:
parent
94bcaeb2e9
commit
71d7482aee
@ -563,7 +563,7 @@ public:
|
||||
enum {COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2, COV_MAT_DEFAULT=COV_MAT_DIAGONAL};
|
||||
|
||||
// Default parameters
|
||||
enum {DEFAULT_NCLUSTERS=10, DEFAULT_MAX_ITERS=100};
|
||||
enum {DEFAULT_NCLUSTERS=5, DEFAULT_MAX_ITERS=100};
|
||||
|
||||
// The initial step
|
||||
enum {START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0};
|
||||
@ -635,7 +635,6 @@ protected:
|
||||
Mat trainProbs;
|
||||
Mat trainLogLikelihoods;
|
||||
Mat trainLabels;
|
||||
Mat trainCounts;
|
||||
|
||||
CV_PROP Mat weights;
|
||||
CV_PROP Mat means;
|
||||
@ -2035,7 +2034,7 @@ public:
|
||||
|
||||
// returns:
|
||||
// 0 - OK
|
||||
// 1 - file can not be opened or is not correct
|
||||
// -1 - file can not be opened or is not correct
|
||||
int read_csv( const char* filename );
|
||||
|
||||
const CvMat* get_values() const;
|
||||
|
@ -44,7 +44,7 @@
|
||||
namespace cv
|
||||
{
|
||||
|
||||
const double minEigenValue = DBL_MIN;
|
||||
const double minEigenValue = DBL_EPSILON;
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -67,7 +67,6 @@ void EM::clear()
|
||||
trainProbs.release();
|
||||
trainLogLikelihoods.release();
|
||||
trainLabels.release();
|
||||
trainCounts.release();
|
||||
|
||||
weights.release();
|
||||
means.release();
|
||||
@ -469,7 +468,6 @@ bool EM::doTrain(int startStep, OutputArray labels, OutputArray probs, OutputArr
|
||||
trainProbs.release();
|
||||
trainLabels.release();
|
||||
trainLogLikelihoods.release();
|
||||
trainCounts.release();
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -556,97 +554,114 @@ void EM::eStep()
|
||||
|
||||
void EM::mStep()
|
||||
{
|
||||
trainCounts.create(1, nclusters, CV_32SC1);
|
||||
trainCounts = Scalar(0);
|
||||
// Update means_k, covs_k and weights_k from probs_ik
|
||||
int dim = trainSamples.cols;
|
||||
|
||||
for(int sampleIndex = 0; sampleIndex < trainLabels.rows; sampleIndex++)
|
||||
trainCounts.at<int>(trainLabels.at<int>(sampleIndex))++;
|
||||
// Update weights
|
||||
// not normalized first
|
||||
reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
|
||||
|
||||
if(countNonZero(trainCounts) != (int)trainCounts.total())
|
||||
// Update means
|
||||
means.create(nclusters, dim, CV_64FC1);
|
||||
means = Scalar(0);
|
||||
|
||||
const double minPosWeight = trainSamples.rows * DBL_EPSILON;
|
||||
double minWeight = DBL_MAX;
|
||||
int minWeightClusterIndex = -1;
|
||||
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
||||
{
|
||||
clusterTrainSamples();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Update means_k, covs_k and weights_k from probs_ik
|
||||
int dim = trainSamples.cols;
|
||||
if(weights.at<double>(clusterIndex) <= minPosWeight)
|
||||
continue;
|
||||
|
||||
// Update weights
|
||||
// not normalized first
|
||||
reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
|
||||
|
||||
// Update means
|
||||
means.create(nclusters, dim, CV_64FC1);
|
||||
means = Scalar(0);
|
||||
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
||||
if(weights.at<double>(clusterIndex) < minWeight)
|
||||
{
|
||||
Mat clusterMean = means.row(clusterIndex);
|
||||
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
||||
clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
|
||||
clusterMean /= weights.at<double>(clusterIndex);
|
||||
minWeight = weights.at<double>(clusterIndex);
|
||||
minWeightClusterIndex = clusterIndex;
|
||||
}
|
||||
|
||||
// Update covsEigenValues and invCovsEigenValues
|
||||
covs.resize(nclusters);
|
||||
covsEigenValues.resize(nclusters);
|
||||
Mat clusterMean = means.row(clusterIndex);
|
||||
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
||||
clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
|
||||
clusterMean /= weights.at<double>(clusterIndex);
|
||||
}
|
||||
|
||||
// Update covsEigenValues and invCovsEigenValues
|
||||
covs.resize(nclusters);
|
||||
covsEigenValues.resize(nclusters);
|
||||
if(covMatType == EM::COV_MAT_GENERIC)
|
||||
covsRotateMats.resize(nclusters);
|
||||
invCovsEigenValues.resize(nclusters);
|
||||
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
||||
{
|
||||
if(weights.at<double>(clusterIndex) <= minPosWeight)
|
||||
continue;
|
||||
|
||||
if(covMatType != EM::COV_MAT_SPHERICAL)
|
||||
covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
|
||||
else
|
||||
covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
|
||||
|
||||
if(covMatType == EM::COV_MAT_GENERIC)
|
||||
covsRotateMats.resize(nclusters);
|
||||
invCovsEigenValues.resize(nclusters);
|
||||
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
||||
covs[clusterIndex].create(dim, dim, CV_64FC1);
|
||||
|
||||
Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ?
|
||||
covsEigenValues[clusterIndex] : covs[clusterIndex];
|
||||
|
||||
clusterCov = Scalar(0);
|
||||
|
||||
Mat centeredSample;
|
||||
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
||||
{
|
||||
if(covMatType != EM::COV_MAT_SPHERICAL)
|
||||
covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
|
||||
else
|
||||
covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
|
||||
centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
|
||||
|
||||
if(covMatType == EM::COV_MAT_GENERIC)
|
||||
covs[clusterIndex].create(dim, dim, CV_64FC1);
|
||||
|
||||
Mat clusterCov = covMatType != EM::COV_MAT_GENERIC ?
|
||||
covsEigenValues[clusterIndex] : covs[clusterIndex];
|
||||
|
||||
clusterCov = Scalar(0);
|
||||
|
||||
Mat centeredSample;
|
||||
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
||||
clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
|
||||
else
|
||||
{
|
||||
centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
|
||||
|
||||
if(covMatType == EM::COV_MAT_GENERIC)
|
||||
clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
|
||||
else
|
||||
double p = trainProbs.at<double>(sampleIndex, clusterIndex);
|
||||
for(int di = 0; di < dim; di++ )
|
||||
{
|
||||
double p = trainProbs.at<double>(sampleIndex, clusterIndex);
|
||||
for(int di = 0; di < dim; di++ )
|
||||
{
|
||||
double val = centeredSample.at<double>(di);
|
||||
clusterCov.at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0) += p*val*val;
|
||||
}
|
||||
double val = centeredSample.at<double>(di);
|
||||
clusterCov.at<double>(covMatType != EM::COV_MAT_SPHERICAL ? di : 0) += p*val*val;
|
||||
}
|
||||
}
|
||||
|
||||
if(covMatType == EM::COV_MAT_SPHERICAL)
|
||||
clusterCov /= dim;
|
||||
|
||||
clusterCov /= weights.at<double>(clusterIndex);
|
||||
|
||||
// Update covsRotateMats for EM::COV_MAT_GENERIC only
|
||||
if(covMatType == EM::COV_MAT_GENERIC)
|
||||
{
|
||||
SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
|
||||
covsEigenValues[clusterIndex] = svd.w;
|
||||
covsRotateMats[clusterIndex] = svd.u;
|
||||
}
|
||||
|
||||
max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
|
||||
|
||||
// update invCovsEigenValues
|
||||
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
|
||||
}
|
||||
|
||||
// Normalize weights
|
||||
weights /= trainSamples.rows;
|
||||
if(covMatType == EM::COV_MAT_SPHERICAL)
|
||||
clusterCov /= dim;
|
||||
|
||||
clusterCov /= weights.at<double>(clusterIndex);
|
||||
|
||||
// Update covsRotateMats for EM::COV_MAT_GENERIC only
|
||||
if(covMatType == EM::COV_MAT_GENERIC)
|
||||
{
|
||||
SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
|
||||
covsEigenValues[clusterIndex] = svd.w;
|
||||
covsRotateMats[clusterIndex] = svd.u;
|
||||
}
|
||||
|
||||
max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
|
||||
|
||||
// update invCovsEigenValues
|
||||
invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
|
||||
}
|
||||
|
||||
for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
|
||||
{
|
||||
if(weights.at<double>(clusterIndex) <= minPosWeight)
|
||||
{
|
||||
Mat clusterMean = means.row(clusterIndex);
|
||||
means.row(minWeightClusterIndex).copyTo(clusterMean);
|
||||
covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
|
||||
covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
|
||||
if(covMatType == EM::COV_MAT_GENERIC)
|
||||
covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
|
||||
invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize weights
|
||||
weights /= trainSamples.rows;
|
||||
}
|
||||
|
||||
void EM::read(const FileNode& fn)
|
||||
|
@ -572,7 +572,106 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
class CV_EMTest_Classification : public cvtest::BaseTest
|
||||
{
|
||||
public:
|
||||
CV_EMTest_Classification() {}
|
||||
protected:
|
||||
virtual void run(int)
|
||||
{
|
||||
// This test classifies spam by the following way:
|
||||
// 1. estimates distributions of "spam" / "not spam"
|
||||
// 2. predict classID using Bayes classifier for estimated distributions.
|
||||
|
||||
CvMLData data;
|
||||
string dataFilename = string(ts->get_data_path()) + "spambase.data";
|
||||
|
||||
if(data.read_csv(dataFilename.c_str()) != 0)
|
||||
{
|
||||
ts->printf(cvtest::TS::LOG, "File with spambase dataset cann't be read.\n");
|
||||
ts->set_failed_test_info(cvtest::TS::FAIL_INVALID_TEST_DATA);
|
||||
}
|
||||
|
||||
Mat values = data.get_values();
|
||||
CV_Assert(values.cols == 58);
|
||||
int responseIndex = 57;
|
||||
|
||||
Mat samples = values.colRange(0, responseIndex);
|
||||
Mat responses = values.col(responseIndex);
|
||||
|
||||
vector<int> trainSamplesMask(samples.rows, 0);
|
||||
int trainSamplesCount = (int)(0.5f * samples.rows);
|
||||
for(int i = 0; i < trainSamplesCount; i++)
|
||||
trainSamplesMask[i] = 1;
|
||||
RNG rng(0);
|
||||
for(size_t i = 0; i < trainSamplesMask.size(); i++)
|
||||
{
|
||||
int i1 = rng(trainSamplesMask.size());
|
||||
int i2 = rng(trainSamplesMask.size());
|
||||
std::swap(trainSamplesMask[i1], trainSamplesMask[i2]);
|
||||
}
|
||||
|
||||
EM model0(3), model1(3);
|
||||
Mat samples0, samples1;
|
||||
for(int i = 0; i < samples.rows; i++)
|
||||
{
|
||||
if(trainSamplesMask[i])
|
||||
{
|
||||
Mat sample = samples.row(i);
|
||||
int resp = (int)responses.at<float>(i);
|
||||
if(resp == 0)
|
||||
samples0.push_back(sample);
|
||||
else
|
||||
samples1.push_back(sample);
|
||||
}
|
||||
}
|
||||
model0.train(samples0);
|
||||
model1.train(samples1);
|
||||
|
||||
Mat trainConfusionMat(2, 2, CV_32SC1, Scalar(0)),
|
||||
testConfusionMat(2, 2, CV_32SC1, Scalar(0));
|
||||
const double lambda = 1.;
|
||||
for(int i = 0; i < samples.rows; i++)
|
||||
{
|
||||
double sampleLogLikelihoods0 = 0, sampleLogLikelihoods1 = 0;
|
||||
Mat sample = samples.row(i);
|
||||
model0.predict(sample, noArray(), &sampleLogLikelihoods0);
|
||||
model1.predict(sample, noArray(), &sampleLogLikelihoods1);
|
||||
|
||||
int classID = sampleLogLikelihoods0 >= lambda * sampleLogLikelihoods1 ? 0 : 1;
|
||||
|
||||
if(trainSamplesMask[i])
|
||||
trainConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
|
||||
else
|
||||
testConfusionMat.at<int>((int)responses.at<float>(i), classID)++;
|
||||
}
|
||||
// std::cout << trainConfusionMat << std::endl;
|
||||
// std::cout << testConfusionMat << std::endl;
|
||||
|
||||
double trainError = (double)(trainConfusionMat.at<int>(1,0) + trainConfusionMat.at<int>(0,1)) / trainSamplesCount;
|
||||
double testError = (double)(testConfusionMat.at<int>(1,0) + testConfusionMat.at<int>(0,1)) / (samples.rows - trainSamplesCount);
|
||||
const double maxTrainError = 0.16;
|
||||
const double maxTestError = 0.19;
|
||||
|
||||
int code = cvtest::TS::OK;
|
||||
if(trainError > maxTrainError)
|
||||
{
|
||||
ts->printf(cvtest::TS::LOG, "Too large train classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
|
||||
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||
}
|
||||
if(testError > maxTestError)
|
||||
{
|
||||
ts->printf(cvtest::TS::LOG, "Too large test classification error (calc = %f, valid=%f).\n", trainError, maxTrainError);
|
||||
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||
}
|
||||
|
||||
ts->set_failed_test_info(code);
|
||||
}
|
||||
};
|
||||
|
||||
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
|
||||
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
|
||||
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
|
||||
TEST(ML_EM, save_load) { CV_EMTest_SaveLoad test; test.safe_run(); }
|
||||
TEST(ML_EM, classification) { CV_EMTest_Classification test; test.safe_run(); }
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user