Refactored SVMSGD class

This commit is contained in:
Marina Noskova
2016-01-20 12:59:44 +03:00
parent a2f0963d66
commit 40bf97c6d1
11 changed files with 980 additions and 241 deletions

View File

@@ -193,6 +193,16 @@ int str_to_boost_type( String& str )
// 8. rtrees
// 9. ertrees
int str_to_svmsgd_type( String& str )
{
if ( !str.compare("SGD") )
return SVMSGD::SGD;
if ( !str.compare("ASGD") )
return SVMSGD::ASGD;
CV_Error( CV_StsBadArg, "incorrect boost type string" );
return -1;
}
// ---------------------------------- MLBaseTest ---------------------------------------------------
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
@@ -248,7 +258,9 @@ void CV_MLBaseTest::run( int )
{
string filename = ts->get_data_path();
filename += get_validation_filename();
validationFS.open( filename, FileStorage::READ );
read_params( *validationFS );
int code = cvtest::TS::OK;
@@ -436,6 +448,21 @@ int CV_MLBaseTest::train( int testCaseIdx )
model = m;
}
else if( modelName == CV_SVMSGD )
{
String svmsgdTypeStr;
modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
Ptr<SVMSGD> m = SVMSGD::create();
int type = str_to_svmsgd_type( svmsgdTypeStr );
m->setType(type);
//m->setType(str_to_svmsgd_type( svmsgdTypeStr ));
m->setLambda(modelParamsNode["lambda"]);
m->setGamma0(modelParamsNode["gamma0"]);
m->setC(modelParamsNode["c"]);
m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001));
model = m;
}
if( !model.empty() )
is_trained = model->train(data, 0);
@@ -457,7 +484,7 @@ float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
else if( modelName == CV_ANN )
err = ann_calc_error( model, data, cls_map, type, resp );
else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST )
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST || modelName == CV_SVMSGD )
err = model->calcError( data, true, _resp );
if( !_resp.empty() && resp )
_resp.convertTo(*resp, CV_32F);
@@ -485,6 +512,8 @@ void CV_MLBaseTest::load( const char* filename )
model = Algorithm::load<Boost>( filename );
else if( modelName == CV_RTREES )
model = Algorithm::load<RTrees>( filename );
else if( modelName == CV_SVMSGD )
model = Algorithm::load<SVMSGD>( filename );
else
CV_Error( CV_StsNotImplemented, "invalid stat model name");
}