Deleted illegal type values.

This commit is contained in:
Marina Noskova 2016-02-15 15:09:59 +03:00
parent ff54952769
commit 02cd8cf039
2 changed files with 21 additions and 73 deletions

View File

@ -1588,7 +1588,6 @@ public:
ASGD is often the preferable choice. */ ASGD is often the preferable choice. */
enum SvmsgdType enum SvmsgdType
{ {
ILLEGAL_SVMSGD_TYPE,
SGD, //!< Stochastic Gradient Descent SGD, //!< Stochastic Gradient Descent
ASGD //!< Average Stochastic Gradient Descent ASGD //!< Average Stochastic Gradient Descent
}; };
@ -1596,7 +1595,6 @@ public:
/** Margin type.*/ /** Margin type.*/
enum MarginType enum MarginType
{ {
ILLEGAL_MARGIN_TYPE,
SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers. SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers.
HARD_MARGIN //!< More accurate for the case of linearly separable sets. HARD_MARGIN //!< More accurate for the case of linearly separable sets.
}; };

View File

@ -89,14 +89,8 @@ public:
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN); virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
virtual int getSvmsgdType() const; CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType)
CV_IMPL_PROPERTY(int, MarginType, params.marginType)
virtual void setSvmsgdType(int svmsgdType);
virtual int getMarginType() const;
virtual void setMarginType(int marginType);
CV_IMPL_PROPERTY(float, Lambda, params.lambda) CV_IMPL_PROPERTY(float, Lambda, params.lambda)
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0) CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
CV_IMPL_PROPERTY(float, C, params.c) CV_IMPL_PROPERTY(float, C, params.c)
@ -132,8 +126,8 @@ private:
float gamma0; //learning rate float gamma0; //learning rate
float c; float c;
TermCriteria termCrit; TermCriteria termCrit;
SvmsgdType svmsgdType; int svmsgdType;
MarginType marginType; int marginType;
}; };
SVMSGDParams params; SVMSGDParams params;
@ -148,9 +142,9 @@ std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
{ {
CV_Assert(responses.cols == 1 || responses.rows == 1); CV_Assert(responses.cols == 1 || responses.rows == 1);
std::pair<bool,bool> emptyInClasses(true, true); std::pair<bool,bool> emptyInClasses(true, true);
int limit_index = responses.rows; int limitIndex = responses.rows;
for(int index = 0; index < limit_index; index++) for(int index = 0; index < limitIndex; index++)
{ {
if (isPositive(responses.at<float>(index))) if (isPositive(responses.at<float>(index)))
emptyInClasses.first = false; emptyInClasses.first = false;
@ -276,9 +270,9 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
int extendedTrainSamplesCount = extendedTrainSamples.rows; int extendedTrainSamplesCount = extendedTrainSamples.rows;
int extendedFeatureCount = extendedTrainSamples.cols; int extendedFeatureCount = extendedTrainSamples.cols;
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); // Initialize extendedWeights vector with zeros Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); //extendedWeights vector for calculating terminal criterion Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
Mat averageExtendedWeights; //average extendedWeights vector for ASGD model Mat averageExtendedWeights;
if (params.svmsgdType == ASGD) if (params.svmsgdType == ASGD)
{ {
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
@ -407,10 +401,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
case ASGD: case ASGD:
SvmsgdTypeStr = "ASGD"; SvmsgdTypeStr = "ASGD";
break; break;
case ILLEGAL_SVMSGD_TYPE:
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
default: default:
std::cout << "params.svmsgdType isn't initialized" << std::endl; SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
} }
fs << "svmsgdType" << SvmsgdTypeStr; fs << "svmsgdType" << SvmsgdTypeStr;
@ -425,10 +417,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
case HARD_MARGIN: case HARD_MARGIN:
marginTypeStr = "HARD_MARGIN"; marginTypeStr = "HARD_MARGIN";
break; break;
case ILLEGAL_MARGIN_TYPE:
marginTypeStr = format("Unknown_%d", params.marginType);
default: default:
std::cout << "params.marginType isn't initialized" << std::endl; marginTypeStr = format("Unknown_%d", params.marginType);
} }
fs << "marginType" << marginTypeStr; fs << "marginType" << marginTypeStr;
@ -458,21 +448,21 @@ void SVMSGDImpl::read(const FileNode& fn)
void SVMSGDImpl::readParams( const FileNode& fn ) void SVMSGDImpl::readParams( const FileNode& fn )
{ {
String svmsgdTypeStr = (String)fn["svmsgdType"]; String svmsgdTypeStr = (String)fn["svmsgdType"];
SvmsgdType svmsgdType = int svmsgdType =
svmsgdTypeStr == "SGD" ? SGD : svmsgdTypeStr == "SGD" ? SGD :
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_SVMSGD_TYPE; svmsgdTypeStr == "ASGD" ? ASGD : -1;
if( svmsgdType == ILLEGAL_SVMSGD_TYPE ) if( svmsgdType < 0 )
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" ); CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
params.svmsgdType = svmsgdType; params.svmsgdType = svmsgdType;
String marginTypeStr = (String)fn["marginType"]; String marginTypeStr = (String)fn["marginType"];
MarginType marginType = int marginType =
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN : marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
if( marginType == ILLEGAL_MARGIN_TYPE ) if( marginType < 0 )
CV_Error( CV_StsParseError, "Missing or invalid margin type" ); CV_Error( CV_StsParseError, "Missing or invalid margin type" );
params.marginType = marginType; params.marginType = marginType;
@ -510,8 +500,8 @@ SVMSGDImpl::SVMSGDImpl()
{ {
clear(); clear();
params.svmsgdType = ILLEGAL_SVMSGD_TYPE; params.svmsgdType = -1;
params.marginType = ILLEGAL_MARGIN_TYPE; params.marginType = -1;
// Parameters for learning // Parameters for learning
params.lambda = 0; // regularization params.lambda = 0; // regularization
@ -529,7 +519,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
case SGD: case SGD:
params.svmsgdType = SGD; params.svmsgdType = SGD;
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
params.lambda = 0.0001f; params.lambda = 0.0001f;
params.gamma0 = 0.05f; params.gamma0 = 0.05f;
params.c = 1.f; params.c = 1.f;
@ -539,7 +529,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
case ASGD: case ASGD:
params.svmsgdType = ASGD; params.svmsgdType = ASGD;
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN : params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE; (marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
params.lambda = 0.00001f; params.lambda = 0.00001f;
params.gamma0 = 0.05f; params.gamma0 = 0.05f;
params.c = 0.75f; params.c = 0.75f;
@ -550,45 +540,5 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" ); CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
} }
} }
void SVMSGDImpl::setSvmsgdType(int type)
{
switch (type)
{
case SGD:
params.svmsgdType = SGD;
break;
case ASGD:
params.svmsgdType = ASGD;
break;
default:
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
}
}
int SVMSGDImpl::getSvmsgdType() const
{
return params.svmsgdType;
}
void SVMSGDImpl::setMarginType(int type)
{
switch (type)
{
case HARD_MARGIN:
params.marginType = HARD_MARGIN;
break;
case SOFT_MARGIN:
params.marginType = SOFT_MARGIN;
break;
default:
params.marginType = ILLEGAL_MARGIN_TYPE;
}
}
int SVMSGDImpl::getMarginType() const
{
return params.marginType;
}
} //ml } //ml
} //cv } //cv