removed contrib, legacy and softcsscade modules; removed latentsvm and datamatrix detector from objdetect. removed haartraining and sft apps.
some of the stuff will be moved to opencv_contrib module. in order to make this PR pass buildbot, please, comment off opencv_legacy, opencv_contrib and opencv_softcascade test runs.
This commit is contained in:
139
samples/cpp/tree_engine.cpp
Normal file
139
samples/cpp/tree_engine.cpp
Normal file
@@ -0,0 +1,139 @@
|
||||
#include "opencv2/ml/ml.hpp"
|
||||
#include "opencv2/core/core_c.h"
|
||||
#include "opencv2/core/utility.hpp"
|
||||
#include <stdio.h>
|
||||
#include <map>
|
||||
|
||||
static void help()
|
||||
{
|
||||
printf(
|
||||
"\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees:\n"
|
||||
"CvDTree dtree;\n"
|
||||
"CvBoost boost;\n"
|
||||
"CvRTrees rtrees;\n"
|
||||
"CvERTrees ertrees;\n"
|
||||
"CvGBTrees gbtrees;\n"
|
||||
"Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n"
|
||||
"where -r <response_column> specified the 0-based index of the response (0 by default)\n"
|
||||
"-c specifies that the response is categorical (it's ordered by default) and\n"
|
||||
"<csv filename> is the name of training data file in comma-separated value format\n\n");
|
||||
}
|
||||
|
||||
|
||||
static int count_classes(CvMLData& data)
|
||||
{
|
||||
cv::Mat r = cv::cvarrToMat(data.get_responses());
|
||||
std::map<int, int> rmap;
|
||||
int i, n = (int)r.total();
|
||||
for( i = 0; i < n; i++ )
|
||||
{
|
||||
float val = r.at<float>(i);
|
||||
int ival = cvRound(val);
|
||||
if( ival != val )
|
||||
return -1;
|
||||
rmap[ival] = 1;
|
||||
}
|
||||
return (int)rmap.size();
|
||||
}
|
||||
|
||||
static void print_result(float train_err, float test_err, const CvMat* _var_imp)
|
||||
{
|
||||
printf( "train error %f\n", train_err );
|
||||
printf( "test error %f\n\n", test_err );
|
||||
|
||||
if (_var_imp)
|
||||
{
|
||||
cv::Mat var_imp = cv::cvarrToMat(_var_imp), sorted_idx;
|
||||
cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
|
||||
|
||||
printf( "variable importance:\n" );
|
||||
int i, n = (int)var_imp.total();
|
||||
int type = var_imp.type();
|
||||
CV_Assert(type == CV_32F || type == CV_64F);
|
||||
|
||||
for( i = 0; i < n; i++)
|
||||
{
|
||||
int k = sorted_idx.at<int>(i);
|
||||
printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k));
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
if(argc < 2)
|
||||
{
|
||||
help();
|
||||
return 0;
|
||||
}
|
||||
const char* filename = 0;
|
||||
int response_idx = 0;
|
||||
bool categorical_response = false;
|
||||
|
||||
for(int i = 1; i < argc; i++)
|
||||
{
|
||||
if(strcmp(argv[i], "-r") == 0)
|
||||
sscanf(argv[++i], "%d", &response_idx);
|
||||
else if(strcmp(argv[i], "-c") == 0)
|
||||
categorical_response = true;
|
||||
else if(argv[i][0] != '-' )
|
||||
filename = argv[i];
|
||||
else
|
||||
{
|
||||
printf("Error. Invalid option %s\n", argv[i]);
|
||||
help();
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
printf("\nReading in %s...\n\n",filename);
|
||||
CvDTree dtree;
|
||||
CvBoost boost;
|
||||
CvRTrees rtrees;
|
||||
CvERTrees ertrees;
|
||||
CvGBTrees gbtrees;
|
||||
|
||||
CvMLData data;
|
||||
|
||||
|
||||
CvTrainTestSplit spl( 0.5f );
|
||||
|
||||
if ( data.read_csv( filename ) == 0)
|
||||
{
|
||||
data.set_response_idx( response_idx );
|
||||
if(categorical_response)
|
||||
data.change_var_type( response_idx, CV_VAR_CATEGORICAL );
|
||||
data.set_train_test_split( &spl );
|
||||
|
||||
printf("======DTREE=====\n");
|
||||
dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
|
||||
print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
|
||||
|
||||
if( categorical_response && count_classes(data) == 2 )
|
||||
{
|
||||
printf("======BOOST=====\n");
|
||||
boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
|
||||
print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
|
||||
}
|
||||
|
||||
printf("======RTREES=====\n");
|
||||
rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
|
||||
print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() );
|
||||
|
||||
printf("======ERTREES=====\n");
|
||||
ertrees.train( &data, CvRTParams( 18, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));
|
||||
print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() );
|
||||
|
||||
printf("======GBTREES=====\n");
|
||||
if (categorical_response)
|
||||
gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.1f, 0.8f, 5, false));
|
||||
else
|
||||
gbtrees.train( &data, CvGBTreesParams(CvGBTrees::SQUARED_LOSS, 100, 0.1f, 0.8f, 5, false));
|
||||
print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
|
||||
}
|
||||
else
|
||||
printf("File can not be read");
|
||||
|
||||
return 0;
|
||||
}
|
Reference in New Issue
Block a user