limited rng seed choice by 5 values in dtree, boost, rtrees, ertrees tests; hope this make them more stable.

This commit is contained in:
Maria Dimashova 2010-11-29 15:22:15 +00:00
parent 56745b5400
commit 78ccde56b2
3 changed files with 344 additions and 324 deletions

View File

@ -100,12 +100,17 @@ int CV_AMLTest::validate_test_results( int testCaseIdx )
resultNode["mean"] >> mean;
resultNode["sigma"] >> sigma;
float curErr = get_error( testCaseIdx, CV_TEST_ERROR );
if ( abs( curErr - mean) > 6*sigma )
const int coeff = 3;
ts->printf( CvTS::LOG, "Test case = %d; test error = %f; mean error = %f (diff=%f), %d*sigma = %f",
testCaseIdx, curErr, mean, abs( curErr - mean), coeff, coeff*sigma );
if ( abs( curErr - mean) > coeff*sigma )
{
ts->printf( CvTS::LOG, "in test case %d test error is out of range:\n"
"abs(%f/*curEr*/ - %f/*mean*/ > %f/*6*sigma*/", testCaseIdx, curErr, mean, 6*sigma );
ts->printf( CvTS::LOG, "abs(%f - %f) > %f - OUT OF RANGE!\n", curErr, mean, coeff*sigma, coeff );
return CvTS::FAIL_BAD_ACCURACY;
}
else
ts->printf( CvTS::LOG, ".\n" );
}
else
{

View File

@ -45,10 +45,10 @@ CvTS test_system("ml");
const char* blacklist[] =
{
"adtree", //ticket 662
/*"adtree", //ticket 662
"artrees", //ticket 460
"aboost", //ticket 474
"aertrees",
"aertrees",*/ //ticket 505
0
};

View File

@ -43,32 +43,32 @@
// auxiliary functions
// 1. nbayes
void nbayes_check_data( CvMLData* _data )
{
if( _data->get_missing() )
CV_Error( CV_StsBadArg, "missing values are not supported" );
const CvMat* var_types = _data->get_var_types();
bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
void nbayes_check_data( CvMLData* _data )
{
if( _data->get_missing() )
CV_Error( CV_StsBadArg, "missing values are not supported" );
const CvMat* var_types = _data->get_var_types();
bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
if( ( fabs( cvNorm( var_types, 0, CV_L1 ) -
(var_types->rows + var_types->cols - 2)*CV_VAR_ORDERED - CV_VAR_CATEGORICAL ) > FLT_EPSILON ) ||
!is_classifier )
CV_Error( CV_StsBadArg, "incorrect types of predictors or responses" );
}
bool nbayes_train( CvNormalBayesClassifier* nbayes, CvMLData* _data )
{
CV_Error( CV_StsBadArg, "incorrect types of predictors or responses" );
}
bool nbayes_train( CvNormalBayesClassifier* nbayes, CvMLData* _data )
{
nbayes_check_data( _data );
const CvMat* values = _data->get_values();
const CvMat* responses = _data->get_responses();
const CvMat* train_sidx = _data->get_train_sample_idx();
const CvMat* var_idx = _data->get_var_idx();
return nbayes->train( values, responses, var_idx, train_sidx );
}
}
float nbayes_calc_error( CvNormalBayesClassifier* nbayes, CvMLData* _data, int type, vector<float> *resp )
{
float err = 0;
nbayes_check_data( _data );
const CvMat* values = _data->get_values();
const CvMat* response = _data->get_responses();
const CvMat* values = _data->get_values();
const CvMat* response = _data->get_responses();
const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
int* sidx = sample_idx ? sample_idx->data.i : 0;
int r_step = CV_IS_MAT_CONT(response->type) ?
@ -98,10 +98,10 @@ float nbayes_calc_error( CvNormalBayesClassifier* nbayes, CvMLData* _data, int t
}
// 2. knearest
void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors )
{
const CvMat* values = _data->get_values();
const CvMat* var_idx = _data->get_var_idx();
void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors )
{
const CvMat* values = _data->get_values();
const CvMat* var_idx = _data->get_var_idx();
if( var_idx->cols + var_idx->rows != values->cols )
CV_Error( CV_StsBadArg, "var_idx is not supported" );
if( _data->get_missing() )
@ -113,18 +113,18 @@ void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors
cvGetCols( values, _predictors, 0, values->cols - 1 );
else
CV_Error( CV_StsBadArg, "responses must be in the first or last column; other cases are not supported" );
}
bool knearest_train( CvKNearest* knearest, CvMLData* _data )
{
}
bool knearest_train( CvKNearest* knearest, CvMLData* _data )
{
const CvMat* responses = _data->get_responses();
const CvMat* train_sidx = _data->get_train_sample_idx();
bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED;
CvMat predictors;
knearest_check_data_and_get_predictors( _data, &predictors );
return knearest->train( &predictors, responses, train_sidx, is_regression );
}
float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int type, vector<float> *resp )
{
return knearest->train( &predictors, responses, train_sidx, is_regression );
}
float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int type, vector<float> *resp )
{
float err = 0;
const CvMat* response = _data->get_responses();
const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
@ -172,7 +172,7 @@ float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int typ
}
err = sample_count ? err / (float)sample_count : -FLT_MAX;
}
return err;
return err;
}
// 3. svm
@ -204,44 +204,44 @@ int str_to_svm_kernel_type( string& str )
CV_Error( CV_StsBadArg, "incorrect svm type string" );
return -1;
}
void svm_check_data( CvMLData* _data )
{
if( _data->get_missing() )
CV_Error( CV_StsBadArg, "missing values are not supported" );
const CvMat* var_types = _data->get_var_types();
for( int i = 0; i < var_types->cols-1; i++ )
if (var_types->data.ptr[i] == CV_VAR_CATEGORICAL)
{
char msg[50];
sprintf( msg, "incorrect type of %d-predictor", i );
CV_Error( CV_StsBadArg, msg );
}
}
bool svm_train( CvSVM* svm, CvMLData* _data, CvSVMParams _params )
{
svm_check_data(_data);
const CvMat* _train_data = _data->get_values();
const CvMat* _responses = _data->get_responses();
const CvMat* _var_idx = _data->get_var_idx();
const CvMat* _sample_idx = _data->get_train_sample_idx();
return svm->train( _train_data, _responses, _var_idx, _sample_idx, _params );
}
bool svm_train_auto( CvSVM* svm, CvMLData* _data, CvSVMParams _params,
int k_fold, CvParamGrid C_grid, CvParamGrid gamma_grid,
CvParamGrid p_grid, CvParamGrid nu_grid, CvParamGrid coef_grid,
CvParamGrid degree_grid )
{
svm_check_data(_data);
const CvMat* _train_data = _data->get_values();
const CvMat* _responses = _data->get_responses();
const CvMat* _var_idx = _data->get_var_idx();
const CvMat* _sample_idx = _data->get_train_sample_idx();
return svm->train_auto( _train_data, _responses, _var_idx,
_sample_idx, _params, k_fold, C_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
}
float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector<float> *resp )
{
svm_check_data(_data);
void svm_check_data( CvMLData* _data )
{
if( _data->get_missing() )
CV_Error( CV_StsBadArg, "missing values are not supported" );
const CvMat* var_types = _data->get_var_types();
for( int i = 0; i < var_types->cols-1; i++ )
if (var_types->data.ptr[i] == CV_VAR_CATEGORICAL)
{
char msg[50];
sprintf( msg, "incorrect type of %d-predictor", i );
CV_Error( CV_StsBadArg, msg );
}
}
bool svm_train( CvSVM* svm, CvMLData* _data, CvSVMParams _params )
{
svm_check_data(_data);
const CvMat* _train_data = _data->get_values();
const CvMat* _responses = _data->get_responses();
const CvMat* _var_idx = _data->get_var_idx();
const CvMat* _sample_idx = _data->get_train_sample_idx();
return svm->train( _train_data, _responses, _var_idx, _sample_idx, _params );
}
bool svm_train_auto( CvSVM* svm, CvMLData* _data, CvSVMParams _params,
int k_fold, CvParamGrid C_grid, CvParamGrid gamma_grid,
CvParamGrid p_grid, CvParamGrid nu_grid, CvParamGrid coef_grid,
CvParamGrid degree_grid )
{
svm_check_data(_data);
const CvMat* _train_data = _data->get_values();
const CvMat* _responses = _data->get_responses();
const CvMat* _var_idx = _data->get_var_idx();
const CvMat* _sample_idx = _data->get_train_sample_idx();
return svm->train_auto( _train_data, _responses, _var_idx,
_sample_idx, _params, k_fold, C_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid );
}
float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector<float> *resp )
{
svm_check_data(_data);
float err = 0;
const CvMat* values = _data->get_values();
const CvMat* response = _data->get_responses();
@ -289,7 +289,7 @@ float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector<float> *resp
}
err = sample_count ? err / (float)sample_count : -FLT_MAX;
}
return err;
return err;
}
// 4. em
@ -303,10 +303,10 @@ int str_to_ann_train_method( string& str )
CV_Error( CV_StsBadArg, "incorrect ann train method string" );
return -1;
}
void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs )
{
const CvMat* values = _data->get_values();
const CvMat* var_idx = _data->get_var_idx();
void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs )
{
const CvMat* values = _data->get_values();
const CvMat* var_idx = _data->get_var_idx();
if( var_idx->cols + var_idx->rows != values->cols )
CV_Error( CV_StsBadArg, "var_idx is not supported" );
if( _data->get_missing() )
@ -318,10 +318,10 @@ void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs )
cvGetCols( values, _inputs, 0, values->cols - 1 );
else
CV_Error( CV_StsBadArg, "outputs must be in the first or last column; other cases are not supported" );
}
void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map<int, int>& cls_map )
{
const CvMat* train_sidx = _data->get_train_sample_idx();
}
void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map<int, int>& cls_map )
{
const CvMat* train_sidx = _data->get_train_sample_idx();
int* train_sidx_ptr = train_sidx->data.i;
const CvMat* responses = _data->get_responses();
float* responses_ptr = responses->data.fl;
@ -348,18 +348,18 @@ void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map<int, int>&
int r = cvRound(responses_ptr[sidx*r_step]);
int cidx = cls_map[r];
new_responses.ptr<float>(sidx)[cidx] = 1;
}
}
int ann_train( CvANN_MLP* ann, CvMLData* _data, Mat& new_responses, CvANN_MLP_TrainParams _params, int flags = 0 )
{
}
}
int ann_train( CvANN_MLP* ann, CvMLData* _data, Mat& new_responses, CvANN_MLP_TrainParams _params, int flags = 0 )
{
const CvMat* train_sidx = _data->get_train_sample_idx();
CvMat predictors;
ann_check_data_and_get_predictors( _data, &predictors );
CvMat _new_responses = CvMat( new_responses );
return ann->train( &predictors, &_new_responses, 0, train_sidx, _params, flags );
}
float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map<int, int>& cls_map, int type , vector<float> *resp_labels )
{
return ann->train( &predictors, &_new_responses, 0, train_sidx, _params, flags );
}
float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map<int, int>& cls_map, int type , vector<float> *resp_labels )
{
float err = 0;
const CvMat* responses = _data->get_responses();
const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
@ -396,16 +396,16 @@ float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map<int, int>& cls_map, i
cvGetRow( &predictors, &sample, si );
ann->predict( &sample, &_output );
CvPoint best_cls = {0,0};
cvMinMaxLoc( &_output, 0, 0, 0, &best_cls, 0 );
int r = cvRound(responses->data.fl[si*r_step]);
CV_DbgAssert( fabs(responses->data.fl[si*r_step]-r) < FLT_EPSILON );
r = cls_map[r];
int d = best_cls.x == r ? 0 : 1;
cvMinMaxLoc( &_output, 0, 0, 0, &best_cls, 0 );
int r = cvRound(responses->data.fl[si*r_step]);
CV_DbgAssert( fabs(responses->data.fl[si*r_step]-r) < FLT_EPSILON );
r = cls_map[r];
int d = best_cls.x == r ? 0 : 1;
err += d;
pred_resp[i] = (float)best_cls.x;
}
err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
return err;
return err;
}
// 6. dtree
@ -429,10 +429,24 @@ int str_to_boost_type( string& str )
// ---------------------------------- MLBaseTest ---------------------------------------------------
CV_MLBaseTest::CV_MLBaseTest( const char* _modelName, const char* _testName, const char* _testFuncs ) :
CvTest( _testName, _testFuncs )
{
modelName = _modelName;
CV_MLBaseTest::CV_MLBaseTest( const char* _modelName, const char* _testName, const char* _testFuncs ) :
CvTest( _testName, _testFuncs )
{
int64 seeds[] = { 0x00009fff4f9c8d52,
0x0000a17166072c7c,
0x0201b32115cd1f9a,
0x0513cb37abcd1234,
0x0001a2b3c4d5f678
};
int seedCount = sizeof(seeds)/sizeof(seeds[0]);
RNG& rng = theRNG();
initSeed = rng.state;
rng.state = seeds[rng(seedCount)];
modelName = _modelName;
nbayes = 0;
knearest = 0;
svm = 0;
@ -441,35 +455,35 @@ CvTest( _testName, _testFuncs )
dtree = 0;
boost = 0;
rtrees = 0;
ertrees = 0;
if( !modelName.compare(CV_NBAYES) )
nbayes = new CvNormalBayesClassifier;
else if( !modelName.compare(CV_KNEAREST) )
knearest = new CvKNearest;
else if( !modelName.compare(CV_SVM) )
svm = new CvSVM;
else if( !modelName.compare(CV_EM) )
em = new CvEM;
else if( !modelName.compare(CV_ANN) )
ann = new CvANN_MLP;
else if( !modelName.compare(CV_DTREE) )
dtree = new CvDTree;
else if( !modelName.compare(CV_BOOST) )
boost = new CvBoost;
else if( !modelName.compare(CV_RTREES) )
rtrees = new CvRTrees;
else if( !modelName.compare(CV_ERTREES) )
ertrees = new CvERTrees;
}
ertrees = 0;
if( !modelName.compare(CV_NBAYES) )
nbayes = new CvNormalBayesClassifier;
else if( !modelName.compare(CV_KNEAREST) )
knearest = new CvKNearest;
else if( !modelName.compare(CV_SVM) )
svm = new CvSVM;
else if( !modelName.compare(CV_EM) )
em = new CvEM;
else if( !modelName.compare(CV_ANN) )
ann = new CvANN_MLP;
else if( !modelName.compare(CV_DTREE) )
dtree = new CvDTree;
else if( !modelName.compare(CV_BOOST) )
boost = new CvBoost;
else if( !modelName.compare(CV_RTREES) )
rtrees = new CvRTrees;
else if( !modelName.compare(CV_ERTREES) )
ertrees = new CvERTrees;
}
int CV_MLBaseTest::init( CvTS* system )
{
clear();
ts = system;
string filename = ts->get_data_path();
filename += get_validation_filename();
validationFS.open( filename, FileStorage::READ );
clear();
ts = system;
string filename = ts->get_data_path();
filename += get_validation_filename();
validationFS.open( filename, FileStorage::READ );
return read_params( *validationFS );
}
@ -495,6 +509,7 @@ CV_MLBaseTest::~CV_MLBaseTest()
delete rtrees;
if( ertrees )
delete ertrees;
theRNG().state = initSeed;
}
int CV_MLBaseTest::read_params( CvFileStorage* _fs )
@ -519,182 +534,182 @@ int CV_MLBaseTest::read_params( CvFileStorage* _fs )
return CvTS::OK;;
}
void CV_MLBaseTest::run( int start_from )
{
int code = CvTS::OK;
start_from = 0;
for (int i = 0; i < test_case_count; i++)
{
int temp_code = run_test_case( i );
if (temp_code == CvTS::OK)
temp_code = validate_test_results( i );
if (temp_code != CvTS::OK)
code = temp_code;
}
if ( test_case_count <= 0)
void CV_MLBaseTest::run( int start_from )
{
int code = CvTS::OK;
start_from = 0;
for (int i = 0; i < test_case_count; i++)
{
ts->printf( CvTS::LOG, "validation file is not determined or not correct" );
int temp_code = run_test_case( i );
if (temp_code == CvTS::OK)
temp_code = validate_test_results( i );
if (temp_code != CvTS::OK)
code = temp_code;
}
if ( test_case_count <= 0)
{
ts->printf( CvTS::LOG, "validation file is not determined or not correct" );
code = CvTS::FAIL_INVALID_TEST_DATA;
}
ts->set_failed_test_info( code );
}
int CV_MLBaseTest::prepare_test_case( int test_case_idx )
{
int trainSampleCount, respIdx;
string varTypes;
clear();
}
ts->set_failed_test_info( code );
}
string dataPath = ts->get_data_path();
int CV_MLBaseTest::prepare_test_case( int test_case_idx )
{
int trainSampleCount, respIdx;
string varTypes;
clear();
string dataPath = ts->get_data_path();
if ( dataPath.empty() )
{
ts->printf( CvTS::LOG, "data path is empty" );
ts->printf( CvTS::LOG, "data path is empty" );
return CvTS::FAIL_INVALID_TEST_DATA;
}
string dataName = dataSetNames[test_case_idx],
filename = dataPath + dataName + ".data";
if ( data.read_csv( filename.c_str() ) != 0)
}
string dataName = dataSetNames[test_case_idx],
filename = dataPath + dataName + ".data";
if ( data.read_csv( filename.c_str() ) != 0)
{
char msg[100];
sprintf( msg, "file %s can not be read", filename.c_str() );
ts->printf( CvTS::LOG, msg );
ts->printf( CvTS::LOG, msg );
return CvTS::FAIL_INVALID_TEST_DATA;
}
FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
CV_DbgAssert( !dataParamsNode.empty() );
CV_DbgAssert( !dataParamsNode["LS"].empty() );
dataParamsNode["LS"] >> trainSampleCount;
CvTrainTestSplit spl( trainSampleCount );
data.set_train_test_split( &spl );
CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
dataParamsNode["resp_idx"] >> respIdx;
data.set_response_idx( respIdx );
CV_DbgAssert( !dataParamsNode["types"].empty() );
dataParamsNode["types"] >> varTypes;
data.set_var_types( varTypes.c_str() );
return CvTS::OK;
}
string& CV_MLBaseTest::get_validation_filename()
{
return validationFN;
}
}
FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"];
CV_DbgAssert( !dataParamsNode.empty() );
CV_DbgAssert( !dataParamsNode["LS"].empty() );
dataParamsNode["LS"] >> trainSampleCount;
CvTrainTestSplit spl( trainSampleCount );
data.set_train_test_split( &spl );
CV_DbgAssert( !dataParamsNode["resp_idx"].empty() );
dataParamsNode["resp_idx"] >> respIdx;
data.set_response_idx( respIdx );
CV_DbgAssert( !dataParamsNode["types"].empty() );
dataParamsNode["types"] >> varTypes;
data.set_var_types( varTypes.c_str() );
return CvTS::OK;
}
string& CV_MLBaseTest::get_validation_filename()
{
return validationFN;
}
int CV_MLBaseTest::train( int testCaseIdx )
{
bool is_trained = false;
FileNode modelParamsNode =
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
FileNode modelParamsNode =
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"];
if( !modelName.compare(CV_NBAYES) )
is_trained = nbayes_train( nbayes, &data );
else if( !modelName.compare(CV_KNEAREST) )
{
assert( 0 );
//is_trained = knearest->train( &data );
}
else if( !modelName.compare(CV_SVM) )
{
string svm_type_str, kernel_type_str;
modelParamsNode["svm_type"] >> svm_type_str;
modelParamsNode["kernel_type"] >> kernel_type_str;
CvSVMParams params;
params.svm_type = str_to_svm_type( svm_type_str );
params.kernel_type = str_to_svm_kernel_type( kernel_type_str );
modelParamsNode["degree"] >> params.degree;
modelParamsNode["gamma"] >> params.gamma;
modelParamsNode["coef0"] >> params.coef0;
modelParamsNode["C"] >> params.C;
modelParamsNode["nu"] >> params.nu;
modelParamsNode["p"] >> params.p;
is_trained = svm_train( svm, &data, params );
}
else if( !modelName.compare(CV_EM) )
{
assert( 0 );
}
else if( !modelName.compare(CV_ANN) )
{
string train_method_str;
double param1, param2;
modelParamsNode["train_method"] >> train_method_str;
modelParamsNode["param1"] >> param1;
modelParamsNode["param2"] >> param2;
Mat new_responses;
ann_get_new_responses( &data, new_responses, cls_map );
int layer_sz[] = { data.get_values()->cols - 1, 100, 100, (int)cls_map.size() };
CvMat layer_sizes =
cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
ann->create( &layer_sizes );
is_trained = ann_train( ann, &data, new_responses, CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01),
str_to_ann_train_method(train_method_str), param1, param2) ) >= 0;
}
else if( !modelName.compare(CV_DTREE) )
{
if( !modelName.compare(CV_NBAYES) )
is_trained = nbayes_train( nbayes, &data );
else if( !modelName.compare(CV_KNEAREST) )
{
assert( 0 );
//is_trained = knearest->train( &data );
}
else if( !modelName.compare(CV_SVM) )
{
string svm_type_str, kernel_type_str;
modelParamsNode["svm_type"] >> svm_type_str;
modelParamsNode["kernel_type"] >> kernel_type_str;
CvSVMParams params;
params.svm_type = str_to_svm_type( svm_type_str );
params.kernel_type = str_to_svm_kernel_type( kernel_type_str );
modelParamsNode["degree"] >> params.degree;
modelParamsNode["gamma"] >> params.gamma;
modelParamsNode["coef0"] >> params.coef0;
modelParamsNode["C"] >> params.C;
modelParamsNode["nu"] >> params.nu;
modelParamsNode["p"] >> params.p;
is_trained = svm_train( svm, &data, params );
}
else if( !modelName.compare(CV_EM) )
{
assert( 0 );
}
else if( !modelName.compare(CV_ANN) )
{
string train_method_str;
double param1, param2;
modelParamsNode["train_method"] >> train_method_str;
modelParamsNode["param1"] >> param1;
modelParamsNode["param2"] >> param2;
Mat new_responses;
ann_get_new_responses( &data, new_responses, cls_map );
int layer_sz[] = { data.get_values()->cols - 1, 100, 100, (int)cls_map.size() };
CvMat layer_sizes =
cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz );
ann->create( &layer_sizes );
is_trained = ann_train( ann, &data, new_responses, CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01),
str_to_ann_train_method(train_method_str), param1, param2) ) >= 0;
}
else if( !modelName.compare(CV_DTREE) )
{
int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS;
float REG_ACCURACY = 0;
bool USE_SURROGATE, IS_PRUNED;
modelParamsNode["max_depth"] >> MAX_DEPTH;
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
modelParamsNode["use_surrogate"] >> USE_SURROGATE;
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
modelParamsNode["cv_folds"] >> CV_FOLDS;
modelParamsNode["is_pruned"] >> IS_PRUNED;
bool USE_SURROGATE, IS_PRUNED;
modelParamsNode["max_depth"] >> MAX_DEPTH;
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
modelParamsNode["use_surrogate"] >> USE_SURROGATE;
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
modelParamsNode["cv_folds"] >> CV_FOLDS;
modelParamsNode["is_pruned"] >> IS_PRUNED;
is_trained = dtree->train( &data,
CvDTreeParams(MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, USE_SURROGATE,
MAX_CATEGORIES, CV_FOLDS, false, IS_PRUNED, 0 )) != 0;
}
else if( !modelName.compare(CV_BOOST) )
{
}
else if( !modelName.compare(CV_BOOST) )
{
int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH;
float WEIGHT_TRIM_RATE;
bool USE_SURROGATE;
string typeStr;
modelParamsNode["type"] >> typeStr;
BOOST_TYPE = str_to_boost_type( typeStr );
bool USE_SURROGATE;
string typeStr;
modelParamsNode["type"] >> typeStr;
BOOST_TYPE = str_to_boost_type( typeStr );
modelParamsNode["weak_count"] >> WEAK_COUNT;
modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE;
modelParamsNode["max_depth"] >> MAX_DEPTH;
modelParamsNode["use_surrogate"] >> USE_SURROGATE;
is_trained = boost->train( &data,
CvBoostParams(BOOST_TYPE, WEAK_COUNT, WEIGHT_TRIM_RATE, MAX_DEPTH, USE_SURROGATE, 0) ) != 0;
}
else if( !modelName.compare(CV_RTREES) )
{
int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
float REG_ACCURACY = 0, OOB_EPS = 0.0;
bool USE_SURROGATE, IS_PRUNED;
modelParamsNode["max_depth"] >> MAX_DEPTH;
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
modelParamsNode["use_surrogate"] >> USE_SURROGATE;
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
modelParamsNode["cv_folds"] >> CV_FOLDS;
modelParamsNode["is_pruned"] >> IS_PRUNED;
modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
is_trained = rtrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,
USE_SURROGATE, MAX_CATEGORIES, 0, true, // (calc_var_importance == true) <=> RF processes variable importance
NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
}
else if( !modelName.compare(CV_ERTREES) )
CvBoostParams(BOOST_TYPE, WEAK_COUNT, WEIGHT_TRIM_RATE, MAX_DEPTH, USE_SURROGATE, 0) ) != 0;
}
else if( !modelName.compare(CV_RTREES) )
{
int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
float REG_ACCURACY = 0, OOB_EPS = 0.0;
bool USE_SURROGATE, IS_PRUNED;
modelParamsNode["max_depth"] >> MAX_DEPTH;
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
modelParamsNode["use_surrogate"] >> USE_SURROGATE;
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
modelParamsNode["cv_folds"] >> CV_FOLDS;
modelParamsNode["is_pruned"] >> IS_PRUNED;
modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
modelParamsNode["max_depth"] >> MAX_DEPTH;
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
modelParamsNode["use_surrogate"] >> USE_SURROGATE;
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
modelParamsNode["cv_folds"] >> CV_FOLDS;
modelParamsNode["is_pruned"] >> IS_PRUNED;
modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
is_trained = rtrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,
USE_SURROGATE, MAX_CATEGORIES, 0, true, // (calc_var_importance == true) <=> RF processes variable importance
NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
}
else if( !modelName.compare(CV_ERTREES) )
{
int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM;
float REG_ACCURACY = 0, OOB_EPS = 0.0;
bool USE_SURROGATE, IS_PRUNED;
modelParamsNode["max_depth"] >> MAX_DEPTH;
modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT;
modelParamsNode["use_surrogate"] >> USE_SURROGATE;
modelParamsNode["max_categories"] >> MAX_CATEGORIES;
modelParamsNode["cv_folds"] >> CV_FOLDS;
modelParamsNode["is_pruned"] >> IS_PRUNED;
modelParamsNode["nactive_vars"] >> NACTIVE_VARS;
modelParamsNode["max_trees_num"] >> MAX_TREES_NUM;
is_trained = ertrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY,
USE_SURROGATE, MAX_CATEGORIES, 0, false, // (calc_var_importance == true) <=> RF processes variable importance
NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0;
@ -711,74 +726,74 @@ int CV_MLBaseTest::train( int testCaseIdx )
float CV_MLBaseTest::get_error( int testCaseIdx, int type, vector<float> *resp )
{
float err = 0;
if( !modelName.compare(CV_NBAYES) )
err = nbayes_calc_error( nbayes, &data, type, resp );
else if( !modelName.compare(CV_KNEAREST) )
{
assert( 0 );
testCaseIdx = 0;
/*int k = 2;
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]["k"] >> k;
err = knearest->calc_error( &data, k, type, resp );*/
}
else if( !modelName.compare(CV_SVM) )
err = svm_calc_error( svm, &data, type, resp );
else if( !modelName.compare(CV_EM) )
assert( 0 );
else if( !modelName.compare(CV_ANN) )
err = ann_calc_error( ann, &data, cls_map, type, resp );
else if( !modelName.compare(CV_DTREE) )
err = dtree->calc_error( &data, type, resp );
else if( !modelName.compare(CV_BOOST) )
err = boost->calc_error( &data, type, resp );
else if( !modelName.compare(CV_RTREES) )
err = rtrees->calc_error( &data, type, resp );
else if( !modelName.compare(CV_ERTREES) )
if( !modelName.compare(CV_NBAYES) )
err = nbayes_calc_error( nbayes, &data, type, resp );
else if( !modelName.compare(CV_KNEAREST) )
{
assert( 0 );
testCaseIdx = 0;
/*int k = 2;
validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]["k"] >> k;
err = knearest->calc_error( &data, k, type, resp );*/
}
else if( !modelName.compare(CV_SVM) )
err = svm_calc_error( svm, &data, type, resp );
else if( !modelName.compare(CV_EM) )
assert( 0 );
else if( !modelName.compare(CV_ANN) )
err = ann_calc_error( ann, &data, cls_map, type, resp );
else if( !modelName.compare(CV_DTREE) )
err = dtree->calc_error( &data, type, resp );
else if( !modelName.compare(CV_BOOST) )
err = boost->calc_error( &data, type, resp );
else if( !modelName.compare(CV_RTREES) )
err = rtrees->calc_error( &data, type, resp );
else if( !modelName.compare(CV_ERTREES) )
err = ertrees->calc_error( &data, type, resp );
return err;
}
void CV_MLBaseTest::save( const char* filename )
{
if( !modelName.compare(CV_NBAYES) )
nbayes->save( filename );
else if( !modelName.compare(CV_KNEAREST) )
knearest->save( filename );
else if( !modelName.compare(CV_SVM) )
svm->save( filename );
else if( !modelName.compare(CV_EM) )
em->save( filename );
else if( !modelName.compare(CV_ANN) )
ann->save( filename );
else if( !modelName.compare(CV_DTREE) )
dtree->save( filename );
else if( !modelName.compare(CV_BOOST) )
boost->save( filename );
else if( !modelName.compare(CV_RTREES) )
rtrees->save( filename );
else if( !modelName.compare(CV_ERTREES) )
if( !modelName.compare(CV_NBAYES) )
nbayes->save( filename );
else if( !modelName.compare(CV_KNEAREST) )
knearest->save( filename );
else if( !modelName.compare(CV_SVM) )
svm->save( filename );
else if( !modelName.compare(CV_EM) )
em->save( filename );
else if( !modelName.compare(CV_ANN) )
ann->save( filename );
else if( !modelName.compare(CV_DTREE) )
dtree->save( filename );
else if( !modelName.compare(CV_BOOST) )
boost->save( filename );
else if( !modelName.compare(CV_RTREES) )
rtrees->save( filename );
else if( !modelName.compare(CV_ERTREES) )
ertrees->save( filename );
}
void CV_MLBaseTest::load( const char* filename )
{
if( !modelName.compare(CV_NBAYES) )
nbayes->load( filename );
else if( !modelName.compare(CV_KNEAREST) )
knearest->load( filename );
else if( !modelName.compare(CV_SVM) )
svm->load( filename );
else if( !modelName.compare(CV_EM) )
em->load( filename );
else if( !modelName.compare(CV_ANN) )
ann->load( filename );
else if( !modelName.compare(CV_DTREE) )
dtree->load( filename );
else if( !modelName.compare(CV_BOOST) )
boost->load( filename );
else if( !modelName.compare(CV_RTREES) )
rtrees->load( filename );
else if( !modelName.compare(CV_ERTREES) )
if( !modelName.compare(CV_NBAYES) )
nbayes->load( filename );
else if( !modelName.compare(CV_KNEAREST) )
knearest->load( filename );
else if( !modelName.compare(CV_SVM) )
svm->load( filename );
else if( !modelName.compare(CV_EM) )
em->load( filename );
else if( !modelName.compare(CV_ANN) )
ann->load( filename );
else if( !modelName.compare(CV_DTREE) )
dtree->load( filename );
else if( !modelName.compare(CV_BOOST) )
boost->load( filename );
else if( !modelName.compare(CV_RTREES) )
rtrees->load( filename );
else if( !modelName.compare(CV_ERTREES) )
ertrees->load( filename );
}