Deleted illegal type values.
This commit is contained in:
parent
ff54952769
commit
02cd8cf039
@ -1588,7 +1588,6 @@ public:
|
|||||||
ASGD is often the preferable choice. */
|
ASGD is often the preferable choice. */
|
||||||
enum SvmsgdType
|
enum SvmsgdType
|
||||||
{
|
{
|
||||||
ILLEGAL_SVMSGD_TYPE,
|
|
||||||
SGD, //!< Stochastic Gradient Descent
|
SGD, //!< Stochastic Gradient Descent
|
||||||
ASGD //!< Average Stochastic Gradient Descent
|
ASGD //!< Average Stochastic Gradient Descent
|
||||||
};
|
};
|
||||||
@ -1596,7 +1595,6 @@ public:
|
|||||||
/** Margin type.*/
|
/** Margin type.*/
|
||||||
enum MarginType
|
enum MarginType
|
||||||
{
|
{
|
||||||
ILLEGAL_MARGIN_TYPE,
|
|
||||||
SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers.
|
SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers.
|
||||||
HARD_MARGIN //!< More accurate for the case of linearly separable sets.
|
HARD_MARGIN //!< More accurate for the case of linearly separable sets.
|
||||||
};
|
};
|
||||||
|
@ -89,14 +89,8 @@ public:
|
|||||||
|
|
||||||
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
|
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
|
||||||
|
|
||||||
virtual int getSvmsgdType() const;
|
CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType)
|
||||||
|
CV_IMPL_PROPERTY(int, MarginType, params.marginType)
|
||||||
virtual void setSvmsgdType(int svmsgdType);
|
|
||||||
|
|
||||||
virtual int getMarginType() const;
|
|
||||||
|
|
||||||
virtual void setMarginType(int marginType);
|
|
||||||
|
|
||||||
CV_IMPL_PROPERTY(float, Lambda, params.lambda)
|
CV_IMPL_PROPERTY(float, Lambda, params.lambda)
|
||||||
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
|
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
|
||||||
CV_IMPL_PROPERTY(float, C, params.c)
|
CV_IMPL_PROPERTY(float, C, params.c)
|
||||||
@ -132,8 +126,8 @@ private:
|
|||||||
float gamma0; //learning rate
|
float gamma0; //learning rate
|
||||||
float c;
|
float c;
|
||||||
TermCriteria termCrit;
|
TermCriteria termCrit;
|
||||||
SvmsgdType svmsgdType;
|
int svmsgdType;
|
||||||
MarginType marginType;
|
int marginType;
|
||||||
};
|
};
|
||||||
|
|
||||||
SVMSGDParams params;
|
SVMSGDParams params;
|
||||||
@ -148,9 +142,9 @@ std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
|
|||||||
{
|
{
|
||||||
CV_Assert(responses.cols == 1 || responses.rows == 1);
|
CV_Assert(responses.cols == 1 || responses.rows == 1);
|
||||||
std::pair<bool,bool> emptyInClasses(true, true);
|
std::pair<bool,bool> emptyInClasses(true, true);
|
||||||
int limit_index = responses.rows;
|
int limitIndex = responses.rows;
|
||||||
|
|
||||||
for(int index = 0; index < limit_index; index++)
|
for(int index = 0; index < limitIndex; index++)
|
||||||
{
|
{
|
||||||
if (isPositive(responses.at<float>(index)))
|
if (isPositive(responses.at<float>(index)))
|
||||||
emptyInClasses.first = false;
|
emptyInClasses.first = false;
|
||||||
@ -276,9 +270,9 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
|||||||
int extendedTrainSamplesCount = extendedTrainSamples.rows;
|
int extendedTrainSamplesCount = extendedTrainSamples.rows;
|
||||||
int extendedFeatureCount = extendedTrainSamples.cols;
|
int extendedFeatureCount = extendedTrainSamples.cols;
|
||||||
|
|
||||||
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); // Initialize extendedWeights vector with zeros
|
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||||
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); //extendedWeights vector for calculating terminal criterion
|
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||||
Mat averageExtendedWeights; //average extendedWeights vector for ASGD model
|
Mat averageExtendedWeights;
|
||||||
if (params.svmsgdType == ASGD)
|
if (params.svmsgdType == ASGD)
|
||||||
{
|
{
|
||||||
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||||
@ -407,10 +401,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
|||||||
case ASGD:
|
case ASGD:
|
||||||
SvmsgdTypeStr = "ASGD";
|
SvmsgdTypeStr = "ASGD";
|
||||||
break;
|
break;
|
||||||
case ILLEGAL_SVMSGD_TYPE:
|
|
||||||
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
|
|
||||||
default:
|
default:
|
||||||
std::cout << "params.svmsgdType isn't initialized" << std::endl;
|
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
|
||||||
}
|
}
|
||||||
|
|
||||||
fs << "svmsgdType" << SvmsgdTypeStr;
|
fs << "svmsgdType" << SvmsgdTypeStr;
|
||||||
@ -425,10 +417,8 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
|||||||
case HARD_MARGIN:
|
case HARD_MARGIN:
|
||||||
marginTypeStr = "HARD_MARGIN";
|
marginTypeStr = "HARD_MARGIN";
|
||||||
break;
|
break;
|
||||||
case ILLEGAL_MARGIN_TYPE:
|
|
||||||
marginTypeStr = format("Unknown_%d", params.marginType);
|
|
||||||
default:
|
default:
|
||||||
std::cout << "params.marginType isn't initialized" << std::endl;
|
marginTypeStr = format("Unknown_%d", params.marginType);
|
||||||
}
|
}
|
||||||
|
|
||||||
fs << "marginType" << marginTypeStr;
|
fs << "marginType" << marginTypeStr;
|
||||||
@ -458,21 +448,21 @@ void SVMSGDImpl::read(const FileNode& fn)
|
|||||||
void SVMSGDImpl::readParams( const FileNode& fn )
|
void SVMSGDImpl::readParams( const FileNode& fn )
|
||||||
{
|
{
|
||||||
String svmsgdTypeStr = (String)fn["svmsgdType"];
|
String svmsgdTypeStr = (String)fn["svmsgdType"];
|
||||||
SvmsgdType svmsgdType =
|
int svmsgdType =
|
||||||
svmsgdTypeStr == "SGD" ? SGD :
|
svmsgdTypeStr == "SGD" ? SGD :
|
||||||
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_SVMSGD_TYPE;
|
svmsgdTypeStr == "ASGD" ? ASGD : -1;
|
||||||
|
|
||||||
if( svmsgdType == ILLEGAL_SVMSGD_TYPE )
|
if( svmsgdType < 0 )
|
||||||
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
|
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
|
||||||
|
|
||||||
params.svmsgdType = svmsgdType;
|
params.svmsgdType = svmsgdType;
|
||||||
|
|
||||||
String marginTypeStr = (String)fn["marginType"];
|
String marginTypeStr = (String)fn["marginType"];
|
||||||
MarginType marginType =
|
int marginType =
|
||||||
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
|
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
|
||||||
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
|
||||||
|
|
||||||
if( marginType == ILLEGAL_MARGIN_TYPE )
|
if( marginType < 0 )
|
||||||
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
|
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
|
||||||
|
|
||||||
params.marginType = marginType;
|
params.marginType = marginType;
|
||||||
@ -510,8 +500,8 @@ SVMSGDImpl::SVMSGDImpl()
|
|||||||
{
|
{
|
||||||
clear();
|
clear();
|
||||||
|
|
||||||
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
|
params.svmsgdType = -1;
|
||||||
params.marginType = ILLEGAL_MARGIN_TYPE;
|
params.marginType = -1;
|
||||||
|
|
||||||
// Parameters for learning
|
// Parameters for learning
|
||||||
params.lambda = 0; // regularization
|
params.lambda = 0; // regularization
|
||||||
@ -529,7 +519,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
|||||||
case SGD:
|
case SGD:
|
||||||
params.svmsgdType = SGD;
|
params.svmsgdType = SGD;
|
||||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||||
params.lambda = 0.0001f;
|
params.lambda = 0.0001f;
|
||||||
params.gamma0 = 0.05f;
|
params.gamma0 = 0.05f;
|
||||||
params.c = 1.f;
|
params.c = 1.f;
|
||||||
@ -539,7 +529,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
|||||||
case ASGD:
|
case ASGD:
|
||||||
params.svmsgdType = ASGD;
|
params.svmsgdType = ASGD;
|
||||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||||
params.lambda = 0.00001f;
|
params.lambda = 0.00001f;
|
||||||
params.gamma0 = 0.05f;
|
params.gamma0 = 0.05f;
|
||||||
params.c = 0.75f;
|
params.c = 0.75f;
|
||||||
@ -550,45 +540,5 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
|||||||
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
|
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SVMSGDImpl::setSvmsgdType(int type)
|
|
||||||
{
|
|
||||||
switch (type)
|
|
||||||
{
|
|
||||||
case SGD:
|
|
||||||
params.svmsgdType = SGD;
|
|
||||||
break;
|
|
||||||
case ASGD:
|
|
||||||
params.svmsgdType = ASGD;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int SVMSGDImpl::getSvmsgdType() const
|
|
||||||
{
|
|
||||||
return params.svmsgdType;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SVMSGDImpl::setMarginType(int type)
|
|
||||||
{
|
|
||||||
switch (type)
|
|
||||||
{
|
|
||||||
case HARD_MARGIN:
|
|
||||||
params.marginType = HARD_MARGIN;
|
|
||||||
break;
|
|
||||||
case SOFT_MARGIN:
|
|
||||||
params.marginType = SOFT_MARGIN;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
params.marginType = ILLEGAL_MARGIN_TYPE;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int SVMSGDImpl::getMarginType() const
|
|
||||||
{
|
|
||||||
return params.marginType;
|
|
||||||
}
|
|
||||||
} //ml
|
} //ml
|
||||||
} //cv
|
} //cv
|
||||||
|
Loading…
Reference in New Issue
Block a user