Updated ml module interfaces and documentation

This commit is contained in:
Maksim Shabunin
2015-02-11 13:24:14 +03:00
parent da383e65e2
commit 79e8f0680c
32 changed files with 1403 additions and 1528 deletions

View File

@@ -178,8 +178,23 @@ build_rtrees_classifier( const string& data_filename,
{
// create classifier by using <data> and <responses>
cout << "Training the classifier ...\n";
// Params( int maxDepth, int minSampleCount,
// double regressionAccuracy, bool useSurrogates,
// int maxCategories, const Mat& priors,
// bool calcVarImportance, int nactiveVars,
// TermCriteria termCrit );
Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
model = StatModel::train<RTrees>(tdata, RTrees::Params(10,10,0,false,15,Mat(),true,4,TC(100,0.01f)));
model = RTrees::create();
model->setMaxDepth(10);
model->setMinSampleCount(10);
model->setRegressionAccuracy(0);
model->setUseSurrogates(false);
model->setMaxCategories(15);
model->setPriors(Mat());
model->setCalculateVarImportance(true);
model->setActiveVarCount(4);
model->setTermCriteria(TC(100,0.01f));
model->train(tdata);
cout << endl;
}
@@ -269,7 +284,14 @@ build_boost_classifier( const string& data_filename,
priors[1] = 26;
cout << "Training the classifier (may take a few minutes)...\n";
model = StatModel::train<Boost>(tdata, Boost::Params(Boost::GENTLE, 100, 0.95, 5, false, Mat(priors) ));
model = Boost::create();
model->setBoostType(Boost::GENTLE);
model->setWeakCount(100);
model->setWeightTrimRate(0.95);
model->setMaxDepth(5);
model->setUseSurrogates(false);
model->setPriors(Mat(priors));
model->train(tdata);
cout << endl;
}
@@ -374,11 +396,11 @@ build_mlp_classifier( const string& data_filename,
Mat layer_sizes( 1, nlayers, CV_32S, layer_sz );
#if 1
int method = ANN_MLP::Params::BACKPROP;
int method = ANN_MLP::BACKPROP;
double method_param = 0.001;
int max_iter = 300;
#else
int method = ANN_MLP::Params::RPROP;
int method = ANN_MLP::RPROP;
double method_param = 0.1;
int max_iter = 1000;
#endif
@@ -386,7 +408,12 @@ build_mlp_classifier( const string& data_filename,
Ptr<TrainData> tdata = TrainData::create(train_data, ROW_SAMPLE, train_responses);
cout << "Training the classifier (may take a few minutes)...\n";
model = StatModel::train<ANN_MLP>(tdata, ANN_MLP::Params(layer_sizes, ANN_MLP::SIGMOID_SYM, 0, 0, TC(max_iter,0), method, method_param));
model = ANN_MLP::create();
model->setLayerSizes(layer_sizes);
model->setActivationFunction(ANN_MLP::SIGMOID_SYM, 0, 0);
model->setTermCriteria(TC(max_iter,0));
model->setTrainMethod(method, method_param);
model->train(tdata);
cout << endl;
}
@@ -403,7 +430,6 @@ build_knearest_classifier( const string& data_filename, int K )
if( !ok )
return ok;
Ptr<KNearest> model;
int nsamples_all = data.rows;
int ntrain_samples = (int)(nsamples_all*0.8);
@@ -411,7 +437,10 @@ build_knearest_classifier( const string& data_filename, int K )
// create classifier by using <data> and <responses>
cout << "Training the classifier ...\n";
Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
model = StatModel::train<KNearest>(tdata, KNearest::Params(K, true));
Ptr<KNearest> model = KNearest::create();
model->setDefaultK(K);
model->setIsClassifier(true);
model->train(tdata);
cout << endl;
test_and_save_classifier(model, data, responses, ntrain_samples, 0, string());
@@ -435,7 +464,8 @@ build_nbayes_classifier( const string& data_filename )
// create classifier by using <data> and <responses>
cout << "Training the classifier ...\n";
Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
model = StatModel::train<NormalBayesClassifier>(tdata, NormalBayesClassifier::Params());
model = NormalBayesClassifier::create();
model->train(tdata);
cout << endl;
test_and_save_classifier(model, data, responses, ntrain_samples, 0, string());
@@ -471,13 +501,11 @@ build_svm_classifier( const string& data_filename,
// create classifier by using <data> and <responses>
cout << "Training the classifier ...\n";
Ptr<TrainData> tdata = prepare_train_data(data, responses, ntrain_samples);
SVM::Params params;
params.svmType = SVM::C_SVC;
params.kernelType = SVM::LINEAR;
params.C = 1;
model = StatModel::train<SVM>(tdata, params);
model = SVM::create();
model->setType(SVM::C_SVC);
model->setKernel(SVM::LINEAR);
model->setC(1);
model->train(tdata);
cout << endl;
}