Gradient Boosting Trees (CvGBTrees) added to opencv mll. Test for all CvGBTrees public methods added.

This commit is contained in:
P. Druzhkov
2010-10-13 20:18:12 +00:00
parent d788007672
commit d611fb61fc
3 changed files with 1844 additions and 0 deletions

View File

@@ -183,6 +183,7 @@ CV_INLINE CvParamLattice cvDefaultParamLattice( void )
#define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
#define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
#define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
#define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
#define CV_TRAIN_ERROR 0
#define CV_TEST_ERROR 1
@@ -1359,6 +1360,532 @@ protected:
};
/****************************************************************************************\
* Gradient Boosted Trees *
\****************************************************************************************/
// DataType: STRUCT CvGBTreesParams
// Parameters of GBT (Gradient Boosted trees model), including single
// tree settings and ensemble parameters.
//
// weak_count - count of trees in the ensemble
// loss_function_type - loss function used for ensemble training
// subsample_portion - portion of whole training set used for
// every single tree training.
// subsample_portion value is in (0.0, 1.0].
// subsample_portion == 1.0 when whole dataset is
// used on each step. Count of sample used on each
// step is computed as
// int(total_samples_count * subsample_portion).
// shrinkage - regularization parameter.
// Each tree prediction is multiplied on shrinkage value.
struct CV_EXPORTS CvGBTreesParams : public CvDTreeParams
{
int weak_count;
int loss_function_type;
float subsample_portion;
float shrinkage;
CvGBTreesParams();
CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
float subsample_portion, int max_depth, bool use_surrogates );
};
// DataType: CLASS CvGBTrees
// Gradient Boosting Trees (GBT) algorithm implementation.
//
// data - training dataset
// params - parameters of the CvGBTrees
// weak - array[0..(class_count-1)] of CvSeq
// for storing tree ensembles
// orig_response - original responses of the training set samples
// sum_response - predicitons of the current model on the training dataset.
// this matrix is updated on every iteration.
// sum_response_tmp - predicitons of the model on the training set on the next
// step. On every iteration values of sum_responses_tmp are
// computed via sum_responses values. When the current
// step is complete sum_response values become equal to
// sum_responses_tmp.
// sample_idx - indices of samples used for training the ensemble.
// CvGBTrees training procedure takes a set of samples
// (train_data) and a set of responses (responses).
// Only pairs (train_data[i], responses[i]), where i is
// in sample_idx are used for training the ensemble.
// subsample_train - indices of samples used for training a single decision
// tree on the current step. This indices are countered
// relatively to the sample_idx, so that pairs
// (train_data[sample_idx[i]], responses[sample_idx[i]])
// are used for training a decision tree.
// Training set is randomly splited
// in two parts (subsample_train and subsample_test)
// on every iteration accordingly to the portion parameter.
// subsample_test - relative indices of samples from the training set,
// which are not used for training a tree on the current
// step.
// missing - mask of the missing values in the training set. This
// matrix has the same size as train_data. 1 - missing
// value, 0 - not a missing value.
// class_labels - output class labels map.
// rng - random number generator. Used for spliting the
// training set.
// class_count - count of output classes.
// class_count == 1 in the case of regression,
// and > 1 in the case of classification.
// delta - Huber loss function parameter.
// base_value - start point of the gradient descent procedure.
// model prediction is
// f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
// f_0 is the base value.
class CV_EXPORTS CvGBTrees : public CvStatModel
{
public:
/*
// DataType: ENUM
// Loss functions implemented in CvGBTrees.
//
// SQUARED_LOSS
// problem: regression
// loss = (x - x')^2
//
// ABSOLUTE_LOSS
// problem: regression
// loss = abs(x - x')
//
// HUBER_LOSS
// problem: regression
// loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
// 1/2*(x - x')^2, if abs(x - x') <= delta,
// where delta is the alpha-quantile of pseudo responses from
// the training set.
//
// DEVIANCE_LOSS
// problem: classification
//
*/
enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
/*
// Default constructor. Creates a model only (without training).
// Should be followed by one form of the train(...) function.
//
// API
// CvGBTrees();
// INPUT
// OUTPUT
// RESULT
*/
CvGBTrees();
/*
// Full form constructor. Creates a gradient boosting model and does the
// train.
//
// API
// CvGBTrees( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx=0,
const CvMat* _sample_idx=0, const CvMat* _var_type=0,
const CvMat* _missing_mask=0,
CvGBTreesParams params=CvGBTreesParams() );
// INPUT
// _train_data - a set of input feature vectors.
// size of matrix is
// <count of samples> x <variables count>
// or <variables count> x <count of samples>
// depending on the _tflag parameter.
// matrix values are float.
// _tflag - a flag showing how do samples stored in the
// _train_data matrix row by row (tflag=CV_ROW_SAMPLE)
// or column by column (tflag=CV_COL_SAMPLE).
// _responses - a vector of responses corresponding to the samples
// in _train_data.
// _var_idx - indices of used variables. zero value means that all
// variables are active.
// _sample_idx - indices of used samples. zero value means that all
// samples from _train_data are in the training set.
// _var_type - vector of <variables count> length. gives every
// variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
// _var_type = 0 means all variables are numerical.
// _missing_mask - a mask of misiing values in _train_data.
// _missing_mask = 0 means that there are no missing
// values.
// params - parameters of GTB algorithm.
// OUTPUT
// RESULT
*/
CvGBTrees( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx=0,
const CvMat* _sample_idx=0, const CvMat* _var_type=0,
const CvMat* _missing_mask=0,
CvGBTreesParams params=CvGBTreesParams() );
/*
// Destructor.
*/
virtual ~CvGBTrees();
/*
// Gradient tree boosting model training
//
// API
// virtual bool train( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx=0,
const CvMat* _sample_idx=0, const CvMat* _var_type=0,
const CvMat* _missing_mask=0,
CvGBTreesParams params=CvGBTreesParams(),
bool update=false );
// INPUT
// _train_data - a set of input feature vectors.
// size of matrix is
// <count of samples> x <variables count>
// or <variables count> x <count of samples>
// depending on the _tflag parameter.
// matrix values are float.
// _tflag - a flag showing how do samples stored in the
// _train_data matrix row by row (tflag=CV_ROW_SAMPLE)
// or column by column (tflag=CV_COL_SAMPLE).
// _responses - a vector of responses corresponding to the samples
// in _train_data.
// _var_idx - indices of used variables. zero value means that all
// variables are active.
// _sample_idx - indices of used samples. zero value means that all
// samples from _train_data are in the training set.
// _var_type - vector of <variables count> length. gives every
// variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
// _var_type = 0 means all variables are numerical.
// _missing_mask - a mask of misiing values in _train_data.
// _missing_mask = 0 means that there are no missing
// values.
// params - parameters of GTB algorithm.
// update - is not supported now. (!)
// OUTPUT
// RESULT
// Error state.
*/
virtual bool train( const CvMat* _train_data, int _tflag,
const CvMat* _responses, const CvMat* _var_idx=0,
const CvMat* _sample_idx=0, const CvMat* _var_type=0,
const CvMat* _missing_mask=0,
CvGBTreesParams params=CvGBTreesParams(),
bool update=false );
/*
// Gradient tree boosting model training
//
// API
// virtual bool train( CvMLData* data,
CvGBTreesParams params=CvGBTreesParams(),
bool update=false ) {return false;};
// INPUT
// data - training set.
// params - parameters of GTB algorithm.
// update - is not supported now. (!)
// OUTPUT
// RESULT
// Error state.
*/
virtual bool train( CvMLData* data,
CvGBTreesParams params=CvGBTreesParams(),
bool update=false );
/*
// Response value prediction
//
// API
// virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
// INPUT
// _sample - input sample of the same type as in the training set.
// _missing - missing values mask. _missing=0 if there are no
// missing values in _sample vector.
// weak_responses - predictions of all of the trees.
// not implemented (!)
// slice - part of the ensemble used for prediction.
// slice = CV_WHOLE_SEQ when all trees are used.
// k - number of ensemble used.
// k is in {-1,0,1,..,<count of output classes-1>}.
// in the case of classification problem
// <count of output classes-1> ensembles are built.
// If k = -1 ordinary prediction is the result,
// otherwise function gives the prediction of the
// k-th ensemble only.
// OUTPUT
// RESULT
// Predicted value.
*/
virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
int k=-1 ) const;
/*
// Delete all temporary data.
//
// API
// virtual void clear();
// INPUT
// OUTPUT
// delete data, weak, orig_response, sum_response,
// weak_eval, ubsample_train, subsample_test,
// sample_idx, missing, lass_labels
// delta = 0.0
// RESULT
*/
virtual void clear();
/*
// Compute error on the train/test set.
//
// API
// virtual float calc_error( CvMLData* _data, int type,
// std::vector<float> *resp = 0 );
//
// INPUT
// data - dataset
// type - defines which error is to compute^ train (CV_TRAIN_ERROR) or
// test (CV_TEST_ERROR).
// OUTPUT
// resp - vector of predicitons
// RESULT
// Error value.
*/
virtual float calc_error( CvMLData* _data, int type,
std::vector<float> *resp = 0 );
/*
//
// Write parameters of the gtb model and data. Write learned model.
//
// API
// virtual void write( CvFileStorage* fs, const char* name ) const;
//
// INPUT
// fs - file storage to read parameters from.
// name - model name.
// OUTPUT
// RESULT
*/
virtual void write( CvFileStorage* fs, const char* name ) const;
/*
//
// Read parameters of the gtb model and data. Read learned model.
//
// API
// virtual void read( CvFileStorage* fs, CvFileNode* node );
//
// INPUT
// fs - file storage to read parameters from.
// node - file node.
// OUTPUT
// RESULT
*/
virtual void read( CvFileStorage* fs, CvFileNode* node );
protected:
/*
// Compute the gradient vector components.
//
// API
// virtual void find_gradient( const int k = 0);
// INPUT
// k - used for classification problem, determining current
// tree ensemble.
// OUTPUT
// changes components of data->responses
// which correspond to samples used for training
// on the current step.
// RESULT
*/
virtual void find_gradient( const int k = 0);
/*
//
// Change values in tree leaves according to the used loss function.
//
// API
// virtual void change_values(CvDTree* tree, const int k = 0);
//
// INPUT
// tree - decision tree to change.
// k - used for classification problem, determining current
// tree ensemble.
// OUTPUT
// changes 'value' fields of the trees' leaves.
// changes sum_response_tmp.
// RESULT
*/
virtual void change_values(CvDTree* tree, const int k = 0);
/*
//
// Find optimal constant prediction value according to the used loss
// function.
// The goal is to find a constant which gives the minimal summary loss
// on the _Idx samples.
//
// API
// virtual float find_optimal_value( const CvMat* _Idx );
//
// INPUT
// _Idx - indices of the samples from the training set.
// OUTPUT
// RESULT
// optimal constant value.
*/
virtual float find_optimal_value( const CvMat* _Idx );
/*
//
// Randomly split the whole training set in two parts according
// to params.portion.
//
// API
// virtual void do_subsample();
//
// INPUT
// OUTPUT
// subsample_train - indices of samples used for training
// subsample_test - indices of samples used for test
// RESULT
*/
virtual void do_subsample();
/*
//
// Internal recursive function giving an array of subtree tree leaves.
//
// API
// void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
//
// INPUT
// node - current leaf.
// OUTPUT
// count - count of leaves in the subtree.
// leaves - array of pointers to leaves.
// RESULT
*/
void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
/*
//
// Get leaves of the tree.
//
// API
// CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
//
// INPUT
// dtree - decision tree.
// OUTPUT
// len - count of the leaves.
// RESULT
// CvDTreeNode** - array of pointers to leaves.
*/
CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
/*
//
// Is it a regression or a classification.
//
// API
// bool problem_type();
//
// INPUT
// OUTPUT
// RESULT
// false if it is a classification problem,
// true - if regression.
*/
virtual bool problem_type() const;
/*
//
// Write parameters of the gtb model.
//
// API
// virtual void write_params( CvFileStorage* fs ) const;
//
// INPUT
// fs - file storage to write parameters to.
// OUTPUT
// RESULT
*/
virtual void write_params( CvFileStorage* fs ) const;
/*
//
// Read parameters of the gtb model and data.
//
// API
// virtual void read_params( CvFileStorage* fs );
//
// INPUT
// fs - file storage to read parameters from.
// OUTPUT
// params - parameters of the gtb model.
// data - contains information about the structure
// of the data set (count of variables,
// their types, etc.).
// class_labels - output class labels map.
// RESULT
*/
virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
CvDTreeTrainData* data;
CvGBTreesParams params;
CvSeq** weak;
CvMat* orig_response;
CvMat* sum_response;
CvMat* sum_response_tmp;
CvMat* weak_eval;
CvMat* sample_idx;
CvMat* subsample_train;
CvMat* subsample_test;
CvMat* missing;
CvMat* class_labels;
CvRNG rng;
int class_count;
float delta;
float base_value;
};
/****************************************************************************************\
* Artificial Neural Networks (ANN) *
\****************************************************************************************/
@@ -1936,6 +2463,8 @@ typedef CvBoostTree BoostTree;
typedef CvBoost Boost;
typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
typedef CvANN_MLP NeuralNet_MLP;
typedef CvGBTreesParams GradientBoostingTreesParams;
typedef CvGBTrees GradientBoostingTrees;
}