#include "opencv2/core/core.hpp" #include "opencv2/ml/ml.hpp" #include "opencv2/core/core_c.h" #include #include using namespace std; using namespace cv; void help() { printf( "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees:\n" "CvDTree dtree;\n" "CvBoost boost;\n" "CvRTrees rtrees;\n" "CvERTrees ertrees;\n" "CvGBTrees gbtrees;\n" "Usage: \n" " ./tree_engine [--response_column]= \n" "[--categorical_response]= \n" "[--csv_filename]= \n" ); } int count_classes(CvMLData& data) { cv::Mat r(data.get_responses()); std::map rmap; int i, n = (int)r.total(); for( i = 0; i < n; i++ ) { float val = r.at(i); int ival = cvRound(val); if( ival != val ) return -1; rmap[ival] = 1; } return rmap.size(); } void print_result(float train_err, float test_err, const CvMat* _var_imp) { printf( "train error %f\n", train_err ); printf( "test error %f\n\n", test_err ); if (_var_imp) { cv::Mat var_imp(_var_imp), sorted_idx; cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING); printf( "variable importance:\n" ); int i, n = (int)var_imp.total(); int type = var_imp.type(); CV_Assert(type == CV_32F || type == CV_64F); for( i = 0; i < n; i++) { int k = sorted_idx.at(i); printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at(k) : var_imp.at(k)); } } printf("\n"); } int main(int argc, const char** argv) { help(); CommandLineParser parser(argc, argv); string filename = parser.get("csv_filename"); int response_idx = parser.get("response_column", 0); bool categorical_response = (bool)parser.get("categorical_response", 1); if(filename.empty()) { printf("\n Please, select value for --csv_filename key \n"); help(); return -1; } printf("\nReading in %s...\n\n",filename.c_str()); CvDTree dtree; CvBoost boost; CvRTrees rtrees; CvERTrees ertrees; CvGBTrees gbtrees; CvMLData data; CvTrainTestSplit spl( 0.5f ); if ( data.read_csv( filename.c_str() ) == 0) { data.set_response_idx( response_idx ); if(categorical_response) data.change_var_type( response_idx, CV_VAR_CATEGORICAL ); data.set_train_test_split( &spl ); printf("======DTREE=====\n"); dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 )); print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() ); if( categorical_response && count_classes(data) == 2 ) { printf("======BOOST=====\n"); boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0)); print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance } printf("======RTREES=====\n"); rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER )); print_result( rtrees.calc_error( &data, CV_TRAIN_ERROR), rtrees.calc_error( &data, CV_TEST_ERROR ), rtrees.get_var_importance() ); printf("======ERTREES=====\n"); ertrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER )); print_result( ertrees.calc_error( &data, CV_TRAIN_ERROR), ertrees.calc_error( &data, CV_TEST_ERROR ), ertrees.get_var_importance() ); printf("======GBTREES=====\n"); gbtrees.train( &data, CvGBTreesParams(CvGBTrees::DEVIANCE_LOSS, 100, 0.05f, 0.6f, 10, true)); print_result( gbtrees.calc_error( &data, CV_TRAIN_ERROR), gbtrees.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance } else printf("File can not be read"); return 0; }