Added margin type, added tests with different scales of features.
Also fixed documentation, refactored sample.
This commit is contained in:
@@ -199,10 +199,19 @@ int str_to_svmsgd_type( String& str )
|
||||
return SVMSGD::SGD;
|
||||
if ( !str.compare("ASGD") )
|
||||
return SVMSGD::ASGD;
|
||||
CV_Error( CV_StsBadArg, "incorrect boost type string" );
|
||||
CV_Error( CV_StsBadArg, "incorrect svmsgd type string" );
|
||||
return -1;
|
||||
}
|
||||
|
||||
int str_to_margin_type( String& str )
|
||||
{
|
||||
if ( !str.compare("SOFT_MARGIN") )
|
||||
return SVMSGD::SOFT_MARGIN;
|
||||
if ( !str.compare("HARD_MARGIN") )
|
||||
return SVMSGD::HARD_MARGIN;
|
||||
CV_Error( CV_StsBadArg, "incorrect svmsgd margin type string" );
|
||||
return -1;
|
||||
}
|
||||
// ---------------------------------- MLBaseTest ---------------------------------------------------
|
||||
|
||||
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
|
||||
@@ -452,10 +461,16 @@ int CV_MLBaseTest::train( int testCaseIdx )
|
||||
{
|
||||
String svmsgdTypeStr;
|
||||
modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
|
||||
Ptr<SVMSGD> m = SVMSGD::create();
|
||||
int type = str_to_svmsgd_type( svmsgdTypeStr );
|
||||
m->setType(type);
|
||||
//m->setType(str_to_svmsgd_type( svmsgdTypeStr ));
|
||||
|
||||
Ptr<SVMSGD> m = SVMSGD::create();
|
||||
int svmsgdType = str_to_svmsgd_type( svmsgdTypeStr );
|
||||
m->setSvmsgdType(svmsgdType);
|
||||
|
||||
String marginTypeStr;
|
||||
modelParamsNode["marginType"] >> marginTypeStr;
|
||||
int marginType = str_to_margin_type( marginTypeStr );
|
||||
m->setMarginType(marginType);
|
||||
|
||||
m->setLambda(modelParamsNode["lambda"]);
|
||||
m->setGamma0(modelParamsNode["gamma0"]);
|
||||
m->setC(modelParamsNode["c"]);
|
||||
|
||||
Reference in New Issue
Block a user