improved tree_engine.cpp sample (added train file data specification; print sorted variable importance table)

This commit is contained in:
Vadim Pisarevsky 2011-04-05 15:13:10 +00:00
parent ce474db8eb
commit bbdd0aecbd

View File

@ -1,5 +1,7 @@
#include "opencv2/ml/ml.hpp" #include "opencv2/ml/ml.hpp"
#include "opencv2/core/core_c.h"
#include <stdio.h> #include <stdio.h>
#include <map>
void help() void help()
{ {
@ -10,41 +12,81 @@ void help()
"CvRTrees rtrees;\n" "CvRTrees rtrees;\n"
"CvERTrees ertrees;\n" "CvERTrees ertrees;\n"
"CvGBTrees gbtrees;\n" "CvGBTrees gbtrees;\n"
"Date is hard coded to come from filename = \"../../../opencv/samples/c/waveform.data\";\n" "Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n"
"Or can come from filename = \"../../../opencv/samples/c/waveform.data\";\n" "where -r <response_column> specified the 0-based index of the response (0 by default)\n"
"Call:\n" "-c specifies that the response is categorical (it's ordered by default) and\n"
"./tree_engine\n\n"); "<csv filename> is the name of training data file in comma-separated value format\n\n");
} }
void print_result(float train_err, float test_err, const CvMat* var_imp)
int count_classes(CvMLData& data)
{
cv::Mat r(data.get_responses());
std::map<int, int> rmap;
int i, n = (int)r.total();
for( i = 0; i < n; i++ )
{
float val = r.at<float>(i);
int ival = cvRound(val);
if( ival != val )
return -1;
rmap[ival] = 1;
}
return rmap.size();
}
void print_result(float train_err, float test_err, const CvMat* _var_imp)
{ {
printf( "train error %f\n", train_err ); printf( "train error %f\n", train_err );
printf( "test error %f\n\n", test_err ); printf( "test error %f\n\n", test_err );
if (var_imp) if (_var_imp)
{ {
bool is_flt = false; cv::Mat var_imp(_var_imp), sorted_idx;
if ( CV_MAT_TYPE( var_imp->type ) == CV_32FC1) cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
is_flt = true;
printf( "variable impotance\n" ); printf( "variable importance:\n" );
for( int i = 0; i < var_imp->cols; i++) int i, n = (int)var_imp.total();
int type = var_imp.type();
CV_Assert(type == CV_32F || type == CV_64F);
for( i = 0; i < n; i++)
{ {
printf( "%d %f\n", i, is_flt ? var_imp->data.fl[i] : var_imp->data.db[i] ); int k = sorted_idx.at<int>(i);
printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k));
} }
} }
printf("\n"); printf("\n");
} }
int main() int main(int argc, char** argv)
{ {
const int train_sample_count = 300; if(argc < 2)
{
#define LEPIOTA //Turn on discrete data set help();
#ifdef LEPIOTA //Of course, you might have to set the path here to what's on your machine ... return 0;
const char* filename = "../../opencv/samples/c/agaricus-lepiota.data"; }
#else const char* filename = 0;
const char* filename = "../../opencv/samples/c/waveform.data"; int response_idx = 0;
#endif bool categorical_response = false;
printf("\n Reading in %s. If it is not found, you may have to change this hard-coded path in tree_engine.cpp\n\n",filename);
for(int i = 1; i < argc; i++)
{
if(strcmp(argv[i], "-r") == 0)
sscanf(argv[++i], "%d", &response_idx);
else if(strcmp(argv[i], "-c") == 0)
categorical_response = true;
else if(argv[i][0] != '-' )
filename = argv[i];
else
{
printf("Error. Invalid option %s\n", argv[i]);
help();
return -1;
}
}
printf("\nReading in %s...\n\n",filename);
CvDTree dtree; CvDTree dtree;
CvBoost boost; CvBoost boost;
CvRTrees rtrees; CvRTrees rtrees;
@ -53,29 +95,26 @@ int main()
CvMLData data; CvMLData data;
CvTrainTestSplit spl( train_sample_count );
CvTrainTestSplit spl( 0.5f );
if ( data.read_csv( filename ) == 0) if ( data.read_csv( filename ) == 0)
{ {
data.set_response_idx( response_idx );
#ifdef LEPIOTA if(categorical_response)
data.set_response_idx( 0 ); data.change_var_type( response_idx, CV_VAR_CATEGORICAL );
#else
data.set_response_idx( 21 );
data.change_var_type( 21, CV_VAR_CATEGORICAL );
#endif
data.set_train_test_split( &spl ); data.set_train_test_split( &spl );
printf("======DTREE=====\n"); printf("======DTREE=====\n");
dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 )); dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() ); print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
#ifdef LEPIOTA if( categorical_response && count_classes(data) == 2 )
{
printf("======BOOST=====\n"); printf("======BOOST=====\n");
boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0)); boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
#endif }
printf("======RTREES=====\n"); printf("======RTREES=====\n");
rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER )); rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));