Deleted illegal type values.
This commit is contained in:
parent
ff54952769
commit
02cd8cf039
@ -1588,7 +1588,6 @@ public:
|
||||
ASGD is often the preferable choice. */
|
||||
enum SvmsgdType
|
||||
{
|
||||
ILLEGAL_SVMSGD_TYPE,
|
||||
SGD, //!< Stochastic Gradient Descent
|
||||
ASGD //!< Average Stochastic Gradient Descent
|
||||
};
|
||||
@ -1596,7 +1595,6 @@ public:
|
||||
/** Margin type.*/
|
||||
enum MarginType
|
||||
{
|
||||
ILLEGAL_MARGIN_TYPE,
|
||||
SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers.
|
||||
HARD_MARGIN //!< More accurate for the case of linearly separable sets.
|
||||
};
|
||||
|
@ -89,14 +89,8 @@ public:
|
||||
|
||||
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
|
||||
|
||||
virtual int getSvmsgdType() const;
|
||||
|
||||
virtual void setSvmsgdType(int svmsgdType);
|
||||
|
||||
virtual int getMarginType() const;
|
||||
|
||||
virtual void setMarginType(int marginType);
|
||||
|
||||
CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType)
|
||||
CV_IMPL_PROPERTY(int, MarginType, params.marginType)
|
||||
CV_IMPL_PROPERTY(float, Lambda, params.lambda)
|
||||
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
|
||||
CV_IMPL_PROPERTY(float, C, params.c)
|
||||
@ -132,8 +126,8 @@ private:
|
||||
float gamma0; //learning rate
|
||||
float c;
|
||||
TermCriteria termCrit;
|
||||
SvmsgdType svmsgdType;
|
||||
MarginType marginType;
|
||||
int svmsgdType;
|
||||
int marginType;
|
||||
};
|
||||
|
||||
SVMSGDParams params;
|
||||
@ -148,9 +142,9 @@ std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
|
||||
{
|
||||
CV_Assert(responses.cols == 1 || responses.rows == 1);
|
||||
std::pair<bool,bool> emptyInClasses(true, true);
|
||||
int limit_index = responses.rows;
|
||||
int limitIndex = responses.rows;
|
||||
|
||||
for(int index = 0; index < limit_index; index++)
|
||||
for(int index = 0; index < limitIndex; index++)
|
||||
{
|
||||
if (isPositive(responses.at<float>(index)))
|
||||
emptyInClasses.first = false;
|
||||
@ -276,9 +270,9 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
int extendedTrainSamplesCount = extendedTrainSamples.rows;
|
||||
int extendedFeatureCount = extendedTrainSamples.cols;
|
||||
|
||||
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); // Initialize extendedWeights vector with zeros
|
||||
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); //extendedWeights vector for calculating terminal criterion
|
||||
Mat averageExtendedWeights; //average extendedWeights vector for ASGD model
|
||||
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||
Mat averageExtendedWeights;
|
||||
if (params.svmsgdType == ASGD)
|
||||
{
|
||||
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||
@ -407,10 +401,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||
case ASGD:
|
||||
SvmsgdTypeStr = "ASGD";
|
||||
break;
|
||||
case ILLEGAL_SVMSGD_TYPE:
|
||||
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
|
||||
default:
|
||||
std::cout << "params.svmsgdType isn't initialized" << std::endl;
|
||||
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
|
||||
}
|
||||
|
||||
fs << "svmsgdType" << SvmsgdTypeStr;
|
||||
@ -425,10 +417,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||
case HARD_MARGIN:
|
||||
marginTypeStr = "HARD_MARGIN";
|
||||
break;
|
||||
case ILLEGAL_MARGIN_TYPE:
|
||||
marginTypeStr = format("Unknown_%d", params.marginType);
|
||||
default:
|
||||
std::cout << "params.marginType isn't initialized" << std::endl;
|
||||
marginTypeStr = format("Unknown_%d", params.marginType);
|
||||
}
|
||||
|
||||
fs << "marginType" << marginTypeStr;
|
||||
@ -458,21 +448,21 @@ void SVMSGDImpl::read(const FileNode& fn)
|
||||
void SVMSGDImpl::readParams( const FileNode& fn )
|
||||
{
|
||||
String svmsgdTypeStr = (String)fn["svmsgdType"];
|
||||
SvmsgdType svmsgdType =
|
||||
int svmsgdType =
|
||||
svmsgdTypeStr == "SGD" ? SGD :
|
||||
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_SVMSGD_TYPE;
|
||||
svmsgdTypeStr == "ASGD" ? ASGD : -1;
|
||||
|
||||
if( svmsgdType == ILLEGAL_SVMSGD_TYPE )
|
||||
if( svmsgdType < 0 )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
|
||||
|
||||
params.svmsgdType = svmsgdType;
|
||||
|
||||
String marginTypeStr = (String)fn["marginType"];
|
||||
MarginType marginType =
|
||||
int marginType =
|
||||
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
|
||||
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
||||
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
|
||||
|
||||
if( marginType == ILLEGAL_MARGIN_TYPE )
|
||||
if( marginType < 0 )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
|
||||
|
||||
params.marginType = marginType;
|
||||
@ -510,8 +500,8 @@ SVMSGDImpl::SVMSGDImpl()
|
||||
{
|
||||
clear();
|
||||
|
||||
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
|
||||
params.marginType = ILLEGAL_MARGIN_TYPE;
|
||||
params.svmsgdType = -1;
|
||||
params.marginType = -1;
|
||||
|
||||
// Parameters for learning
|
||||
params.lambda = 0; // regularization
|
||||
@ -529,7 +519,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
||||
case SGD:
|
||||
params.svmsgdType = SGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||
params.lambda = 0.0001f;
|
||||
params.gamma0 = 0.05f;
|
||||
params.c = 1.f;
|
||||
@ -539,7 +529,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
||||
case ASGD:
|
||||
params.svmsgdType = ASGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||
params.lambda = 0.00001f;
|
||||
params.gamma0 = 0.05f;
|
||||
params.c = 0.75f;
|
||||
@ -550,45 +540,5 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
||||
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
|
||||
}
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setSvmsgdType(int type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case SGD:
|
||||
params.svmsgdType = SGD;
|
||||
break;
|
||||
case ASGD:
|
||||
params.svmsgdType = ASGD;
|
||||
break;
|
||||
default:
|
||||
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
|
||||
}
|
||||
}
|
||||
|
||||
int SVMSGDImpl::getSvmsgdType() const
|
||||
{
|
||||
return params.svmsgdType;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setMarginType(int type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case HARD_MARGIN:
|
||||
params.marginType = HARD_MARGIN;
|
||||
break;
|
||||
case SOFT_MARGIN:
|
||||
params.marginType = SOFT_MARGIN;
|
||||
break;
|
||||
default:
|
||||
params.marginType = ILLEGAL_MARGIN_TYPE;
|
||||
}
|
||||
}
|
||||
|
||||
int SVMSGDImpl::getMarginType() const
|
||||
{
|
||||
return params.marginType;
|
||||
}
|
||||
} //ml
|
||||
} //cv
|
||||
|
Loading…
Reference in New Issue
Block a user