diff --git a/tests/ml/src/amltests.cpp b/tests/ml/src/amltests.cpp index c4bd3f7fa..db21e5ae9 100644 --- a/tests/ml/src/amltests.cpp +++ b/tests/ml/src/amltests.cpp @@ -100,12 +100,17 @@ int CV_AMLTest::validate_test_results( int testCaseIdx ) resultNode["mean"] >> mean; resultNode["sigma"] >> sigma; float curErr = get_error( testCaseIdx, CV_TEST_ERROR ); - if ( abs( curErr - mean) > 6*sigma ) + const int coeff = 3; + ts->printf( CvTS::LOG, "Test case = %d; test error = %f; mean error = %f (diff=%f), %d*sigma = %f", + testCaseIdx, curErr, mean, abs( curErr - mean), coeff, coeff*sigma ); + if ( abs( curErr - mean) > coeff*sigma ) { - ts->printf( CvTS::LOG, "in test case %d test error is out of range:\n" - "abs(%f/*curEr*/ - %f/*mean*/ > %f/*6*sigma*/", testCaseIdx, curErr, mean, 6*sigma ); + ts->printf( CvTS::LOG, "abs(%f - %f) > %f - OUT OF RANGE!\n", curErr, mean, coeff*sigma, coeff ); return CvTS::FAIL_BAD_ACCURACY; } + else + ts->printf( CvTS::LOG, ".\n" ); + } else { diff --git a/tests/ml/src/mltest_main.cpp b/tests/ml/src/mltest_main.cpp index b40f7a4f9..89653c5bb 100644 --- a/tests/ml/src/mltest_main.cpp +++ b/tests/ml/src/mltest_main.cpp @@ -45,10 +45,10 @@ CvTS test_system("ml"); const char* blacklist[] = { - "adtree", //ticket 662 + /*"adtree", //ticket 662 "artrees", //ticket 460 "aboost", //ticket 474 - "aertrees", + "aertrees",*/ //ticket 505 0 }; diff --git a/tests/ml/src/mltests.cpp b/tests/ml/src/mltests.cpp index 33db9517e..f04001ac3 100644 --- a/tests/ml/src/mltests.cpp +++ b/tests/ml/src/mltests.cpp @@ -43,32 +43,32 @@ // auxiliary functions // 1. nbayes -void nbayes_check_data( CvMLData* _data ) -{ - if( _data->get_missing() ) - CV_Error( CV_StsBadArg, "missing values are not supported" ); - const CvMat* var_types = _data->get_var_types(); - bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL; +void nbayes_check_data( CvMLData* _data ) +{ + if( _data->get_missing() ) + CV_Error( CV_StsBadArg, "missing values are not supported" ); + const CvMat* var_types = _data->get_var_types(); + bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL; if( ( fabs( cvNorm( var_types, 0, CV_L1 ) - (var_types->rows + var_types->cols - 2)*CV_VAR_ORDERED - CV_VAR_CATEGORICAL ) > FLT_EPSILON ) || !is_classifier ) - CV_Error( CV_StsBadArg, "incorrect types of predictors or responses" ); -} -bool nbayes_train( CvNormalBayesClassifier* nbayes, CvMLData* _data ) -{ + CV_Error( CV_StsBadArg, "incorrect types of predictors or responses" ); +} +bool nbayes_train( CvNormalBayesClassifier* nbayes, CvMLData* _data ) +{ nbayes_check_data( _data ); const CvMat* values = _data->get_values(); const CvMat* responses = _data->get_responses(); const CvMat* train_sidx = _data->get_train_sample_idx(); const CvMat* var_idx = _data->get_var_idx(); return nbayes->train( values, responses, var_idx, train_sidx ); -} +} float nbayes_calc_error( CvNormalBayesClassifier* nbayes, CvMLData* _data, int type, vector *resp ) { float err = 0; nbayes_check_data( _data ); - const CvMat* values = _data->get_values(); - const CvMat* response = _data->get_responses(); + const CvMat* values = _data->get_values(); + const CvMat* response = _data->get_responses(); const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx(); int* sidx = sample_idx ? sample_idx->data.i : 0; int r_step = CV_IS_MAT_CONT(response->type) ? @@ -98,10 +98,10 @@ float nbayes_calc_error( CvNormalBayesClassifier* nbayes, CvMLData* _data, int t } // 2. knearest -void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors ) -{ - const CvMat* values = _data->get_values(); - const CvMat* var_idx = _data->get_var_idx(); +void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors ) +{ + const CvMat* values = _data->get_values(); + const CvMat* var_idx = _data->get_var_idx(); if( var_idx->cols + var_idx->rows != values->cols ) CV_Error( CV_StsBadArg, "var_idx is not supported" ); if( _data->get_missing() ) @@ -113,18 +113,18 @@ void knearest_check_data_and_get_predictors( CvMLData* _data, CvMat* _predictors cvGetCols( values, _predictors, 0, values->cols - 1 ); else CV_Error( CV_StsBadArg, "responses must be in the first or last column; other cases are not supported" ); -} -bool knearest_train( CvKNearest* knearest, CvMLData* _data ) -{ +} +bool knearest_train( CvKNearest* knearest, CvMLData* _data ) +{ const CvMat* responses = _data->get_responses(); const CvMat* train_sidx = _data->get_train_sample_idx(); bool is_regression = _data->get_var_type( _data->get_response_idx() ) == CV_VAR_ORDERED; CvMat predictors; knearest_check_data_and_get_predictors( _data, &predictors ); - return knearest->train( &predictors, responses, train_sidx, is_regression ); -} -float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int type, vector *resp ) -{ + return knearest->train( &predictors, responses, train_sidx, is_regression ); +} +float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int type, vector *resp ) +{ float err = 0; const CvMat* response = _data->get_responses(); const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx(); @@ -172,7 +172,7 @@ float knearest_calc_error( CvKNearest* knearest, CvMLData* _data, int k, int typ } err = sample_count ? err / (float)sample_count : -FLT_MAX; } - return err; + return err; } // 3. svm @@ -204,44 +204,44 @@ int str_to_svm_kernel_type( string& str ) CV_Error( CV_StsBadArg, "incorrect svm type string" ); return -1; } -void svm_check_data( CvMLData* _data ) -{ - if( _data->get_missing() ) - CV_Error( CV_StsBadArg, "missing values are not supported" ); - const CvMat* var_types = _data->get_var_types(); - for( int i = 0; i < var_types->cols-1; i++ ) - if (var_types->data.ptr[i] == CV_VAR_CATEGORICAL) - { - char msg[50]; - sprintf( msg, "incorrect type of %d-predictor", i ); - CV_Error( CV_StsBadArg, msg ); - } -} -bool svm_train( CvSVM* svm, CvMLData* _data, CvSVMParams _params ) -{ - svm_check_data(_data); - const CvMat* _train_data = _data->get_values(); - const CvMat* _responses = _data->get_responses(); - const CvMat* _var_idx = _data->get_var_idx(); - const CvMat* _sample_idx = _data->get_train_sample_idx(); - return svm->train( _train_data, _responses, _var_idx, _sample_idx, _params ); -} -bool svm_train_auto( CvSVM* svm, CvMLData* _data, CvSVMParams _params, - int k_fold, CvParamGrid C_grid, CvParamGrid gamma_grid, - CvParamGrid p_grid, CvParamGrid nu_grid, CvParamGrid coef_grid, - CvParamGrid degree_grid ) -{ - svm_check_data(_data); - const CvMat* _train_data = _data->get_values(); - const CvMat* _responses = _data->get_responses(); - const CvMat* _var_idx = _data->get_var_idx(); - const CvMat* _sample_idx = _data->get_train_sample_idx(); - return svm->train_auto( _train_data, _responses, _var_idx, - _sample_idx, _params, k_fold, C_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid ); -} -float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector *resp ) -{ - svm_check_data(_data); +void svm_check_data( CvMLData* _data ) +{ + if( _data->get_missing() ) + CV_Error( CV_StsBadArg, "missing values are not supported" ); + const CvMat* var_types = _data->get_var_types(); + for( int i = 0; i < var_types->cols-1; i++ ) + if (var_types->data.ptr[i] == CV_VAR_CATEGORICAL) + { + char msg[50]; + sprintf( msg, "incorrect type of %d-predictor", i ); + CV_Error( CV_StsBadArg, msg ); + } +} +bool svm_train( CvSVM* svm, CvMLData* _data, CvSVMParams _params ) +{ + svm_check_data(_data); + const CvMat* _train_data = _data->get_values(); + const CvMat* _responses = _data->get_responses(); + const CvMat* _var_idx = _data->get_var_idx(); + const CvMat* _sample_idx = _data->get_train_sample_idx(); + return svm->train( _train_data, _responses, _var_idx, _sample_idx, _params ); +} +bool svm_train_auto( CvSVM* svm, CvMLData* _data, CvSVMParams _params, + int k_fold, CvParamGrid C_grid, CvParamGrid gamma_grid, + CvParamGrid p_grid, CvParamGrid nu_grid, CvParamGrid coef_grid, + CvParamGrid degree_grid ) +{ + svm_check_data(_data); + const CvMat* _train_data = _data->get_values(); + const CvMat* _responses = _data->get_responses(); + const CvMat* _var_idx = _data->get_var_idx(); + const CvMat* _sample_idx = _data->get_train_sample_idx(); + return svm->train_auto( _train_data, _responses, _var_idx, + _sample_idx, _params, k_fold, C_grid, gamma_grid, p_grid, nu_grid, coef_grid, degree_grid ); +} +float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector *resp ) +{ + svm_check_data(_data); float err = 0; const CvMat* values = _data->get_values(); const CvMat* response = _data->get_responses(); @@ -289,7 +289,7 @@ float svm_calc_error( CvSVM* svm, CvMLData* _data, int type, vector *resp } err = sample_count ? err / (float)sample_count : -FLT_MAX; } - return err; + return err; } // 4. em @@ -303,10 +303,10 @@ int str_to_ann_train_method( string& str ) CV_Error( CV_StsBadArg, "incorrect ann train method string" ); return -1; } -void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs ) -{ - const CvMat* values = _data->get_values(); - const CvMat* var_idx = _data->get_var_idx(); +void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs ) +{ + const CvMat* values = _data->get_values(); + const CvMat* var_idx = _data->get_var_idx(); if( var_idx->cols + var_idx->rows != values->cols ) CV_Error( CV_StsBadArg, "var_idx is not supported" ); if( _data->get_missing() ) @@ -318,10 +318,10 @@ void ann_check_data_and_get_predictors( CvMLData* _data, CvMat* _inputs ) cvGetCols( values, _inputs, 0, values->cols - 1 ); else CV_Error( CV_StsBadArg, "outputs must be in the first or last column; other cases are not supported" ); -} -void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map& cls_map ) -{ - const CvMat* train_sidx = _data->get_train_sample_idx(); +} +void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map& cls_map ) +{ + const CvMat* train_sidx = _data->get_train_sample_idx(); int* train_sidx_ptr = train_sidx->data.i; const CvMat* responses = _data->get_responses(); float* responses_ptr = responses->data.fl; @@ -348,18 +348,18 @@ void ann_get_new_responses( CvMLData* _data, Mat& new_responses, map& int r = cvRound(responses_ptr[sidx*r_step]); int cidx = cls_map[r]; new_responses.ptr(sidx)[cidx] = 1; - } -} -int ann_train( CvANN_MLP* ann, CvMLData* _data, Mat& new_responses, CvANN_MLP_TrainParams _params, int flags = 0 ) -{ + } +} +int ann_train( CvANN_MLP* ann, CvMLData* _data, Mat& new_responses, CvANN_MLP_TrainParams _params, int flags = 0 ) +{ const CvMat* train_sidx = _data->get_train_sample_idx(); CvMat predictors; ann_check_data_and_get_predictors( _data, &predictors ); CvMat _new_responses = CvMat( new_responses ); - return ann->train( &predictors, &_new_responses, 0, train_sidx, _params, flags ); -} -float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map& cls_map, int type , vector *resp_labels ) -{ + return ann->train( &predictors, &_new_responses, 0, train_sidx, _params, flags ); +} +float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map& cls_map, int type , vector *resp_labels ) +{ float err = 0; const CvMat* responses = _data->get_responses(); const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx(); @@ -396,16 +396,16 @@ float ann_calc_error( CvANN_MLP* ann, CvMLData* _data, map& cls_map, i cvGetRow( &predictors, &sample, si ); ann->predict( &sample, &_output ); CvPoint best_cls = {0,0}; - cvMinMaxLoc( &_output, 0, 0, 0, &best_cls, 0 ); - int r = cvRound(responses->data.fl[si*r_step]); - CV_DbgAssert( fabs(responses->data.fl[si*r_step]-r) < FLT_EPSILON ); - r = cls_map[r]; - int d = best_cls.x == r ? 0 : 1; + cvMinMaxLoc( &_output, 0, 0, 0, &best_cls, 0 ); + int r = cvRound(responses->data.fl[si*r_step]); + CV_DbgAssert( fabs(responses->data.fl[si*r_step]-r) < FLT_EPSILON ); + r = cls_map[r]; + int d = best_cls.x == r ? 0 : 1; err += d; pred_resp[i] = (float)best_cls.x; } err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX; - return err; + return err; } // 6. dtree @@ -429,10 +429,24 @@ int str_to_boost_type( string& str ) // ---------------------------------- MLBaseTest --------------------------------------------------- -CV_MLBaseTest::CV_MLBaseTest( const char* _modelName, const char* _testName, const char* _testFuncs ) : -CvTest( _testName, _testFuncs ) -{ - modelName = _modelName; +CV_MLBaseTest::CV_MLBaseTest( const char* _modelName, const char* _testName, const char* _testFuncs ) : +CvTest( _testName, _testFuncs ) +{ + int64 seeds[] = { 0x00009fff4f9c8d52, + 0x0000a17166072c7c, + 0x0201b32115cd1f9a, + 0x0513cb37abcd1234, + 0x0001a2b3c4d5f678 + }; + + int seedCount = sizeof(seeds)/sizeof(seeds[0]); + RNG& rng = theRNG(); + + initSeed = rng.state; + + rng.state = seeds[rng(seedCount)]; + + modelName = _modelName; nbayes = 0; knearest = 0; svm = 0; @@ -441,35 +455,35 @@ CvTest( _testName, _testFuncs ) dtree = 0; boost = 0; rtrees = 0; - ertrees = 0; - if( !modelName.compare(CV_NBAYES) ) - nbayes = new CvNormalBayesClassifier; - else if( !modelName.compare(CV_KNEAREST) ) - knearest = new CvKNearest; - else if( !modelName.compare(CV_SVM) ) - svm = new CvSVM; - else if( !modelName.compare(CV_EM) ) - em = new CvEM; - else if( !modelName.compare(CV_ANN) ) - ann = new CvANN_MLP; - else if( !modelName.compare(CV_DTREE) ) - dtree = new CvDTree; - else if( !modelName.compare(CV_BOOST) ) - boost = new CvBoost; - else if( !modelName.compare(CV_RTREES) ) - rtrees = new CvRTrees; - else if( !modelName.compare(CV_ERTREES) ) - ertrees = new CvERTrees; -} + ertrees = 0; + if( !modelName.compare(CV_NBAYES) ) + nbayes = new CvNormalBayesClassifier; + else if( !modelName.compare(CV_KNEAREST) ) + knearest = new CvKNearest; + else if( !modelName.compare(CV_SVM) ) + svm = new CvSVM; + else if( !modelName.compare(CV_EM) ) + em = new CvEM; + else if( !modelName.compare(CV_ANN) ) + ann = new CvANN_MLP; + else if( !modelName.compare(CV_DTREE) ) + dtree = new CvDTree; + else if( !modelName.compare(CV_BOOST) ) + boost = new CvBoost; + else if( !modelName.compare(CV_RTREES) ) + rtrees = new CvRTrees; + else if( !modelName.compare(CV_ERTREES) ) + ertrees = new CvERTrees; +} int CV_MLBaseTest::init( CvTS* system ) { - clear(); - ts = system; - - string filename = ts->get_data_path(); - filename += get_validation_filename(); - validationFS.open( filename, FileStorage::READ ); + clear(); + ts = system; + + string filename = ts->get_data_path(); + filename += get_validation_filename(); + validationFS.open( filename, FileStorage::READ ); return read_params( *validationFS ); } @@ -495,6 +509,7 @@ CV_MLBaseTest::~CV_MLBaseTest() delete rtrees; if( ertrees ) delete ertrees; + theRNG().state = initSeed; } int CV_MLBaseTest::read_params( CvFileStorage* _fs ) @@ -519,182 +534,182 @@ int CV_MLBaseTest::read_params( CvFileStorage* _fs ) return CvTS::OK;; } -void CV_MLBaseTest::run( int start_from ) -{ - int code = CvTS::OK; - start_from = 0; - for (int i = 0; i < test_case_count; i++) - { - int temp_code = run_test_case( i ); - if (temp_code == CvTS::OK) - temp_code = validate_test_results( i ); - if (temp_code != CvTS::OK) - code = temp_code; - } - if ( test_case_count <= 0) +void CV_MLBaseTest::run( int start_from ) +{ + int code = CvTS::OK; + start_from = 0; + for (int i = 0; i < test_case_count; i++) { - ts->printf( CvTS::LOG, "validation file is not determined or not correct" ); + int temp_code = run_test_case( i ); + if (temp_code == CvTS::OK) + temp_code = validate_test_results( i ); + if (temp_code != CvTS::OK) + code = temp_code; + } + if ( test_case_count <= 0) + { + ts->printf( CvTS::LOG, "validation file is not determined or not correct" ); code = CvTS::FAIL_INVALID_TEST_DATA; - } - ts->set_failed_test_info( code ); -} - -int CV_MLBaseTest::prepare_test_case( int test_case_idx ) -{ - int trainSampleCount, respIdx; - string varTypes; - clear(); + } + ts->set_failed_test_info( code ); +} - string dataPath = ts->get_data_path(); +int CV_MLBaseTest::prepare_test_case( int test_case_idx ) +{ + int trainSampleCount, respIdx; + string varTypes; + clear(); + + string dataPath = ts->get_data_path(); if ( dataPath.empty() ) { - ts->printf( CvTS::LOG, "data path is empty" ); + ts->printf( CvTS::LOG, "data path is empty" ); return CvTS::FAIL_INVALID_TEST_DATA; - } - - string dataName = dataSetNames[test_case_idx], - filename = dataPath + dataName + ".data"; - if ( data.read_csv( filename.c_str() ) != 0) + } + + string dataName = dataSetNames[test_case_idx], + filename = dataPath + dataName + ".data"; + if ( data.read_csv( filename.c_str() ) != 0) { char msg[100]; sprintf( msg, "file %s can not be read", filename.c_str() ); - ts->printf( CvTS::LOG, msg ); + ts->printf( CvTS::LOG, msg ); return CvTS::FAIL_INVALID_TEST_DATA; - } - - FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"]; - CV_DbgAssert( !dataParamsNode.empty() ); - - CV_DbgAssert( !dataParamsNode["LS"].empty() ); - dataParamsNode["LS"] >> trainSampleCount; - CvTrainTestSplit spl( trainSampleCount ); - data.set_train_test_split( &spl ); - - CV_DbgAssert( !dataParamsNode["resp_idx"].empty() ); - dataParamsNode["resp_idx"] >> respIdx; - data.set_response_idx( respIdx ); - - CV_DbgAssert( !dataParamsNode["types"].empty() ); - dataParamsNode["types"] >> varTypes; - data.set_var_types( varTypes.c_str() ); - - return CvTS::OK; -} - -string& CV_MLBaseTest::get_validation_filename() -{ - return validationFN; -} - + } + + FileNode dataParamsNode = validationFS.getFirstTopLevelNode()["validation"][modelName][dataName]["data_params"]; + CV_DbgAssert( !dataParamsNode.empty() ); + + CV_DbgAssert( !dataParamsNode["LS"].empty() ); + dataParamsNode["LS"] >> trainSampleCount; + CvTrainTestSplit spl( trainSampleCount ); + data.set_train_test_split( &spl ); + + CV_DbgAssert( !dataParamsNode["resp_idx"].empty() ); + dataParamsNode["resp_idx"] >> respIdx; + data.set_response_idx( respIdx ); + + CV_DbgAssert( !dataParamsNode["types"].empty() ); + dataParamsNode["types"] >> varTypes; + data.set_var_types( varTypes.c_str() ); + + return CvTS::OK; +} + +string& CV_MLBaseTest::get_validation_filename() +{ + return validationFN; +} + int CV_MLBaseTest::train( int testCaseIdx ) { bool is_trained = false; - FileNode modelParamsNode = - validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]; + FileNode modelParamsNode = + validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]; - if( !modelName.compare(CV_NBAYES) ) - is_trained = nbayes_train( nbayes, &data ); - else if( !modelName.compare(CV_KNEAREST) ) - { - assert( 0 ); - //is_trained = knearest->train( &data ); - } - else if( !modelName.compare(CV_SVM) ) - { - string svm_type_str, kernel_type_str; - modelParamsNode["svm_type"] >> svm_type_str; - modelParamsNode["kernel_type"] >> kernel_type_str; - CvSVMParams params; - params.svm_type = str_to_svm_type( svm_type_str ); - params.kernel_type = str_to_svm_kernel_type( kernel_type_str ); - modelParamsNode["degree"] >> params.degree; - modelParamsNode["gamma"] >> params.gamma; - modelParamsNode["coef0"] >> params.coef0; - modelParamsNode["C"] >> params.C; - modelParamsNode["nu"] >> params.nu; - modelParamsNode["p"] >> params.p; - is_trained = svm_train( svm, &data, params ); - } - else if( !modelName.compare(CV_EM) ) - { - assert( 0 ); - } - else if( !modelName.compare(CV_ANN) ) - { - string train_method_str; - double param1, param2; - modelParamsNode["train_method"] >> train_method_str; - modelParamsNode["param1"] >> param1; - modelParamsNode["param2"] >> param2; - Mat new_responses; - ann_get_new_responses( &data, new_responses, cls_map ); - int layer_sz[] = { data.get_values()->cols - 1, 100, 100, (int)cls_map.size() }; - CvMat layer_sizes = - cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz ); - ann->create( &layer_sizes ); - is_trained = ann_train( ann, &data, new_responses, CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01), - str_to_ann_train_method(train_method_str), param1, param2) ) >= 0; - } - else if( !modelName.compare(CV_DTREE) ) - { + if( !modelName.compare(CV_NBAYES) ) + is_trained = nbayes_train( nbayes, &data ); + else if( !modelName.compare(CV_KNEAREST) ) + { + assert( 0 ); + //is_trained = knearest->train( &data ); + } + else if( !modelName.compare(CV_SVM) ) + { + string svm_type_str, kernel_type_str; + modelParamsNode["svm_type"] >> svm_type_str; + modelParamsNode["kernel_type"] >> kernel_type_str; + CvSVMParams params; + params.svm_type = str_to_svm_type( svm_type_str ); + params.kernel_type = str_to_svm_kernel_type( kernel_type_str ); + modelParamsNode["degree"] >> params.degree; + modelParamsNode["gamma"] >> params.gamma; + modelParamsNode["coef0"] >> params.coef0; + modelParamsNode["C"] >> params.C; + modelParamsNode["nu"] >> params.nu; + modelParamsNode["p"] >> params.p; + is_trained = svm_train( svm, &data, params ); + } + else if( !modelName.compare(CV_EM) ) + { + assert( 0 ); + } + else if( !modelName.compare(CV_ANN) ) + { + string train_method_str; + double param1, param2; + modelParamsNode["train_method"] >> train_method_str; + modelParamsNode["param1"] >> param1; + modelParamsNode["param2"] >> param2; + Mat new_responses; + ann_get_new_responses( &data, new_responses, cls_map ); + int layer_sz[] = { data.get_values()->cols - 1, 100, 100, (int)cls_map.size() }; + CvMat layer_sizes = + cvMat( 1, (int)(sizeof(layer_sz)/sizeof(layer_sz[0])), CV_32S, layer_sz ); + ann->create( &layer_sizes ); + is_trained = ann_train( ann, &data, new_responses, CvANN_MLP_TrainParams(cvTermCriteria(CV_TERMCRIT_ITER,300,0.01), + str_to_ann_train_method(train_method_str), param1, param2) ) >= 0; + } + else if( !modelName.compare(CV_DTREE) ) + { int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS; float REG_ACCURACY = 0; - bool USE_SURROGATE, IS_PRUNED; - modelParamsNode["max_depth"] >> MAX_DEPTH; - modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT; - modelParamsNode["use_surrogate"] >> USE_SURROGATE; - modelParamsNode["max_categories"] >> MAX_CATEGORIES; - modelParamsNode["cv_folds"] >> CV_FOLDS; - modelParamsNode["is_pruned"] >> IS_PRUNED; + bool USE_SURROGATE, IS_PRUNED; + modelParamsNode["max_depth"] >> MAX_DEPTH; + modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT; + modelParamsNode["use_surrogate"] >> USE_SURROGATE; + modelParamsNode["max_categories"] >> MAX_CATEGORIES; + modelParamsNode["cv_folds"] >> CV_FOLDS; + modelParamsNode["is_pruned"] >> IS_PRUNED; is_trained = dtree->train( &data, CvDTreeParams(MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, USE_SURROGATE, MAX_CATEGORIES, CV_FOLDS, false, IS_PRUNED, 0 )) != 0; - } - else if( !modelName.compare(CV_BOOST) ) - { + } + else if( !modelName.compare(CV_BOOST) ) + { int BOOST_TYPE, WEAK_COUNT, MAX_DEPTH; float WEIGHT_TRIM_RATE; - bool USE_SURROGATE; - string typeStr; - modelParamsNode["type"] >> typeStr; - BOOST_TYPE = str_to_boost_type( typeStr ); + bool USE_SURROGATE; + string typeStr; + modelParamsNode["type"] >> typeStr; + BOOST_TYPE = str_to_boost_type( typeStr ); modelParamsNode["weak_count"] >> WEAK_COUNT; modelParamsNode["weight_trim_rate"] >> WEIGHT_TRIM_RATE; modelParamsNode["max_depth"] >> MAX_DEPTH; modelParamsNode["use_surrogate"] >> USE_SURROGATE; is_trained = boost->train( &data, - CvBoostParams(BOOST_TYPE, WEAK_COUNT, WEIGHT_TRIM_RATE, MAX_DEPTH, USE_SURROGATE, 0) ) != 0; - } - else if( !modelName.compare(CV_RTREES) ) - { - int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM; - float REG_ACCURACY = 0, OOB_EPS = 0.0; - bool USE_SURROGATE, IS_PRUNED; - modelParamsNode["max_depth"] >> MAX_DEPTH; - modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT; - modelParamsNode["use_surrogate"] >> USE_SURROGATE; - modelParamsNode["max_categories"] >> MAX_CATEGORIES; - modelParamsNode["cv_folds"] >> CV_FOLDS; - modelParamsNode["is_pruned"] >> IS_PRUNED; - modelParamsNode["nactive_vars"] >> NACTIVE_VARS; - modelParamsNode["max_trees_num"] >> MAX_TREES_NUM; - is_trained = rtrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, - USE_SURROGATE, MAX_CATEGORIES, 0, true, // (calc_var_importance == true) <=> RF processes variable importance - NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0; - } - else if( !modelName.compare(CV_ERTREES) ) + CvBoostParams(BOOST_TYPE, WEAK_COUNT, WEIGHT_TRIM_RATE, MAX_DEPTH, USE_SURROGATE, 0) ) != 0; + } + else if( !modelName.compare(CV_RTREES) ) { int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM; float REG_ACCURACY = 0, OOB_EPS = 0.0; bool USE_SURROGATE, IS_PRUNED; - modelParamsNode["max_depth"] >> MAX_DEPTH; - modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT; - modelParamsNode["use_surrogate"] >> USE_SURROGATE; - modelParamsNode["max_categories"] >> MAX_CATEGORIES; - modelParamsNode["cv_folds"] >> CV_FOLDS; - modelParamsNode["is_pruned"] >> IS_PRUNED; - modelParamsNode["nactive_vars"] >> NACTIVE_VARS; - modelParamsNode["max_trees_num"] >> MAX_TREES_NUM; + modelParamsNode["max_depth"] >> MAX_DEPTH; + modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT; + modelParamsNode["use_surrogate"] >> USE_SURROGATE; + modelParamsNode["max_categories"] >> MAX_CATEGORIES; + modelParamsNode["cv_folds"] >> CV_FOLDS; + modelParamsNode["is_pruned"] >> IS_PRUNED; + modelParamsNode["nactive_vars"] >> NACTIVE_VARS; + modelParamsNode["max_trees_num"] >> MAX_TREES_NUM; + is_trained = rtrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, + USE_SURROGATE, MAX_CATEGORIES, 0, true, // (calc_var_importance == true) <=> RF processes variable importance + NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0; + } + else if( !modelName.compare(CV_ERTREES) ) + { + int MAX_DEPTH, MIN_SAMPLE_COUNT, MAX_CATEGORIES, CV_FOLDS, NACTIVE_VARS, MAX_TREES_NUM; + float REG_ACCURACY = 0, OOB_EPS = 0.0; + bool USE_SURROGATE, IS_PRUNED; + modelParamsNode["max_depth"] >> MAX_DEPTH; + modelParamsNode["min_sample_count"] >> MIN_SAMPLE_COUNT; + modelParamsNode["use_surrogate"] >> USE_SURROGATE; + modelParamsNode["max_categories"] >> MAX_CATEGORIES; + modelParamsNode["cv_folds"] >> CV_FOLDS; + modelParamsNode["is_pruned"] >> IS_PRUNED; + modelParamsNode["nactive_vars"] >> NACTIVE_VARS; + modelParamsNode["max_trees_num"] >> MAX_TREES_NUM; is_trained = ertrees->train( &data, CvRTParams( MAX_DEPTH, MIN_SAMPLE_COUNT, REG_ACCURACY, USE_SURROGATE, MAX_CATEGORIES, 0, false, // (calc_var_importance == true) <=> RF processes variable importance NACTIVE_VARS, MAX_TREES_NUM, OOB_EPS, CV_TERMCRIT_ITER)) != 0; @@ -711,74 +726,74 @@ int CV_MLBaseTest::train( int testCaseIdx ) float CV_MLBaseTest::get_error( int testCaseIdx, int type, vector *resp ) { float err = 0; - if( !modelName.compare(CV_NBAYES) ) - err = nbayes_calc_error( nbayes, &data, type, resp ); - else if( !modelName.compare(CV_KNEAREST) ) - { - assert( 0 ); - testCaseIdx = 0; - /*int k = 2; - validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]["k"] >> k; - err = knearest->calc_error( &data, k, type, resp );*/ - } - else if( !modelName.compare(CV_SVM) ) - err = svm_calc_error( svm, &data, type, resp ); - else if( !modelName.compare(CV_EM) ) - assert( 0 ); - else if( !modelName.compare(CV_ANN) ) - err = ann_calc_error( ann, &data, cls_map, type, resp ); - else if( !modelName.compare(CV_DTREE) ) - err = dtree->calc_error( &data, type, resp ); - else if( !modelName.compare(CV_BOOST) ) - err = boost->calc_error( &data, type, resp ); - else if( !modelName.compare(CV_RTREES) ) - err = rtrees->calc_error( &data, type, resp ); - else if( !modelName.compare(CV_ERTREES) ) + if( !modelName.compare(CV_NBAYES) ) + err = nbayes_calc_error( nbayes, &data, type, resp ); + else if( !modelName.compare(CV_KNEAREST) ) + { + assert( 0 ); + testCaseIdx = 0; + /*int k = 2; + validationFS.getFirstTopLevelNode()["validation"][modelName][dataSetNames[testCaseIdx]]["model_params"]["k"] >> k; + err = knearest->calc_error( &data, k, type, resp );*/ + } + else if( !modelName.compare(CV_SVM) ) + err = svm_calc_error( svm, &data, type, resp ); + else if( !modelName.compare(CV_EM) ) + assert( 0 ); + else if( !modelName.compare(CV_ANN) ) + err = ann_calc_error( ann, &data, cls_map, type, resp ); + else if( !modelName.compare(CV_DTREE) ) + err = dtree->calc_error( &data, type, resp ); + else if( !modelName.compare(CV_BOOST) ) + err = boost->calc_error( &data, type, resp ); + else if( !modelName.compare(CV_RTREES) ) + err = rtrees->calc_error( &data, type, resp ); + else if( !modelName.compare(CV_ERTREES) ) err = ertrees->calc_error( &data, type, resp ); return err; } void CV_MLBaseTest::save( const char* filename ) { - if( !modelName.compare(CV_NBAYES) ) - nbayes->save( filename ); - else if( !modelName.compare(CV_KNEAREST) ) - knearest->save( filename ); - else if( !modelName.compare(CV_SVM) ) - svm->save( filename ); - else if( !modelName.compare(CV_EM) ) - em->save( filename ); - else if( !modelName.compare(CV_ANN) ) - ann->save( filename ); - else if( !modelName.compare(CV_DTREE) ) - dtree->save( filename ); - else if( !modelName.compare(CV_BOOST) ) - boost->save( filename ); - else if( !modelName.compare(CV_RTREES) ) - rtrees->save( filename ); - else if( !modelName.compare(CV_ERTREES) ) + if( !modelName.compare(CV_NBAYES) ) + nbayes->save( filename ); + else if( !modelName.compare(CV_KNEAREST) ) + knearest->save( filename ); + else if( !modelName.compare(CV_SVM) ) + svm->save( filename ); + else if( !modelName.compare(CV_EM) ) + em->save( filename ); + else if( !modelName.compare(CV_ANN) ) + ann->save( filename ); + else if( !modelName.compare(CV_DTREE) ) + dtree->save( filename ); + else if( !modelName.compare(CV_BOOST) ) + boost->save( filename ); + else if( !modelName.compare(CV_RTREES) ) + rtrees->save( filename ); + else if( !modelName.compare(CV_ERTREES) ) ertrees->save( filename ); } void CV_MLBaseTest::load( const char* filename ) { - if( !modelName.compare(CV_NBAYES) ) - nbayes->load( filename ); - else if( !modelName.compare(CV_KNEAREST) ) - knearest->load( filename ); - else if( !modelName.compare(CV_SVM) ) - svm->load( filename ); - else if( !modelName.compare(CV_EM) ) - em->load( filename ); - else if( !modelName.compare(CV_ANN) ) - ann->load( filename ); - else if( !modelName.compare(CV_DTREE) ) - dtree->load( filename ); - else if( !modelName.compare(CV_BOOST) ) - boost->load( filename ); - else if( !modelName.compare(CV_RTREES) ) - rtrees->load( filename ); - else if( !modelName.compare(CV_ERTREES) ) + if( !modelName.compare(CV_NBAYES) ) + nbayes->load( filename ); + else if( !modelName.compare(CV_KNEAREST) ) + knearest->load( filename ); + else if( !modelName.compare(CV_SVM) ) + svm->load( filename ); + else if( !modelName.compare(CV_EM) ) + em->load( filename ); + else if( !modelName.compare(CV_ANN) ) + ann->load( filename ); + else if( !modelName.compare(CV_DTREE) ) + dtree->load( filename ); + else if( !modelName.compare(CV_BOOST) ) + boost->load( filename ); + else if( !modelName.compare(CV_RTREES) ) + rtrees->load( filename ); + else if( !modelName.compare(CV_ERTREES) ) ertrees->load( filename ); }