updated points_classifier sample to use bayes classifier after distributions estimation by EM
This commit is contained in:
parent
eaf0d38f03
commit
7120355e06
@ -442,16 +442,30 @@ void find_decision_boundary_EM()
|
||||
Mat trainSamples, trainClasses;
|
||||
prepare_train_data( trainSamples, trainClasses );
|
||||
|
||||
cv::EM em;
|
||||
cv::EM::Params params;
|
||||
params.nclusters = classColors.size();
|
||||
params.covMatType = cv::EM::COV_MAT_GENERIC;
|
||||
params.startStep = cv::EM::START_AUTO_STEP;
|
||||
params.termCrit = cv::TermCriteria(cv::TermCriteria::COUNT + cv::TermCriteria::COUNT, 10, 0.1);
|
||||
vector<cv::EM> em_models(classColors.size());
|
||||
|
||||
// learn classifier
|
||||
em.train( trainSamples, Mat(), params, &trainClasses );
|
||||
CV_Assert((int)trainClasses.total() == trainSamples.rows);
|
||||
CV_Assert((int)trainClasses.type() == CV_32SC1);
|
||||
|
||||
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
|
||||
{
|
||||
const int componentCount = 3;
|
||||
em_models[modelIndex] = EM(componentCount, cv::EM::COV_MAT_DIAGONAL);
|
||||
|
||||
Mat modelSamples;
|
||||
for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
|
||||
{
|
||||
if(trainClasses.at<int>(sampleIndex) == (int)modelIndex)
|
||||
modelSamples.push_back(trainSamples.row(sampleIndex));
|
||||
}
|
||||
|
||||
// learn models
|
||||
if(!modelSamples.empty())
|
||||
em_models[modelIndex].train(modelSamples);
|
||||
}
|
||||
|
||||
// classify coordinate plane points using the bayes classifier, i.e.
|
||||
// y(x) = arg max_i=1_modelsCount likelihoods_i(x)
|
||||
Mat testSample(1, 2, CV_32FC1 );
|
||||
for( int y = 0; y < img.rows; y += testStep )
|
||||
{
|
||||
@ -460,7 +474,16 @@ void find_decision_boundary_EM()
|
||||
testSample.at<float>(0) = (float)x;
|
||||
testSample.at<float>(1) = (float)y;
|
||||
|
||||
int response = (int)em.predict( testSample );
|
||||
Mat logLikelihoods(1, em_models.size(), CV_64FC1, Scalar(-DBL_MAX));
|
||||
for(size_t modelIndex = 0; modelIndex < em_models.size(); modelIndex++)
|
||||
{
|
||||
if(em_models[modelIndex].isTrained())
|
||||
em_models[modelIndex].predict( testSample, noArray(), &logLikelihoods.at<double>(modelIndex) );
|
||||
}
|
||||
Point maxLoc;
|
||||
minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
|
||||
|
||||
int response = maxLoc.x;
|
||||
circle( imgDst, Point(x,y), 2, classColors[response], 1 );
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user