Added margin type, added tests with different scales of features.
Also fixed documentation, refactored sample.
This commit is contained in:
@@ -42,7 +42,6 @@
|
||||
|
||||
#include "precomp.hpp"
|
||||
#include "limits"
|
||||
//#include "math.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -76,9 +75,9 @@ public:
|
||||
|
||||
virtual void clear();
|
||||
|
||||
virtual void write(FileStorage& fs) const;
|
||||
virtual void write(FileStorage &fs) const;
|
||||
|
||||
virtual void read(const FileNode& fn);
|
||||
virtual void read(const FileNode &fn);
|
||||
|
||||
virtual Mat getWeights(){ return weights_; }
|
||||
|
||||
@@ -88,11 +87,15 @@ public:
|
||||
|
||||
virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
|
||||
|
||||
virtual void setOptimalParameters(int type = ASGD);
|
||||
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
|
||||
|
||||
virtual int getType() const;
|
||||
virtual int getSvmsgdType() const;
|
||||
|
||||
virtual void setType(int type);
|
||||
virtual void setSvmsgdType(int svmsgdType);
|
||||
|
||||
virtual int getMarginType() const;
|
||||
|
||||
virtual void setMarginType(int marginType);
|
||||
|
||||
CV_IMPL_PROPERTY(float, Lambda, params.lambda)
|
||||
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
|
||||
@@ -100,21 +103,21 @@ public:
|
||||
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
|
||||
|
||||
private:
|
||||
void updateWeights(InputArray sample, bool isFirstClass, float gamma, Mat weights);
|
||||
void updateWeights(InputArray sample, bool isFirstClass, float gamma, Mat &weights);
|
||||
|
||||
std::pair<bool,bool> areClassesEmpty(Mat responses);
|
||||
|
||||
void writeParams( FileStorage& fs ) const;
|
||||
void writeParams( FileStorage &fs ) const;
|
||||
|
||||
void readParams( const FileNode& fn );
|
||||
void readParams( const FileNode &fn );
|
||||
|
||||
static inline bool isFirstClass(float val) { return val > 0; }
|
||||
|
||||
static void normalizeSamples(Mat &matrix, Mat &multiplier, Mat &average);
|
||||
static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
|
||||
|
||||
float calcShift(InputArray _samples, InputArray _responses) const;
|
||||
|
||||
static void makeExtendedTrainSamples(const Mat trainSamples, Mat &extendedTrainSamples, Mat &multiplier);
|
||||
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
|
||||
|
||||
|
||||
|
||||
@@ -130,6 +133,7 @@ private:
|
||||
float c;
|
||||
TermCriteria termCrit;
|
||||
SvmsgdType svmsgdType;
|
||||
MarginType marginType;
|
||||
};
|
||||
|
||||
SVMSGDParams params;
|
||||
@@ -160,7 +164,7 @@ std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
|
||||
return emptyInClasses;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &multiplier, Mat &average)
|
||||
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
|
||||
{
|
||||
int featuresCount = samples.cols;
|
||||
int samplesCount = samples.rows;
|
||||
@@ -176,37 +180,25 @@ void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &multiplier, Mat &average)
|
||||
samples.row(sampleIndex) -= average;
|
||||
}
|
||||
|
||||
Mat featureNorm(1, featuresCount, samples.type());
|
||||
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
|
||||
{
|
||||
featureNorm.at<float>(featureIndex) = norm(samples.col(featureIndex));
|
||||
}
|
||||
double normValue = norm(samples);
|
||||
|
||||
multiplier = sqrt(samplesCount) / featureNorm;
|
||||
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
|
||||
{
|
||||
samples.row(sampleIndex) = samples.row(sampleIndex).mul(multiplier);
|
||||
}
|
||||
multiplier = sqrt(samples.total()) / normValue;
|
||||
|
||||
samples *= multiplier;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::makeExtendedTrainSamples(const Mat trainSamples, Mat &extendedTrainSamples, Mat &multiplier)
|
||||
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
|
||||
{
|
||||
Mat normalisedTrainSamples = trainSamples.clone();
|
||||
int samplesCount = normalisedTrainSamples.rows;
|
||||
|
||||
Mat average;
|
||||
|
||||
normalizeSamples(normalisedTrainSamples, multiplier, average);
|
||||
normalizeSamples(normalisedTrainSamples, average, multiplier);
|
||||
|
||||
Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
|
||||
cv::hconcat(normalisedTrainSamples, onesCol, extendedTrainSamples);
|
||||
|
||||
//cout << "SVMSGDImpl::makeExtendedTrainSamples average: \n" << average << endl;
|
||||
//cout << "SVMSGDImpl::makeExtendedTrainSamples multiplier: \n" << multiplier << endl;
|
||||
}
|
||||
|
||||
|
||||
void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float gamma, Mat weights)
|
||||
void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float gamma, Mat& weights)
|
||||
{
|
||||
Mat sample = _sample.getMat();
|
||||
|
||||
@@ -226,7 +218,7 @@ void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float gamma,
|
||||
|
||||
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
||||
{
|
||||
float distance_to_classes[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
|
||||
float distanceToClasses[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
|
||||
|
||||
Mat trainSamples = _samples.getMat();
|
||||
int trainSamplesCount = trainSamples.rows;
|
||||
@@ -241,36 +233,29 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
||||
bool firstClass = isFirstClass(trainResponses.at<float>(samplesIndex));
|
||||
int index = firstClass ? 0:1;
|
||||
float signToMul = firstClass ? 1 : -1;
|
||||
float cur_distance = dotProduct * signToMul;
|
||||
float curDistance = dotProduct * signToMul;
|
||||
|
||||
if (cur_distance < distance_to_classes[index])
|
||||
if (curDistance < distanceToClasses[index])
|
||||
{
|
||||
distance_to_classes[index] = cur_distance;
|
||||
distanceToClasses[index] = curDistance;
|
||||
}
|
||||
}
|
||||
|
||||
return -(distance_to_classes[0] - distance_to_classes[1]) / 2.f;
|
||||
return -(distanceToClasses[0] - distanceToClasses[1]) / 2.f;
|
||||
}
|
||||
|
||||
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
{
|
||||
//cout << "SVMSGDImpl::train begin" << endl;
|
||||
clear();
|
||||
CV_Assert( isClassifier() ); //toDo: consider
|
||||
|
||||
Mat trainSamples = data->getTrainSamples();
|
||||
|
||||
//cout << "SVMSGDImpl::train trainSamples: \n" << trainSamples << endl;
|
||||
|
||||
int featureCount = trainSamples.cols;
|
||||
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
|
||||
|
||||
//cout << "SVMSGDImpl::train trainresponses: \n" << trainResponses << endl;
|
||||
|
||||
std::pair<bool,bool> areEmpty = areClassesEmpty(trainResponses);
|
||||
|
||||
//cout << "SVMSGDImpl::train areEmpty" << areEmpty.first << "," << areEmpty.second << endl;
|
||||
|
||||
if ( areEmpty.first && areEmpty.second )
|
||||
{
|
||||
return false;
|
||||
@@ -283,10 +268,9 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
}
|
||||
|
||||
Mat extendedTrainSamples;
|
||||
Mat multiplier;
|
||||
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, multiplier);
|
||||
|
||||
//cout << "SVMSGDImpl::train extendedTrainSamples: \n" << extendedTrainSamples << endl;
|
||||
Mat average;
|
||||
float multiplier = 0;
|
||||
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
|
||||
|
||||
int extendedTrainSamplesCount = extendedTrainSamples.rows;
|
||||
int extendedFeatureCount = extendedTrainSamples.cols;
|
||||
@@ -301,6 +285,7 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
|
||||
RNG rng(0);
|
||||
|
||||
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
|
||||
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
|
||||
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
|
||||
|
||||
@@ -336,17 +321,20 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
extendedWeights = averageExtendedWeights;
|
||||
}
|
||||
|
||||
//cout << "SVMSGDImpl::train extendedWeights: \n" << extendedWeights << endl;
|
||||
|
||||
Rect roi(0, 0, featureCount, 1);
|
||||
weights_ = extendedWeights(roi);
|
||||
weights_ = weights_.mul(1/multiplier);
|
||||
weights_ *= multiplier;
|
||||
|
||||
//cout << "SVMSGDImpl::train weights: \n" << weights_ << endl;
|
||||
CV_Assert(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN);
|
||||
|
||||
shift_ = calcShift(trainSamples, trainResponses);
|
||||
|
||||
//cout << "SVMSGDImpl::train shift = " << shift_ << endl;
|
||||
if (params.marginType == SOFT_MARGIN)
|
||||
{
|
||||
shift_ = extendedWeights.at<float>(featureCount) - weights_.dot(average);
|
||||
}
|
||||
else
|
||||
{
|
||||
shift_ = calcShift(trainSamples, trainResponses);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -385,6 +373,8 @@ float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) cons
|
||||
bool SVMSGDImpl::isClassifier() const
|
||||
{
|
||||
return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
|
||||
&&
|
||||
(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
|
||||
&&
|
||||
(params.lambda > 0) && (params.gamma0 > 0) && (params.c >= 0);
|
||||
}
|
||||
@@ -417,15 +407,32 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||
case ASGD:
|
||||
SvmsgdTypeStr = "ASGD";
|
||||
break;
|
||||
case ILLEGAL_VALUE:
|
||||
SvmsgdTypeStr = format("Uknown_%d", params.svmsgdType);
|
||||
case ILLEGAL_SVMSGD_TYPE:
|
||||
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
|
||||
default:
|
||||
std::cout << "params.svmsgdType isn't initialized" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
fs << "svmsgdType" << SvmsgdTypeStr;
|
||||
|
||||
String marginTypeStr;
|
||||
|
||||
switch (params.marginType)
|
||||
{
|
||||
case SOFT_MARGIN:
|
||||
marginTypeStr = "SOFT_MARGIN";
|
||||
break;
|
||||
case HARD_MARGIN:
|
||||
marginTypeStr = "HARD_MARGIN";
|
||||
break;
|
||||
case ILLEGAL_MARGIN_TYPE:
|
||||
marginTypeStr = format("Unknown_%d", params.marginType);
|
||||
default:
|
||||
std::cout << "params.marginType isn't initialized" << std::endl;
|
||||
}
|
||||
|
||||
fs << "marginType" << marginTypeStr;
|
||||
|
||||
fs << "lambda" << params.lambda;
|
||||
fs << "gamma0" << params.gamma0;
|
||||
fs << "c" << params.c;
|
||||
@@ -438,8 +445,6 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||
fs << "}";
|
||||
}
|
||||
|
||||
|
||||
|
||||
void SVMSGDImpl::read(const FileNode& fn)
|
||||
{
|
||||
clear();
|
||||
@@ -455,13 +460,23 @@ void SVMSGDImpl::readParams( const FileNode& fn )
|
||||
String svmsgdTypeStr = (String)fn["svmsgdType"];
|
||||
SvmsgdType svmsgdType =
|
||||
svmsgdTypeStr == "SGD" ? SGD :
|
||||
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_VALUE;
|
||||
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_SVMSGD_TYPE;
|
||||
|
||||
if( svmsgdType == ILLEGAL_VALUE )
|
||||
if( svmsgdType == ILLEGAL_SVMSGD_TYPE )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
|
||||
|
||||
params.svmsgdType = svmsgdType;
|
||||
|
||||
String marginTypeStr = (String)fn["marginType"];
|
||||
MarginType marginType =
|
||||
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
|
||||
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
||||
|
||||
if( marginType == ILLEGAL_MARGIN_TYPE )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
|
||||
|
||||
params.marginType = marginType;
|
||||
|
||||
CV_Assert ( fn["lambda"].isReal() );
|
||||
params.lambda = (float)fn["lambda"];
|
||||
|
||||
@@ -494,7 +509,8 @@ SVMSGDImpl::SVMSGDImpl()
|
||||
{
|
||||
clear();
|
||||
|
||||
params.svmsgdType = ILLEGAL_VALUE;
|
||||
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
|
||||
params.marginType = ILLEGAL_MARGIN_TYPE;
|
||||
|
||||
// Parameters for learning
|
||||
params.lambda = 0; // regularization
|
||||
@@ -505,26 +521,28 @@ SVMSGDImpl::SVMSGDImpl()
|
||||
params.termCrit = _termCrit;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setOptimalParameters(int type)
|
||||
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
||||
{
|
||||
switch (type)
|
||||
switch (svmsgdType)
|
||||
{
|
||||
case SGD:
|
||||
params.svmsgdType = SGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
||||
params.lambda = 0.0001;
|
||||
params.gamma0 = 0.05;
|
||||
params.c = 1;
|
||||
params.termCrit.maxCount = 100000;
|
||||
params.termCrit.epsilon = 0.00001;
|
||||
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
|
||||
break;
|
||||
|
||||
case ASGD:
|
||||
params.svmsgdType = ASGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
||||
params.lambda = 0.00001;
|
||||
params.gamma0 = 0.05;
|
||||
params.c = 0.75;
|
||||
params.termCrit.maxCount = 100000;
|
||||
params.termCrit.epsilon = 0.00001;
|
||||
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -532,7 +550,7 @@ void SVMSGDImpl::setOptimalParameters(int type)
|
||||
}
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setType(int type)
|
||||
void SVMSGDImpl::setSvmsgdType(int type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
@@ -543,13 +561,33 @@ void SVMSGDImpl::setType(int type)
|
||||
params.svmsgdType = ASGD;
|
||||
break;
|
||||
default:
|
||||
params.svmsgdType = ILLEGAL_VALUE;
|
||||
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
|
||||
}
|
||||
}
|
||||
|
||||
int SVMSGDImpl::getType() const
|
||||
int SVMSGDImpl::getSvmsgdType() const
|
||||
{
|
||||
return params.svmsgdType;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setMarginType(int type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case HARD_MARGIN:
|
||||
params.marginType = HARD_MARGIN;
|
||||
break;
|
||||
case SOFT_MARGIN:
|
||||
params.marginType = SOFT_MARGIN;
|
||||
break;
|
||||
default:
|
||||
params.marginType = ILLEGAL_MARGIN_TYPE;
|
||||
}
|
||||
}
|
||||
|
||||
int SVMSGDImpl::getMarginType() const
|
||||
{
|
||||
return params.marginType;
|
||||
}
|
||||
} //ml
|
||||
} //cv
|
||||
|
||||
Reference in New Issue
Block a user