Deleted default value for parameters in docs.
Added some asserts.
This commit is contained in:
@@ -97,7 +97,7 @@ public:
|
||||
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
|
||||
|
||||
private:
|
||||
void updateWeights(InputArray sample, bool isPositive, float stepSize, Mat &weights);
|
||||
void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights);
|
||||
|
||||
void writeParams( FileStorage &fs ) const;
|
||||
|
||||
@@ -111,8 +111,6 @@ private:
|
||||
|
||||
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
|
||||
|
||||
|
||||
|
||||
// Vector with SVM weights
|
||||
Mat weights_;
|
||||
float shift_;
|
||||
@@ -263,11 +261,12 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
|
||||
RNG rng(0);
|
||||
|
||||
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS) && (trainResponses.type() == CV_32FC1));
|
||||
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
|
||||
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
|
||||
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
|
||||
|
||||
double err = DBL_MAX;
|
||||
CV_Assert (trainResponses.type() == CV_32FC1);
|
||||
// Stochastic gradient descent SVM
|
||||
for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
|
||||
{
|
||||
@@ -288,8 +287,8 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
}
|
||||
else
|
||||
{
|
||||
err = norm(extendedWeights - previousWeights);
|
||||
extendedWeights.copyTo(previousWeights);
|
||||
err = norm(extendedWeights - previousWeights);
|
||||
extendedWeights.copyTo(previousWeights);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -316,7 +315,6 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
|
||||
{
|
||||
float result = 0;
|
||||
@@ -417,17 +415,6 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||
fs << "iterations" << params.termCrit.maxCount;
|
||||
fs << "}";
|
||||
}
|
||||
|
||||
void SVMSGDImpl::read(const FileNode& fn)
|
||||
{
|
||||
clear();
|
||||
|
||||
readParams(fn);
|
||||
|
||||
fn["weights"] >> weights_;
|
||||
fn["shift"] >> shift_;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::readParams( const FileNode& fn )
|
||||
{
|
||||
String svmsgdTypeStr = (String)fn["svmsgdType"];
|
||||
@@ -443,7 +430,7 @@ void SVMSGDImpl::readParams( const FileNode& fn )
|
||||
String marginTypeStr = (String)fn["marginType"];
|
||||
int marginType =
|
||||
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
|
||||
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
|
||||
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
|
||||
|
||||
if( marginType < 0 )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
|
||||
@@ -460,16 +447,22 @@ void SVMSGDImpl::readParams( const FileNode& fn )
|
||||
params.stepDecreasingPower = (float)fn["stepDecreasingPower"];
|
||||
|
||||
FileNode tcnode = fn["term_criteria"];
|
||||
if( !tcnode.empty() )
|
||||
{
|
||||
params.termCrit.epsilon = (double)tcnode["epsilon"];
|
||||
params.termCrit.maxCount = (int)tcnode["iterations"];
|
||||
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
|
||||
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
|
||||
}
|
||||
else
|
||||
params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 100000, FLT_EPSILON );
|
||||
CV_Assert(!tcnode.empty());
|
||||
params.termCrit.epsilon = (double)tcnode["epsilon"];
|
||||
params.termCrit.maxCount = (int)tcnode["iterations"];
|
||||
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
|
||||
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
|
||||
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS));
|
||||
}
|
||||
|
||||
void SVMSGDImpl::read(const FileNode& fn)
|
||||
{
|
||||
clear();
|
||||
|
||||
readParams(fn);
|
||||
|
||||
fn["weights"] >> weights_;
|
||||
fn["shift"] >> shift_;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::clear()
|
||||
@@ -492,7 +485,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
||||
case SGD:
|
||||
params.svmsgdType = SGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||
params.marginRegularization = 0.0001f;
|
||||
params.initialStepSize = 0.05f;
|
||||
params.stepDecreasingPower = 1.f;
|
||||
@@ -502,7 +495,7 @@ void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
||||
case ASGD:
|
||||
params.svmsgdType = ASGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||
params.marginRegularization = 0.00001f;
|
||||
params.initialStepSize = 0.05f;
|
||||
params.stepDecreasingPower = 0.75f;
|
||||
|
||||
Reference in New Issue
Block a user