Delete function areClassesEmpty().
This commit is contained in:
@@ -99,8 +99,6 @@ public:
|
||||
private:
|
||||
void updateWeights(InputArray sample, bool isPositive, float stepSize, Mat &weights);
|
||||
|
||||
std::pair<bool,bool> areClassesEmpty(Mat responses);
|
||||
|
||||
void writeParams( FileStorage &fs ) const;
|
||||
|
||||
void readParams( const FileNode &fn );
|
||||
@@ -138,26 +136,6 @@ Ptr<SVMSGD> SVMSGD::create()
|
||||
return makePtr<SVMSGDImpl>();
|
||||
}
|
||||
|
||||
std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
|
||||
{
|
||||
CV_Assert(responses.cols == 1 || responses.rows == 1);
|
||||
std::pair<bool,bool> emptyInClasses(true, true);
|
||||
int limitIndex = responses.rows;
|
||||
|
||||
for(int index = 0; index < limitIndex; index++)
|
||||
{
|
||||
if (isPositive(responses.at<float>(index)))
|
||||
emptyInClasses.first = false;
|
||||
else
|
||||
emptyInClasses.second = false;
|
||||
|
||||
if (!emptyInClasses.first && ! emptyInClasses.second)
|
||||
break;
|
||||
}
|
||||
|
||||
return emptyInClasses;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
|
||||
{
|
||||
int featuresCount = samples.cols;
|
||||
@@ -248,16 +226,20 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
int featureCount = trainSamples.cols;
|
||||
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
|
||||
|
||||
std::pair<bool,bool> areEmpty = areClassesEmpty(trainResponses);
|
||||
CV_Assert(trainResponses.rows == trainSamples.rows);
|
||||
|
||||
if ( areEmpty.first && areEmpty.second )
|
||||
if (trainResponses.empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if ( areEmpty.first || areEmpty.second )
|
||||
|
||||
int positiveCount = countNonZero(trainResponses >= 0);
|
||||
int negativeCount = countNonZero(trainResponses < 0);
|
||||
|
||||
if ( positiveCount <= 0 || negativeCount <= 0 )
|
||||
{
|
||||
weights_ = Mat::zeros(1, featureCount, CV_32F);
|
||||
shift_ = areEmpty.first ? -1.f : 1.f;
|
||||
shift_ = (positiveCount > 0) ? 1.f : -1.f;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -340,7 +322,7 @@ float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) cons
|
||||
int nSamples = samples.rows;
|
||||
cv::Mat results;
|
||||
|
||||
CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32F );
|
||||
CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32FC1);
|
||||
|
||||
if( _results.needed() )
|
||||
{
|
||||
@@ -498,17 +480,7 @@ void SVMSGDImpl::clear()
|
||||
SVMSGDImpl::SVMSGDImpl()
|
||||
{
|
||||
clear();
|
||||
|
||||
params.svmsgdType = -1;
|
||||
params.marginType = -1;
|
||||
|
||||
// Parameters for learning
|
||||
params.marginRegularization = 0; // regularization
|
||||
params.initialStepSize = 0; // learning rate (ideally should be large at beginning and decay each iteration)
|
||||
params.stepDecreasingPower = 0;
|
||||
|
||||
TermCriteria _termCrit(TermCriteria::COUNT + TermCriteria::EPS, 0, 0);
|
||||
params.termCrit = _termCrit;
|
||||
setOptimalParameters();
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
||||
|
||||
Reference in New Issue
Block a user