Updated ml module interfaces and documentation
This commit is contained in:
@@ -102,7 +102,7 @@ static void predict_and_paint(const Ptr<StatModel>& model, Mat& dst)
|
||||
static void find_decision_boundary_NBC()
|
||||
{
|
||||
// learn classifier
|
||||
Ptr<NormalBayesClassifier> normalBayesClassifier = StatModel::train<NormalBayesClassifier>(prepare_train_data(), NormalBayesClassifier::Params());
|
||||
Ptr<NormalBayesClassifier> normalBayesClassifier = StatModel::train<NormalBayesClassifier>(prepare_train_data());
|
||||
|
||||
predict_and_paint(normalBayesClassifier, imgDst);
|
||||
}
|
||||
@@ -112,15 +112,29 @@ static void find_decision_boundary_NBC()
|
||||
#if _KNN_
|
||||
static void find_decision_boundary_KNN( int K )
|
||||
{
|
||||
Ptr<KNearest> knn = StatModel::train<KNearest>(prepare_train_data(), KNearest::Params(K, true));
|
||||
|
||||
Ptr<KNearest> knn = KNearest::create();
|
||||
knn->setDefaultK(K);
|
||||
knn->setIsClassifier(true);
|
||||
knn->train(prepare_train_data());
|
||||
predict_and_paint(knn, imgDst);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if _SVM_
|
||||
static void find_decision_boundary_SVM( SVM::Params params )
|
||||
static void find_decision_boundary_SVM( double C )
|
||||
{
|
||||
Ptr<SVM> svm = StatModel::train<SVM>(prepare_train_data(), params);
|
||||
Ptr<SVM> svm = SVM::create();
|
||||
svm->setType(SVM::C_SVC);
|
||||
svm->setKernel(SVM::POLY); //SVM::LINEAR;
|
||||
svm->setDegree(0.5);
|
||||
svm->setGamma(1);
|
||||
svm->setCoef0(1);
|
||||
svm->setNu(0.5);
|
||||
svm->setP(0);
|
||||
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 1000, 0.01));
|
||||
svm->setC(C);
|
||||
svm->train(prepare_train_data());
|
||||
predict_and_paint(svm, imgDst);
|
||||
|
||||
Mat sv = svm->getSupportVectors();
|
||||
@@ -135,16 +149,14 @@ static void find_decision_boundary_SVM( SVM::Params params )
|
||||
#if _DT_
|
||||
static void find_decision_boundary_DT()
|
||||
{
|
||||
DTrees::Params params;
|
||||
params.maxDepth = 8;
|
||||
params.minSampleCount = 2;
|
||||
params.useSurrogates = false;
|
||||
params.CVFolds = 0; // the number of cross-validation folds
|
||||
params.use1SERule = false;
|
||||
params.truncatePrunedTree = false;
|
||||
|
||||
Ptr<DTrees> dtree = StatModel::train<DTrees>(prepare_train_data(), params);
|
||||
|
||||
Ptr<DTrees> dtree = DTrees::create();
|
||||
dtree->setMaxDepth(8);
|
||||
dtree->setMinSampleCount(2);
|
||||
dtree->setUseSurrogates(false);
|
||||
dtree->setCVFolds(0); // the number of cross-validation folds
|
||||
dtree->setUse1SERule(false);
|
||||
dtree->setTruncatePrunedTree(false);
|
||||
dtree->train(prepare_train_data());
|
||||
predict_and_paint(dtree, imgDst);
|
||||
}
|
||||
#endif
|
||||
@@ -152,15 +164,14 @@ static void find_decision_boundary_DT()
|
||||
#if _BT_
|
||||
static void find_decision_boundary_BT()
|
||||
{
|
||||
Boost::Params params( Boost::DISCRETE, // boost_type
|
||||
100, // weak_count
|
||||
0.95, // weight_trim_rate
|
||||
2, // max_depth
|
||||
false, //use_surrogates
|
||||
Mat() // priors
|
||||
);
|
||||
|
||||
Ptr<Boost> boost = StatModel::train<Boost>(prepare_train_data(), params);
|
||||
Ptr<Boost> boost = Boost::create();
|
||||
boost->setBoostType(Boost::DISCRETE);
|
||||
boost->setWeakCount(100);
|
||||
boost->setWeightTrimRate(0.95);
|
||||
boost->setMaxDepth(2);
|
||||
boost->setUseSurrogates(false);
|
||||
boost->setPriors(Mat());
|
||||
boost->train(prepare_train_data());
|
||||
predict_and_paint(boost, imgDst);
|
||||
}
|
||||
|
||||
@@ -185,18 +196,17 @@ static void find_decision_boundary_GBT()
|
||||
#if _RF_
|
||||
static void find_decision_boundary_RF()
|
||||
{
|
||||
RTrees::Params params( 4, // max_depth,
|
||||
2, // min_sample_count,
|
||||
0.f, // regression_accuracy,
|
||||
false, // use_surrogates,
|
||||
16, // max_categories,
|
||||
Mat(), // priors,
|
||||
false, // calc_var_importance,
|
||||
1, // nactive_vars,
|
||||
TermCriteria(TermCriteria::MAX_ITER, 5, 0) // max_num_of_trees_in_the_forest,
|
||||
);
|
||||
|
||||
Ptr<RTrees> rtrees = StatModel::train<RTrees>(prepare_train_data(), params);
|
||||
Ptr<RTrees> rtrees = RTrees::create();
|
||||
rtrees->setMaxDepth(4);
|
||||
rtrees->setMinSampleCount(2);
|
||||
rtrees->setRegressionAccuracy(0.f);
|
||||
rtrees->setUseSurrogates(false);
|
||||
rtrees->setMaxCategories(16);
|
||||
rtrees->setPriors(Mat());
|
||||
rtrees->setCalculateVarImportance(false);
|
||||
rtrees->setActiveVarCount(1);
|
||||
rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 5, 0));
|
||||
rtrees->train(prepare_train_data());
|
||||
predict_and_paint(rtrees, imgDst);
|
||||
}
|
||||
|
||||
@@ -205,9 +215,6 @@ static void find_decision_boundary_RF()
|
||||
#if _ANN_
|
||||
static void find_decision_boundary_ANN( const Mat& layer_sizes )
|
||||
{
|
||||
ANN_MLP::Params params(layer_sizes, ANN_MLP::SIGMOID_SYM, 1, 1, TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON),
|
||||
ANN_MLP::Params::BACKPROP, 0.001);
|
||||
|
||||
Mat trainClasses = Mat::zeros( (int)trainedPoints.size(), (int)classColors.size(), CV_32FC1 );
|
||||
for( int i = 0; i < trainClasses.rows; i++ )
|
||||
{
|
||||
@@ -217,7 +224,12 @@ static void find_decision_boundary_ANN( const Mat& layer_sizes )
|
||||
Mat samples = prepare_train_samples(trainedPoints);
|
||||
Ptr<TrainData> tdata = TrainData::create(samples, ROW_SAMPLE, trainClasses);
|
||||
|
||||
Ptr<ANN_MLP> ann = StatModel::train<ANN_MLP>(tdata, params);
|
||||
Ptr<ANN_MLP> ann = ANN_MLP::create();
|
||||
ann->setLayerSizes(layer_sizes);
|
||||
ann->setActivationFunction(ANN_MLP::SIGMOID_SYM, 1, 1);
|
||||
ann->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 300, FLT_EPSILON));
|
||||
ann->setTrainMethod(ANN_MLP::BACKPROP, 0.001);
|
||||
ann->train(tdata);
|
||||
predict_and_paint(ann, imgDst);
|
||||
}
|
||||
#endif
|
||||
@@ -247,8 +259,11 @@ static void find_decision_boundary_EM()
|
||||
// learn models
|
||||
if( !modelSamples.empty() )
|
||||
{
|
||||
em_models[i] = EM::train(modelSamples, noArray(), noArray(), noArray(),
|
||||
EM::Params(componentCount, EM::COV_MAT_DIAGONAL));
|
||||
Ptr<EM> em = EM::create();
|
||||
em->setClustersNumber(componentCount);
|
||||
em->setCovarianceMatrixType(EM::COV_MAT_DIAGONAL);
|
||||
em->trainEM(modelSamples, noArray(), noArray(), noArray());
|
||||
em_models[i] = em;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -332,33 +347,20 @@ int main()
|
||||
imshow( "NormalBayesClassifier", imgDst );
|
||||
#endif
|
||||
#if _KNN_
|
||||
int K = 3;
|
||||
find_decision_boundary_KNN( K );
|
||||
find_decision_boundary_KNN( 3 );
|
||||
imshow( "kNN", imgDst );
|
||||
|
||||
K = 15;
|
||||
find_decision_boundary_KNN( K );
|
||||
find_decision_boundary_KNN( 15 );
|
||||
imshow( "kNN2", imgDst );
|
||||
#endif
|
||||
|
||||
#if _SVM_
|
||||
//(1)-(2)separable and not sets
|
||||
SVM::Params params;
|
||||
params.svmType = SVM::C_SVC;
|
||||
params.kernelType = SVM::POLY; //CvSVM::LINEAR;
|
||||
params.degree = 0.5;
|
||||
params.gamma = 1;
|
||||
params.coef0 = 1;
|
||||
params.C = 1;
|
||||
params.nu = 0.5;
|
||||
params.p = 0;
|
||||
params.termCrit = TermCriteria(TermCriteria::MAX_ITER+TermCriteria::EPS, 1000, 0.01);
|
||||
|
||||
find_decision_boundary_SVM( params );
|
||||
find_decision_boundary_SVM( 1 );
|
||||
imshow( "classificationSVM1", imgDst );
|
||||
|
||||
params.C = 10;
|
||||
find_decision_boundary_SVM( params );
|
||||
find_decision_boundary_SVM( 10 );
|
||||
imshow( "classificationSVM2", imgDst );
|
||||
#endif
|
||||
|
||||
|
Reference in New Issue
Block a user