updated points_classifier sample to use bayes classifier after distributions estimation by EM
This commit is contained in:
@@ -442,16 +442,30 @@ void find_decision_boundary_EM()
|
|||||||
Mat trainSamples, trainClasses;
|
Mat trainSamples, trainClasses;
|
||||||
prepare_train_data( trainSamples, trainClasses );
|
prepare_train_data( trainSamples, trainClasses );
|
||||||
|
|
||||||
cv::EM em;
|
vector<cv::EM> em_models(classColors.size());
|
||||||
cv::EM::Params params;
|
|
||||||
params.nclusters = classColors.size();
|
|
||||||
params.covMatType = cv::EM::COV_MAT_GENERIC;
|
|
||||||
params.startStep = cv::EM::START_AUTO_STEP;
|
|
||||||
params.termCrit = cv::TermCriteria(cv::TermCriteria::COUNT + cv::TermCriteria::COUNT, 10, 0.1);
|
|
||||||
|
|
||||||
// learn classifier
|
CV_Assert((int)trainClasses.total() == trainSamples.rows);
|
||||||
em.train( trainSamples, Mat(), params, &trainClasses );
|
CV_Assert((int)trainClasses.type() == CV_32SC1);
|
||||||
|
|
||||||
|
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
|
||||||
|
{
|
||||||
|
const int componentCount = 3;
|
||||||
|
em_models[modelIndex] = EM(componentCount, cv::EM::COV_MAT_DIAGONAL);
|
||||||
|
|
||||||
|
Mat modelSamples;
|
||||||
|
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
||||||
|
{
|
||||||
|
if(trainClasses.at<int>(sampleIndex) == (int)modelIndex)
|
||||||
|
modelSamples.push_back(trainSamples.row(sampleIndex));
|
||||||
|
}
|
||||||
|
|
||||||
|
// learn models
|
||||||
|
if(!modelSamples.empty())
|
||||||
|
em_models[modelIndex].train(modelSamples);
|
||||||
|
}
|
||||||
|
|
||||||
|
// classify coordinate plane points using the bayes classifier, i.e.
|
||||||
|
// y(x) = arg max_i=1_modelsCount likelihoods_i(x)
|
||||||
Mat testSample(1, 2, CV_32FC1 );
|
Mat testSample(1, 2, CV_32FC1 );
|
||||||
for( int y = 0; y < img.rows; y += testStep )
|
for( int y = 0; y < img.rows; y += testStep )
|
||||||
{
|
{
|
||||||
@@ -460,7 +474,16 @@ void find_decision_boundary_EM()
|
|||||||
testSample.at<float>(0) = (float)x;
|
testSample.at<float>(0) = (float)x;
|
||||||
testSample.at<float>(1) = (float)y;
|
testSample.at<float>(1) = (float)y;
|
||||||
|
|
||||||
int response = (int)em.predict( testSample );
|
Mat logLikelihoods(1, em_models.size(), CV_64FC1, Scalar(-DBL_MAX));
|
||||||
|
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
|
||||||
|
{
|
||||||
|
if(em_models[modelIndex].isTrained())
|
||||||
|
em_models[modelIndex].predict( testSample, noArray(), &logLikelihoods.at<double>(modelIndex) );
|
||||||
|
}
|
||||||
|
Point maxLoc;
|
||||||
|
minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
|
||||||
|
|
||||||
|
int response = maxLoc.x;
|
||||||
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user