Gradient Boosting Trees (CvGBTrees) added to opencv mll. Test for all CvGBTrees public methods added.

This commit is contained in:
P. Druzhkov 2010-10-13 20:18:12 +00:00
parent d788007672
commit d611fb61fc
3 changed files with 1844 additions and 0 deletions

View File

@ -183,6 +183,7 @@ CV_INLINE CvParamLattice cvDefaultParamLattice( void )
#define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
#define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
#define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
#define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
#define CV_TRAIN_ERROR 0
#define CV_TEST_ERROR 1
@ -1359,6 +1360,532 @@ protected:
};
/****************************************************************************************\
* Gradient Boosted Trees *
\****************************************************************************************/
// DataType: STRUCT CvGBTreesParams
// Parameters of GBT (Gradient Boosted trees model), including single
// tree settings and ensemble parameters.
//
// weak_count - count of trees in the ensemble
// loss_function_type - loss function used for ensemble training
// subsample_portion - portion of whole training set used for
// every single tree training.
// subsample_portion value is in (0.0, 1.0].
// subsample_portion == 1.0 when whole dataset is
// used on each step. Count of sample used on each
// step is computed as
// int(total_samples_count * subsample_portion).
// shrinkage - regularization parameter.
// Each tree prediction is multiplied on shrinkage value.
struct CV_EXPORTS CvGBTreesParams : public CvDTreeParams
{
int weak_count;
int loss_function_type;
float subsample_portion;
float shrinkage;
CvGBTreesParams();
CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
float subsample_portion, int max_depth, bool use_surrogates );
};
// DataType: CLASS CvGBTrees
// Gradient Boosting Trees (GBT) algorithm implementation.
//
// data - training dataset
// params - parameters of the CvGBTrees
// weak - array[0..(class_count-1)] of CvSeq
// for storing tree ensembles
// orig_response - original responses of the training set samples
// sum_response - predicitons of the current model on the training dataset.
// this matrix is updated on every iteration.
// sum_response_tmp - predicitons of the model on the training set on the next
// step. On every iteration values of sum_responses_tmp are
// computed via sum_responses values. When the current
// step is complete sum_response values become equal to
// sum_responses_tmp.
// sample_idx - indices of samples used for training the ensemble.
// CvGBTrees training procedure takes a set of samples
// (train_data) and a set of responses (responses).
// Only pairs (train_data[i], responses[i]), where i is
// in sample_idx are used for training the ensemble.
// subsample_train - indices of samples used for training a single decision
// tree on the current step. This indices are countered
// relatively to the sample_idx, so that pairs
// (train_data[sample_idx[i]], responses[sample_idx[i]])
// are used for training a decision tree.
// Training set is randomly splited
// in two parts (subsample_train and subsample_test)
// on every iteration accordingly to the portion parameter.
// subsample_test - relative indices of samples from the training set,
// which are not used for training a tree on the current
// step.
// missing - mask of the missing values in the training set. This
// matrix has the same size as train_data. 1 - missing
// value, 0 - not a missing value.
// class_labels - output class labels map.
// rng - random number generator. Used for spliting the
// training set.
// class_count - count of output classes.
// class_count == 1 in the case of regression,
// and > 1 in the case of classification.
// delta - Huber loss function parameter.
// base_value - start point of the gradient descent procedure.
// model prediction is
// f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
// f_0 is the base value.
class CV_EXPORTS CvGBTrees : public CvStatModel
{
public:
/*
// DataType: ENUM
// Loss functions implemented in CvGBTrees.
//
// SQUARED_LOSS
// problem: regression
// loss = (x - x')^2
//
// ABSOLUTE_LOSS
// problem: regression
// loss = abs(x - x')
//
// HUBER_LOSS
// problem: regression
// loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
// 1/2*(x - x')^2, if abs(x - x') <= delta,
// where delta is the alpha-quantile of pseudo responses from
// the training set.
//
// DEVIANCE_LOSS
// problem: classification
//
*/
enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
/*
// Default constructor. Creates a model only (without training).
// Should be followed by one form of the train(...) function.
//
// API
// CvGBTrees();
// INPUT
// OUTPUT
// RESULT
*/
CvGBTrees();
/*
// Full form constructor. Creates a gradient boosting model and does the
// train.
//
// API
// CvGBTrees( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx=0,
const CvMat* _sample_idx=0, const CvMat* _var_type=0,
const CvMat* _missing_mask=0,
CvGBTreesParams params=CvGBTreesParams() );
// INPUT
// _train_data - a set of input feature vectors.
// size of matrix is
// <count of samples> x <variables count>
// or <variables count> x <count of samples>
// depending on the _tflag parameter.
// matrix values are float.
// _tflag - a flag showing how do samples stored in the
// _train_data matrix row by row (tflag=CV_ROW_SAMPLE)
// or column by column (tflag=CV_COL_SAMPLE).
// _responses - a vector of responses corresponding to the samples
// in _train_data.
// _var_idx - indices of used variables. zero value means that all
// variables are active.
// _sample_idx - indices of used samples. zero value means that all
// samples from _train_data are in the training set.
// _var_type - vector of <variables count> length. gives every
// variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
// _var_type = 0 means all variables are numerical.
// _missing_mask - a mask of misiing values in _train_data.
// _missing_mask = 0 means that there are no missing
// values.
// params - parameters of GTB algorithm.
// OUTPUT
// RESULT
*/
CvGBTrees( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx=0,
const CvMat* _sample_idx=0, const CvMat* _var_type=0,
const CvMat* _missing_mask=0,
CvGBTreesParams params=CvGBTreesParams() );
/*
// Destructor.
*/
virtual ~CvGBTrees();
/*
// Gradient tree boosting model training
//
// API
// virtual bool train( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx=0,
const CvMat* _sample_idx=0, const CvMat* _var_type=0,
const CvMat* _missing_mask=0,
CvGBTreesParams params=CvGBTreesParams(),
bool update=false );
// INPUT
// _train_data - a set of input feature vectors.
// size of matrix is
// <count of samples> x <variables count>
// or <variables count> x <count of samples>
// depending on the _tflag parameter.
// matrix values are float.
// _tflag - a flag showing how do samples stored in the
// _train_data matrix row by row (tflag=CV_ROW_SAMPLE)
// or column by column (tflag=CV_COL_SAMPLE).
// _responses - a vector of responses corresponding to the samples
// in _train_data.
// _var_idx - indices of used variables. zero value means that all
// variables are active.
// _sample_idx - indices of used samples. zero value means that all
// samples from _train_data are in the training set.
// _var_type - vector of <variables count> length. gives every
// variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
// _var_type = 0 means all variables are numerical.
// _missing_mask - a mask of misiing values in _train_data.
// _missing_mask = 0 means that there are no missing
// values.
// params - parameters of GTB algorithm.
// update - is not supported now. (!)
// OUTPUT
// RESULT
// Error state.
*/
virtual bool train( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx=0,
const CvMat* _sample_idx=0, const CvMat* _var_type=0,
const CvMat* _missing_mask=0,
CvGBTreesParams params=CvGBTreesParams(),
bool update=false );
/*
// Gradient tree boosting model training
//
// API
// virtual bool train( CvMLData* data,
CvGBTreesParams params=CvGBTreesParams(),
bool update=false ) {return false;};
// INPUT
// data - training set.
// params - parameters of GTB algorithm.
// update - is not supported now. (!)
// OUTPUT
// RESULT
// Error state.
*/
virtual bool train( CvMLData* data,
CvGBTreesParams params=CvGBTreesParams(),
bool update=false );
/*
// Response value prediction
//
// API
// virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
// INPUT
// _sample - input sample of the same type as in the training set.
// _missing - missing values mask. _missing=0 if there are no
// missing values in _sample vector.
// weak_responses - predictions of all of the trees.
// not implemented (!)
// slice - part of the ensemble used for prediction.
// slice = CV_WHOLE_SEQ when all trees are used.
// k - number of ensemble used.
// k is in {-1,0,1,..,<count of output classes-1>}.
// in the case of classification problem
// <count of output classes-1> ensembles are built.
// If k = -1 ordinary prediction is the result,
// otherwise function gives the prediction of the
// k-th ensemble only.
// OUTPUT
// RESULT
// Predicted value.
*/
virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
/*
// Delete all temporary data.
//
// API
// virtual void clear();
// INPUT
// OUTPUT
// delete data, weak, orig_response, sum_response,
// weak_eval, ubsample_train, subsample_test,
// sample_idx, missing, lass_labels
// delta = 0.0
// RESULT
*/
virtual void clear();
/*
// Compute error on the train/test set.
//
// API
// virtual float calc_error( CvMLData* _data, int type,
// std::vector<float> *resp = 0 );
//
// INPUT
// data - dataset
// type - defines which error is to compute^ train (CV_TRAIN_ERROR) or
// test (CV_TEST_ERROR).
// OUTPUT
// resp - vector of predicitons
// RESULT
// Error value.
*/
virtual float calc_error( CvMLData* _data, int type,
std::vector<float> *resp = 0 );
/*
//
// Write parameters of the gtb model and data. Write learned model.
//
// API
// virtual void write( CvFileStorage* fs, const char* name ) const;
//
// INPUT
// fs - file storage to read parameters from.
// name - model name.
// OUTPUT
// RESULT
*/
virtual void write( CvFileStorage* fs, const char* name ) const;
/*
//
// Read parameters of the gtb model and data. Read learned model.
//
// API
// virtual void read( CvFileStorage* fs, CvFileNode* node );
//
// INPUT
// fs - file storage to read parameters from.
// node - file node.
// OUTPUT
// RESULT
*/
virtual void read( CvFileStorage* fs, CvFileNode* node );
protected:
/*
// Compute the gradient vector components.
//
// API
// virtual void find_gradient( const int k = 0);
// INPUT
// k - used for classification problem, determining current
// tree ensemble.
// OUTPUT
// changes components of data->responses
// which correspond to samples used for training
// on the current step.
// RESULT
*/
virtual void find_gradient( const int k = 0);
/*
//
// Change values in tree leaves according to the used loss function.
//
// API
// virtual void change_values(CvDTree* tree, const int k = 0);
//
// INPUT
// tree - decision tree to change.
// k - used for classification problem, determining current
// tree ensemble.
// OUTPUT
// changes 'value' fields of the trees' leaves.
// changes sum_response_tmp.
// RESULT
*/
virtual void change_values(CvDTree* tree, const int k = 0);
/*
//
// Find optimal constant prediction value according to the used loss
// function.
// The goal is to find a constant which gives the minimal summary loss
// on the _Idx samples.
//
// API
// virtual float find_optimal_value( const CvMat* _Idx );
//
// INPUT
// _Idx - indices of the samples from the training set.
// OUTPUT
// RESULT
// optimal constant value.
*/
virtual float find_optimal_value( const CvMat* _Idx );
/*
//
// Randomly split the whole training set in two parts according
// to params.portion.
//
// API
// virtual void do_subsample();
//
// INPUT
// OUTPUT
// subsample_train - indices of samples used for training
// subsample_test - indices of samples used for test
// RESULT
*/
virtual void do_subsample();
/*
//
// Internal recursive function giving an array of subtree tree leaves.
//
// API
// void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
//
// INPUT
// node - current leaf.
// OUTPUT
// count - count of leaves in the subtree.
// leaves - array of pointers to leaves.
// RESULT
*/
void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
/*
//
// Get leaves of the tree.
//
// API
// CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
//
// INPUT
// dtree - decision tree.
// OUTPUT
// len - count of the leaves.
// RESULT
// CvDTreeNode** - array of pointers to leaves.
*/
CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
/*
//
// Is it a regression or a classification.
//
// API
// bool problem_type();
//
// INPUT
// OUTPUT
// RESULT
// false if it is a classification problem,
// true - if regression.
*/
virtual bool problem_type() const;
/*
//
// Write parameters of the gtb model.
//
// API
// virtual void write_params( CvFileStorage* fs ) const;
//
// INPUT
// fs - file storage to write parameters to.
// OUTPUT
// RESULT
*/
virtual void write_params( CvFileStorage* fs ) const;
/*
//
// Read parameters of the gtb model and data.
//
// API
// virtual void read_params( CvFileStorage* fs );
//
// INPUT
// fs - file storage to read parameters from.
// OUTPUT
// params - parameters of the gtb model.
// data - contains information about the structure
// of the data set (count of variables,
// their types, etc.).
// class_labels - output class labels map.
// RESULT
*/
virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
CvDTreeTrainData* data;
CvGBTreesParams params;
CvSeq** weak;
CvMat* orig_response;
CvMat* sum_response;
CvMat* sum_response_tmp;
CvMat* weak_eval;
CvMat* sample_idx;
CvMat* subsample_train;
CvMat* subsample_test;
CvMat* missing;
CvMat* class_labels;
CvRNG rng;
int class_count;
float delta;
float base_value;
};
/****************************************************************************************\
* Artificial Neural Networks (ANN) *
\****************************************************************************************/
@ -1936,6 +2463,8 @@ typedef CvBoostTree BoostTree;
typedef CvBoost Boost;
typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
typedef CvANN_MLP NeuralNet_MLP;
typedef CvGBTreesParams GradientBoostingTreesParams;
typedef CvGBTrees GradientBoostingTrees;
}

1044
modules/ml/src/gbt.cpp Normal file

File diff suppressed because it is too large Load Diff

271
tests/ml/src/gbttest.cpp Normal file
View File

@ -0,0 +1,271 @@
#include "mltest.h"
#include <string>
#include <fstream>
#include <iostream>
using namespace std;
class CV_GBTreesTest : public CvTest
{
public:
CV_GBTreesTest();
~CV_GBTreesTest();
protected:
void run(int);
int TestTrainPredict(int test_num);
int TestSaveLoad();
int checkPredictError(int test_num);
int checkLoadSave();
//string model_file_name1;
//string model_file_name2;
char model_file_name1[50];
char model_file_name2[50];
string* datasets;
string data_path;
CvMLData* data;
CvGBTrees* gtb;
vector<float> test_resps1;
vector<float> test_resps2;
};
int _get_len(const CvMat* mat)
{
return (mat->cols > mat->rows) ? mat->cols : mat->rows;
}
CV_GBTreesTest::CV_GBTreesTest() :
CvTest( "CvGBTrees_test",
"all public methods (train, predict, save, load)" )
{
datasets = 0;
data = 0;
gtb = 0;
}
CV_GBTreesTest::~CV_GBTreesTest()
{
if (data)
delete data;
delete[] datasets;
}
int CV_GBTreesTest::TestTrainPredict(int test_num)
{
int code = CvTS::OK;
int weak_count = 200;
float shrinkage = 0.1f;
float subsample_portion = 0.5f;
int max_depth = 5;
bool use_surrogates = true;
int loss_function_type = 0;
switch (test_num)
{
case (1) : loss_function_type = CvGBTrees::SQUARED_LOSS; break;
case (2) : loss_function_type = CvGBTrees::ABSOLUTE_LOSS; break;
case (3) : loss_function_type = CvGBTrees::HUBER_LOSS; break;
case (0) : loss_function_type = CvGBTrees::DEVIANCE_LOSS; break;
default :
{
ts->printf( CvTS::LOG, "Bad test_num value in CV_GBTreesTest::TestTrainPredict(..) function." );
return CvTS::FAIL_BAD_ARG_CHECK;
}
}
int dataset_num = test_num == 0 ? 0 : 1;
if (!data)
{
data = new CvMLData();
data->set_delimiter(',');
if (data->read_csv(datasets[dataset_num].c_str()))
{
ts->printf( CvTS::LOG, "File reading error." );
return CvTS::FAIL_INVALID_TEST_DATA;
}
if (test_num == 0)
{
data->set_response_idx(57);
data->set_var_types("ord[0-56],cat[57]");
}
else
{
data->set_response_idx(13);
data->set_var_types("ord[0-2,4-13],cat[3]");
subsample_portion = 0.7f;
}
int train_sample_count = cvFloor(_get_len(data->get_responses())*0.5f);
CvTrainTestSplit spl( train_sample_count );
data->set_train_test_split( &spl );
}
data->mix_train_and_test_idx();
if (gtb) delete gtb;
gtb = new CvGBTrees();
bool tmp_code = true;
tmp_code = gtb->train(data, CvGBTreesParams(loss_function_type, weak_count,
shrinkage, subsample_portion,
max_depth, use_surrogates));
if (!tmp_code)
{
ts->printf( CvTS::LOG, "Model training was failed.");
return CvTS::FAIL_INVALID_OUTPUT;
}
code = checkPredictError(test_num);
return code;
}
int CV_GBTreesTest::checkPredictError(int test_num)
{
if (!gtb)
return CvTS::FAIL_GENERIC;
float mean[] = {5.3555f, 11.2241f, 11.9212f, 12.0848f};
float sigma[] = {0.362127f, 3.4906f, 3.4906f, 3.64994f};
float current_error = gtb->calc_error(data, CV_TEST_ERROR);
if ( abs( current_error - mean[test_num]) > 6*sigma[test_num] )
{
ts->printf( CvTS::LOG, "Test error is out of range:\n"
"abs(%f/*curEr*/ - %f/*mean*/ > %f/*6*sigma*/",
current_error, mean[test_num], 6*sigma[test_num] );
return CvTS::FAIL_BAD_ACCURACY;
}
return CvTS::OK;
}
int CV_GBTreesTest::TestSaveLoad()
{
if (!gtb)
return CvTS::FAIL_GENERIC;
tmpnam(model_file_name1);
tmpnam(model_file_name2);
gtb->save(model_file_name1);
gtb->calc_error(data, CV_TEST_ERROR, &test_resps1);
gtb->load(model_file_name1);
gtb->calc_error(data, CV_TEST_ERROR, &test_resps2);
gtb->save(model_file_name2);
return checkLoadSave();
}
int CV_GBTreesTest::checkLoadSave()
{
int code = CvTS::OK;
// 1. compare files
ifstream f1( model_file_name1 ), f2( model_file_name2 );
string s1, s2;
int lineIdx = 0;
CV_Assert( f1.is_open() && f2.is_open() );
for( ; !f1.eof() && !f2.eof(); lineIdx++ )
{
getline( f1, s1 );
getline( f2, s2 );
if( s1.compare(s2) )
{
ts->printf( CvTS::LOG, "first and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
code = CvTS::FAIL_INVALID_OUTPUT;
}
}
if( !f1.eof() || !f2.eof() )
{
ts->printf( CvTS::LOG, "First and second saved files differ in %n-line; first %n line: %s; second %n-line: %s",
lineIdx, lineIdx, s1.c_str(), lineIdx, s2.c_str() );
code = CvTS::FAIL_INVALID_OUTPUT;
}
f1.close();
f2.close();
// delete temporary files
remove( model_file_name1 );
remove( model_file_name2 );
// 2. compare responses
CV_Assert( test_resps1.size() == test_resps2.size() );
vector<float>::const_iterator it1 = test_resps1.begin(), it2 = test_resps2.begin();
for( ; it1 != test_resps1.end(); ++it1, ++it2 )
{
if( fabs(*it1 - *it2) > FLT_EPSILON )
{
ts->printf( CvTS::LOG, "Responses predicted before saving and after loading are different" );
code = CvTS::FAIL_INVALID_OUTPUT;
}
}
return code;
}
void CV_GBTreesTest::run(int)
{
string data_path = string(ts->get_data_path());
datasets = new string[2];
datasets[0] = data_path + string("spambase.data"); /*string("dataset_classification.csv");*/
datasets[1] = data_path + string("housing_.data"); /*string("dataset_regression.csv");*/
int code = CvTS::OK;
for (int i = 0; i < 4; i++)
{
int temp_code = TestTrainPredict(i);
if (temp_code != CvTS::OK)
{
code = temp_code;
break;
}
else if (i==0)
{
temp_code = TestSaveLoad();
if (temp_code != CvTS::OK)
code = temp_code;
delete data;
data = 0;
}
delete gtb;
gtb = 0;
}
delete data;
data = 0;
ts->set_failed_test_info( code );
}
/////////////////////////////////////////////////////////////////////////////
//////////////////// test registration /////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////
CV_GBTreesTest gbtrees_test;