Refactored SVMSGD class
This commit is contained in:
@@ -193,6 +193,16 @@ int str_to_boost_type( String& str )
|
||||
// 8. rtrees
|
||||
// 9. ertrees
|
||||
|
||||
int str_to_svmsgd_type( String& str )
|
||||
{
|
||||
if ( !str.compare("SGD") )
|
||||
return SVMSGD::SGD;
|
||||
if ( !str.compare("ASGD") )
|
||||
return SVMSGD::ASGD;
|
||||
CV_Error( CV_StsBadArg, "incorrect boost type string" );
|
||||
return -1;
|
||||
}
|
||||
|
||||
// ---------------------------------- MLBaseTest ---------------------------------------------------
|
||||
|
||||
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
|
||||
@@ -248,7 +258,9 @@ void CV_MLBaseTest::run( int )
|
||||
{
|
||||
string filename = ts->get_data_path();
|
||||
filename += get_validation_filename();
|
||||
|
||||
validationFS.open( filename, FileStorage::READ );
|
||||
|
||||
read_params( *validationFS );
|
||||
|
||||
int code = cvtest::TS::OK;
|
||||
@@ -436,6 +448,21 @@ int CV_MLBaseTest::train( int testCaseIdx )
|
||||
model = m;
|
||||
}
|
||||
|
||||
else if( modelName == CV_SVMSGD )
|
||||
{
|
||||
String svmsgdTypeStr;
|
||||
modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
|
||||
Ptr<SVMSGD> m = SVMSGD::create();
|
||||
int type = str_to_svmsgd_type( svmsgdTypeStr );
|
||||
m->setType(type);
|
||||
//m->setType(str_to_svmsgd_type( svmsgdTypeStr ));
|
||||
m->setLambda(modelParamsNode["lambda"]);
|
||||
m->setGamma0(modelParamsNode["gamma0"]);
|
||||
m->setC(modelParamsNode["c"]);
|
||||
m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001));
|
||||
model = m;
|
||||
}
|
||||
|
||||
if( !model.empty() )
|
||||
is_trained = model->train(data, 0);
|
||||
|
||||
@@ -457,7 +484,7 @@ float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
|
||||
else if( modelName == CV_ANN )
|
||||
err = ann_calc_error( model, data, cls_map, type, resp );
|
||||
else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
|
||||
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST )
|
||||
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST || modelName == CV_SVMSGD )
|
||||
err = model->calcError( data, true, _resp );
|
||||
if( !_resp.empty() && resp )
|
||||
_resp.convertTo(*resp, CV_32F);
|
||||
@@ -485,6 +512,8 @@ void CV_MLBaseTest::load( const char* filename )
|
||||
model = Algorithm::load<Boost>( filename );
|
||||
else if( modelName == CV_RTREES )
|
||||
model = Algorithm::load<RTrees>( filename );
|
||||
else if( modelName == CV_SVMSGD )
|
||||
model = Algorithm::load<SVMSGD>( filename );
|
||||
else
|
||||
CV_Error( CV_StsNotImplemented, "invalid stat model name");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user