2010-11-26 18:59:40 +01:00
|
|
|
#include "opencv2/ml/ml.hpp"
|
2014-08-02 23:41:09 +02:00
|
|
|
#include "opencv2/core/core.hpp"
|
2013-03-28 18:01:12 +01:00
|
|
|
#include "opencv2/core/utility.hpp"
|
2010-05-11 19:44:00 +02:00
|
|
|
#include <stdio.h>
|
2014-08-02 23:41:09 +02:00
|
|
|
#include <string>
|
2011-04-05 17:13:10 +02:00
|
|
|
#include <map>
|
2010-12-03 03:49:09 +01:00
|
|
|
|
2014-08-02 23:41:09 +02:00
|
|
|
using namespace cv;
|
|
|
|
using namespace cv::ml;
|
|
|
|
|
2012-06-07 19:21:29 +02:00
|
|
|
static void help()
|
2010-12-03 03:49:09 +01:00
|
|
|
{
|
2012-06-07 19:21:29 +02:00
|
|
|
printf(
|
2014-08-02 23:41:09 +02:00
|
|
|
"\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
|
|
|
|
"Usage:\n\t./tree_engine [-r <response_column>] [-ts type_spec] <csv filename>\n"
|
2011-06-09 14:01:47 +02:00
|
|
|
"where -r <response_column> specified the 0-based index of the response (0 by default)\n"
|
2014-08-02 23:41:09 +02:00
|
|
|
"-ts specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
|
2011-06-09 14:01:47 +02:00
|
|
|
"<csv filename> is the name of training data file in comma-separated value format\n\n");
|
2010-12-03 03:49:09 +01:00
|
|
|
}
|
2011-04-05 17:13:10 +02:00
|
|
|
|
2014-08-02 23:41:09 +02:00
|
|
|
static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
|
2011-04-05 17:13:10 +02:00
|
|
|
{
|
2014-08-02 23:41:09 +02:00
|
|
|
bool ok = model->train(data);
|
|
|
|
if( !ok )
|
2011-04-05 17:13:10 +02:00
|
|
|
{
|
2014-08-02 23:41:09 +02:00
|
|
|
printf("Training failed\n");
|
2011-04-05 17:13:10 +02:00
|
|
|
}
|
2014-08-02 23:41:09 +02:00
|
|
|
else
|
2010-05-11 19:44:00 +02:00
|
|
|
{
|
2014-08-02 23:41:09 +02:00
|
|
|
printf( "train error: %f\n", model->calcError(data, false, noArray()) );
|
|
|
|
printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
|
2010-05-11 19:44:00 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2011-06-09 14:01:47 +02:00
|
|
|
int main(int argc, char** argv)
|
2010-05-11 19:44:00 +02:00
|
|
|
{
|
2011-06-09 14:01:47 +02:00
|
|
|
if(argc < 2)
|
2011-04-05 17:13:10 +02:00
|
|
|
{
|
|
|
|
help();
|
2011-06-09 14:01:47 +02:00
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
const char* filename = 0;
|
|
|
|
int response_idx = 0;
|
2014-08-02 23:41:09 +02:00
|
|
|
std::string typespec;
|
2012-06-07 19:21:29 +02:00
|
|
|
|
2011-06-09 14:01:47 +02:00
|
|
|
for(int i = 1; i < argc; i++)
|
|
|
|
{
|
|
|
|
if(strcmp(argv[i], "-r") == 0)
|
|
|
|
sscanf(argv[++i], "%d", &response_idx);
|
2014-08-02 23:41:09 +02:00
|
|
|
else if(strcmp(argv[i], "-ts") == 0)
|
|
|
|
typespec = argv[++i];
|
2011-06-09 14:01:47 +02:00
|
|
|
else if(argv[i][0] != '-' )
|
|
|
|
filename = argv[i];
|
|
|
|
else
|
|
|
|
{
|
|
|
|
printf("Error. Invalid option %s\n", argv[i]);
|
|
|
|
help();
|
|
|
|
return -1;
|
|
|
|
}
|
2011-04-05 17:13:10 +02:00
|
|
|
}
|
2012-06-07 19:21:29 +02:00
|
|
|
|
2011-06-09 14:01:47 +02:00
|
|
|
printf("\nReading in %s...\n\n",filename);
|
2014-08-02 23:41:09 +02:00
|
|
|
const double train_test_split_ratio = 0.5;
|
2010-05-11 19:44:00 +02:00
|
|
|
|
2014-08-02 23:41:09 +02:00
|
|
|
Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
|
2012-06-07 19:21:29 +02:00
|
|
|
|
2014-08-02 23:41:09 +02:00
|
|
|
if( data.empty() )
|
2010-05-11 19:44:00 +02:00
|
|
|
{
|
2014-08-02 23:41:09 +02:00
|
|
|
printf("ERROR: File %s can not be read\n", filename);
|
|
|
|
return 0;
|
|
|
|
}
|
2010-05-11 19:44:00 +02:00
|
|
|
|
2014-08-02 23:41:09 +02:00
|
|
|
data->setTrainTestSplitRatio(train_test_split_ratio);
|
2010-05-11 19:44:00 +02:00
|
|
|
|
2014-08-02 23:41:09 +02:00
|
|
|
printf("======DTREE=====\n");
|
|
|
|
Ptr<DTrees> dtree = DTrees::create(DTrees::Params( 10, 2, 0, false, 16, 0, false, false, Mat() ));
|
|
|
|
train_and_print_errs(dtree, data);
|
2010-10-17 21:18:42 +02:00
|
|
|
|
2014-08-02 23:41:09 +02:00
|
|
|
if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
|
|
|
|
{
|
|
|
|
printf("======BOOST=====\n");
|
|
|
|
Ptr<Boost> boost = Boost::create(Boost::Params(Boost::GENTLE, 100, 0.95, 2, false, Mat()));
|
|
|
|
train_and_print_errs(boost, data);
|
2010-05-11 19:44:00 +02:00
|
|
|
}
|
2014-08-02 23:41:09 +02:00
|
|
|
|
|
|
|
printf("======RTREES=====\n");
|
|
|
|
Ptr<RTrees> rtrees = RTrees::create(RTrees::Params(10, 2, 0, false, 16, Mat(), false, 0, TermCriteria(TermCriteria::MAX_ITER, 100, 0)));
|
|
|
|
train_and_print_errs(rtrees, data);
|
2010-05-11 19:44:00 +02:00
|
|
|
|
|
|
|
return 0;
|
2010-07-08 13:24:32 +02:00
|
|
|
}
|