Refactored SVMSGD class

This commit is contained in:
Marina Noskova 2016-01-20 12:59:44 +03:00
parent a2f0963d66
commit 40bf97c6d1
11 changed files with 980 additions and 241 deletions

View File

@ -75,6 +75,7 @@
#endif #endif
#ifdef HAVE_OPENCV_ML #ifdef HAVE_OPENCV_ML
#include "opencv2/ml.hpp" #include "opencv2/ml.hpp"
#include "opencv2/ml/svmsgd.hpp"
#endif #endif
#endif #endif

View File

@ -1513,126 +1513,6 @@ CV_EXPORTS void randMVNormal( InputArray mean, InputArray cov, int nsamples, Out
CV_EXPORTS void createConcentricSpheresTestSet( int nsamples, int nfeatures, int nclasses, CV_EXPORTS void createConcentricSpheresTestSet( int nsamples, int nfeatures, int nclasses,
OutputArray samples, OutputArray responses); OutputArray samples, OutputArray responses);
/****************************************************************************************\
* Stochastic Gradient Descent SVM Classifier *
\****************************************************************************************/
/*!
@brief Stochastic Gradient Descent SVM classifier
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large.
The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which
is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example).
First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined.
Then the SVM model can be trained using the train features and the correspondent labels.
After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically.
@code
// Initialize object
SVMSGD SvmSgd;
// Train the Stochastic Gradient Descent SVM
SvmSgd.train(trainFeatures, labels);
// Predict label for the new feature vector (1xM)
predictedLabel = SvmSgd.predict(newFeatureVector);
@endcode
*/
class CV_EXPORTS_W SVMSGD {
public:
/** @brief SGDSVM constructor.
@param lambda regularization
@param learnRate learning rate
@param nIterations number of training iterations
*/
SVMSGD(float lambda = 0.000001, float learnRate = 2, uint nIterations = 100000);
/** @brief SGDSVM constructor.
@param updateFrequency online update frequency
@param learnRateDecay learn rate decay over time: learnRate = learnRate * learnDecay
@param lambda regularization
@param learnRate learning rate
@param nIterations number of training iterations
*/
SVMSGD(uint updateFrequency, float learnRateDecay = 1, float lambda = 0.000001, float learnRate = 2, uint nIterations = 100000);
virtual ~SVMSGD();
virtual SVMSGD* clone() const;
/** @brief Train the SGDSVM classifier.
The function trains the SGDSVM classifier using the train features and the correspondent labels (-1 or 1).
@param trainFeatures features used for training. Each row is a new sample.
@param labels mat (size Nx1 with N = number of features) with the label of each training feature.
*/
virtual void train(cv::Mat trainFeatures, cv::Mat labels);
/** @brief Predict the label of a new feature vector.
The function predicts and returns the label of a new feature vector, using the previously trained SVM model.
@param newFeature new feature vector used for prediction
*/
virtual float predict(cv::Mat newFeature);
/** @brief Returns the weights of the trained model.
*/
virtual std::vector<float> getWeights(){ return _weights; };
/** @brief Sets the weights of the trained model.
@param weights weights used to predict the label of a new feature vector.
*/
virtual void setWeights(std::vector<float> weights){ _weights = weights; };
private:
void updateWeights();
void generateRandomIndex();
float calcInnerProduct(float *rowDataPointer);
void updateWeights(float innerProduct, float *rowDataPointer, int label);
// Vector with SVM weights
std::vector<float> _weights;
// Random index generation
long long int _randomNumber;
unsigned int _randomIndex;
// Number of features and samples
unsigned int _nFeatures;
unsigned int _nTrainSamples;
// Parameters for learning
float _lambda; //regularization
float _learnRate; //learning rate
unsigned int _nIterations; //number of training iterations
// Vars to control the features slider matrix
bool _onlineUpdate;
bool _initPredict;
uint _slidingWindowSize;
uint _predictSlidingWindowSize;
float* _labelSlider;
float _learnRateDecay;
// Mat with features slider and correspondent counter
unsigned int _sliderCounter;
cv::Mat _featuresSlider;
};
//! @} ml //! @} ml

View File

@ -0,0 +1,134 @@
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
// Copyright (C) 2014, Itseez Inc, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of the copyright holders may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#ifndef __OPENCV_ML_SVMSGD_HPP__
#define __OPENCV_ML_SVMSGD_HPP__
#ifdef __cplusplus
#include "opencv2/ml.hpp"
namespace cv
{
namespace ml
{
/****************************************************************************************\
* Stochastic Gradient Descent SVM Classifier *
\****************************************************************************************/
/*!
@brief Stochastic Gradient Descent SVM classifier
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large.
The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which
is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example).
First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined.
Then the SVM model can be trained using the train features and the correspondent labels.
After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically.
@code
// Initialize object
SVMSGD SvmSgd;
// Train the Stochastic Gradient Descent SVM
SvmSgd.train(trainFeatures, labels);
// Predict label for the new feature vector (1xM)
predictedLabel = SvmSgd.predict(newFeatureVector);
@endcode
*/
class CV_EXPORTS_W SVMSGD : public cv::ml::StatModel
{
public:
enum SvmsgdType
{
ILLEGAL_VALUE,
SGD, //Stochastic Gradient Descent
ASGD //Average Stochastic Gradient Descent
};
/**
* @return the weights of the trained model.
*/
CV_WRAP virtual Mat getWeights() = 0;
CV_WRAP virtual float getShift() = 0;
CV_WRAP static Ptr<SVMSGD> create();
CV_WRAP virtual void setOptimalParameters(int type = ASGD) = 0;
CV_WRAP virtual int getType() const = 0;
CV_WRAP virtual void setType(int type) = 0;
CV_WRAP virtual float getLambda() const = 0;
CV_WRAP virtual void setLambda(float lambda) = 0;
CV_WRAP virtual float getGamma0() const = 0;
CV_WRAP virtual void setGamma0(float gamma0) = 0;
CV_WRAP virtual float getC() const = 0;
CV_WRAP virtual void setC(float c) = 0;
CV_WRAP virtual cv::TermCriteria getTermCriteria() const = 0;
CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0;
};
} //ml
} //cv
#endif // __clpusplus
#endif // __OPENCV_ML_SVMSGD_HPP

View File

@ -45,7 +45,7 @@
#include "opencv2/ml.hpp" #include "opencv2/ml.hpp"
#include "opencv2/core/core_c.h" #include "opencv2/core/core_c.h"
#include "opencv2/core/utility.hpp" #include "opencv2/core/utility.hpp"
#include "opencv2/ml/svmsgd.hpp"
#include "opencv2/core/private.hpp" #include "opencv2/core/private.hpp"
#include <assert.h> #include <assert.h>

View File

@ -41,161 +41,430 @@
//M*/ //M*/
#include "precomp.hpp" #include "precomp.hpp"
#include "limits"
/****************************************************************************************\ /****************************************************************************************\
* Stochastic Gradient Descent SVM Classifier * * Stochastic Gradient Descent SVM Classifier *
\****************************************************************************************/ \****************************************************************************************/
namespace cv { namespace cv
namespace ml { {
namespace ml
{
SVMSGD::SVMSGD(float lambda, float learnRate, uint nIterations){ class SVMSGDImpl : public SVMSGD
{
// Initialize with random seed public:
_randomNumber = 1; SVMSGDImpl();
// Initialize constants virtual ~SVMSGDImpl() {}
_slidingWindowSize = 0;
_nFeatures = 0;
_predictSlidingWindowSize = 1;
// Initialize sliderCounter at index 0 virtual bool train(const Ptr<TrainData>& data, int);
_sliderCounter = 0;
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
virtual bool isClassifier() const { return params.svmsgdType == SGD || params.svmsgdType == ASGD; }
virtual bool isTrained() const;
virtual void clear();
virtual void write(FileStorage& fs) const;
virtual void read(const FileNode& fn);
virtual Mat getWeights(){ return weights_; }
virtual float getShift(){ return shift_; }
virtual int getVarCount() const { return weights_.cols; }
virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
virtual void setOptimalParameters(int type = ASGD);
virtual int getType() const;
virtual void setType(int type);
CV_IMPL_PROPERTY(float, Lambda, params.lambda)
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
CV_IMPL_PROPERTY(float, C, params.c)
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
private:
void updateWeights(InputArray sample, bool is_first_class, float gamma);
float calcShift(InputArray trainSamples, InputArray trainResponses) const;
std::pair<bool,bool> areClassesEmpty(Mat responses);
void writeParams( FileStorage& fs ) const;
void readParams( const FileNode& fn );
static inline bool isFirstClass(float val) { return val > 0; }
// Vector with SVM weights
Mat weights_;
float shift_;
// Random index generation
RNG rng_;
// Parameters for learning // Parameters for learning
_lambda = lambda; // regularization struct SVMSGDParams
_learnRate = learnRate; // learning rate (ideally should be large at beginning and decay each iteration) {
_nIterations = nIterations; // number of training iterations float lambda; //regularization
float gamma0; //learning rate
float c;
TermCriteria termCrit;
SvmsgdType svmsgdType;
};
// True only in the first predict iteration SVMSGDParams params;
_initPredict = true; };
// Online update flag Ptr<SVMSGD> SVMSGD::create()
_onlineUpdate = false; {
return makePtr<SVMSGDImpl>();
} }
SVMSGD::SVMSGD(uint updateFrequency, float learnRateDecay, float lambda, float learnRate, uint nIterations){
// Initialize with random seed bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
_randomNumber = 1; {
clear();
// Initialize constants Mat trainSamples = data->getTrainSamples();
_slidingWindowSize = 0;
_nFeatures = 0;
_predictSlidingWindowSize = updateFrequency;
// Initialize sliderCounter at index 0 // Initialize varCount
_sliderCounter = 0; int trainSamplesCount_ = trainSamples.rows;
int varCount = trainSamples.cols;
// Parameters for learning
_lambda = lambda; // regularization
_learnRate = learnRate; // learning rate (ideally should be large at beginning and decay each iteration)
_nIterations = nIterations; // number of training iterations
// True only in the first predict iteration
_initPredict = true;
// Online update flag
_onlineUpdate = true;
// Learn rate decay: _learnRate = _learnRate * _learnDecay
_learnRateDecay = learnRateDecay;
}
SVMSGD::~SVMSGD(){
}
SVMSGD* SVMSGD::clone() const{
return new SVMSGD(*this);
}
void SVMSGD::train(cv::Mat trainFeatures, cv::Mat labels){
// Initialize _nFeatures
_slidingWindowSize = trainFeatures.rows;
_nFeatures = trainFeatures.cols;
float innerProduct;
// Initialize weights vector with zeros // Initialize weights vector with zeros
if (_weights.size()==0){ weights_ = Mat::zeros(1, varCount, CV_32F);
_weights.reserve(_nFeatures);
for (uint feat = 0; feat < _nFeatures; ++feat){ Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
_weights.push_back(0.0);
} std::pair<bool,bool> are_empty = areClassesEmpty(trainResponses);
if ( are_empty.first && are_empty.second )
{
weights_.release();
return false;
}
if ( are_empty.first || are_empty.second )
{
shift_ = are_empty.first ? -1 : 1;
return true;
}
Mat currentSample;
float gamma = 0;
Mat lastWeights = Mat::zeros(1, varCount, CV_32F); //weights vector for calculating terminal criterion
Mat averageWeights; //average weights vector for ASGD model
double err = DBL_MAX;
if (params.svmsgdType == ASGD)
{
averageWeights = Mat::zeros(1, varCount, CV_32F);
} }
// Stochastic gradient descent SVM // Stochastic gradient descent SVM
for (uint iter = 0; iter < _nIterations; ++iter){ for (int iter = 0; (iter < params.termCrit.maxCount)&&(err > params.termCrit.epsilon); iter++)
generateRandomIndex(); {
innerProduct = calcInnerProduct(trainFeatures.ptr<float>(_randomIndex)); //generate sample number
int label = (labels.at<int>(_randomIndex,0) > 0) ? 1 : -1; // ensure that labels are -1 or 1 int randomNumber = rng_.uniform(0, trainSamplesCount_);
updateWeights(innerProduct, trainFeatures.ptr<float>(_randomIndex), label );
}
}
float SVMSGD::predict(cv::Mat newFeature){ currentSample = trainSamples.row(randomNumber);
float innerProduct;
if (_initPredict){ //update gamma
_nFeatures = newFeature.cols; gamma = params.gamma0 * std::pow((1 + params.lambda * params.gamma0 * (float)iter), (-params.c));
_slidingWindowSize = _predictSlidingWindowSize;
_featuresSlider = cv::Mat::zeros(_slidingWindowSize, _nFeatures, CV_32F);
_initPredict = false;
_labelSlider = new float[_predictSlidingWindowSize]();
_learnRate = _learnRate * _learnRateDecay;
}
innerProduct = calcInnerProduct(newFeature.ptr<float>(0)); bool is_first_class = isFirstClass(trainResponses.at<float>(randomNumber));
updateWeights( currentSample, is_first_class, gamma );
// Resultant label (-1 or 1) //average weights (only for ASGD model)
int label = (innerProduct>=0) ? 1 : -1; if (params.svmsgdType == ASGD)
{
if (_onlineUpdate){ averageWeights = ((float)iter/ (1 + (float)iter)) * averageWeights + weights_ / (1 + (float) iter);
// Update the featuresSlider with newFeature and _labelSlider with label
newFeature.row(0).copyTo(_featuresSlider.row(_sliderCounter));
_labelSlider[_sliderCounter] = float(label);
// Update weights with a random index
if (_sliderCounter == _slidingWindowSize-1){
generateRandomIndex();
updateWeights(innerProduct, _featuresSlider.ptr<float>(_randomIndex), int(_labelSlider[_randomIndex]) );
} }
// _sliderCounter++ if < _slidingWindowSize err = norm(weights_ - lastWeights);
_sliderCounter = (_sliderCounter == _slidingWindowSize-1) ? 0 : (_sliderCounter+1); weights_.copyTo(lastWeights);
} }
return float(label); if (params.svmsgdType == ASGD)
} {
weights_ = averageWeights;
void SVMSGD::generateRandomIndex(){
// Choose random sample, using Mikolov's fast almost-uniform random number
_randomNumber = _randomNumber * (unsigned long long) 25214903917 + 11;
_randomIndex = uint(_randomNumber % (unsigned long long) _slidingWindowSize);
}
float SVMSGD::calcInnerProduct(float *rowDataPointer){
float innerProduct = 0;
for (uint feat = 0; feat < _nFeatures; ++feat){
innerProduct += _weights[feat] * rowDataPointer[feat];
} }
return innerProduct;
shift_ = calcShift(trainSamples, trainResponses);
return true;
} }
void SVMSGD::updateWeights(float innerProduct, float *rowDataPointer, int label){ std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
if (label * innerProduct > 1) { {
std::pair<bool,bool> are_classes_empty(true, true);
int limit_index = responses.rows;
for(int index = 0; index < limit_index; index++)
{
if (isFirstClass(responses.at<float>(index,0)))
are_classes_empty.first = false;
else
are_classes_empty.second = false;
if (!are_classes_empty.first && ! are_classes_empty.second)
break;
}
return are_classes_empty;
}
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
{
float distance_to_classes[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
Mat trainSamples = _samples.getMat();
int trainSamplesCount = trainSamples.rows;
Mat trainResponses = _responses.getMat();
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
{
Mat currentSample = trainSamples.row(samplesIndex);
float scalar_product = currentSample.dot(weights_);
bool is_first_class = isFirstClass(trainResponses.at<float>(samplesIndex));
int index = is_first_class ? 0:1;
float sign_to_mul = is_first_class ? 1 : -1;
float cur_distance = scalar_product * sign_to_mul ;
if (cur_distance < distance_to_classes[index])
{
distance_to_classes[index] = cur_distance;
}
}
//todo: areClassesEmpty(); make const;
return -(distance_to_classes[0] - distance_to_classes[1]) / 2.f;
}
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
{
float result = 0;
cv::Mat samples = _samples.getMat();
int nSamples = samples.rows;
cv::Mat results;
CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32F );
if( _results.needed() )
{
_results.create( nSamples, 1, samples.type() );
results = _results.getMat();
}
else
{
CV_Assert( nSamples == 1 );
results = Mat(1, 1, CV_32F, &result);
}
Mat currentSample;
float criterion = 0;
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
{
currentSample = samples.row(sampleIndex);
criterion = currentSample.dot(weights_) + shift_;
results.at<float>(sampleIndex) = (criterion >= 0) ? 1 : -1;
}
return result;
}
void SVMSGDImpl::updateWeights(InputArray _sample, bool is_first_class, float gamma)
{
Mat sample = _sample.getMat();
int responce = is_first_class ? 1 : -1; // ensure that trainResponses are -1 or 1
if ( sample.dot(weights_) * responce > 1)
{
// Not a support vector, only apply weight decay // Not a support vector, only apply weight decay
for (uint feat = 0; feat < _nFeatures; feat++) { weights_ *= (1.f - gamma * params.lambda);
_weights[feat] -= _learnRate * _lambda * _weights[feat]; }
} else
} else { {
// It's a support vector, add it to the weights // It's a support vector, add it to the weights
for (uint feat = 0; feat < _nFeatures; feat++) { weights_ -= (gamma * params.lambda) * weights_ - gamma * responce * sample;
_weights[feat] -= _learnRate * (_lambda * _weights[feat] - label * rowDataPointer[feat]); //std::cout << "sample " << sample << std::endl;
} //std::cout << "weights_ " << weights_ << std::endl;
} }
} }
bool SVMSGDImpl::isTrained() const
{
return !weights_.empty();
} }
void SVMSGDImpl::write(FileStorage& fs) const
{
if( !isTrained() )
CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );
writeParams( fs );
fs << "shift" << shift_;
fs << "weights" << weights_;
} }
void SVMSGDImpl::writeParams( FileStorage& fs ) const
{
String SvmsgdTypeStr;
switch (params.svmsgdType)
{
case SGD:
SvmsgdTypeStr = "SGD";
break;
case ASGD:
SvmsgdTypeStr = "ASGD";
break;
case ILLEGAL_VALUE:
SvmsgdTypeStr = format("Uknown_%d", params.svmsgdType);
default:
std::cout << "params.svmsgdType isn't initialized" << std::endl;
}
fs << "svmsgdType" << SvmsgdTypeStr;
fs << "lambda" << params.lambda;
fs << "gamma0" << params.gamma0;
fs << "c" << params.c;
fs << "term_criteria" << "{:";
if( params.termCrit.type & TermCriteria::EPS )
fs << "epsilon" << params.termCrit.epsilon;
if( params.termCrit.type & TermCriteria::COUNT )
fs << "iterations" << params.termCrit.maxCount;
fs << "}";
}
void SVMSGDImpl::read(const FileNode& fn)
{
clear();
readParams(fn);
shift_ = (float) fn["shift"];
fn["weights"] >> weights_;
}
void SVMSGDImpl::readParams( const FileNode& fn )
{
String svmsgdTypeStr = (String)fn["svmsgdType"];
SvmsgdType svmsgdType =
svmsgdTypeStr == "SGD" ? SGD :
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_VALUE;
if( svmsgdType == ILLEGAL_VALUE )
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
params.svmsgdType = svmsgdType;
CV_Assert ( fn["lambda"].isReal() );
params.lambda = (float)fn["lambda"];
CV_Assert ( fn["gamma0"].isReal() );
params.gamma0 = (float)fn["gamma0"];
CV_Assert ( fn["c"].isReal() );
params.c = (float)fn["c"];
FileNode tcnode = fn["term_criteria"];
if( !tcnode.empty() )
{
params.termCrit.epsilon = (double)tcnode["epsilon"];
params.termCrit.maxCount = (int)tcnode["iterations"];
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
}
else
params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON );
}
void SVMSGDImpl::clear()
{
weights_.release();
shift_ = 0;
}
SVMSGDImpl::SVMSGDImpl()
{
clear();
rng_(0);
params.svmsgdType = ILLEGAL_VALUE;
// Parameters for learning
params.lambda = 0; // regularization
params.gamma0 = 0; // learning rate (ideally should be large at beginning and decay each iteration)
params.c = 0;
TermCriteria _termCrit(TermCriteria::COUNT + TermCriteria::EPS, 0, 0);
params.termCrit = _termCrit;
}
void SVMSGDImpl::setOptimalParameters(int type)
{
switch (type)
{
case SGD:
params.svmsgdType = SGD;
params.lambda = 0.00001;
params.gamma0 = 0.05;
params.c = 1;
params.termCrit.maxCount = 50000;
params.termCrit.epsilon = 0.00000001;
break;
case ASGD:
params.svmsgdType = ASGD;
params.lambda = 0.00001;
params.gamma0 = 0.5;
params.c = 0.75;
params.termCrit.maxCount = 100000;
params.termCrit.epsilon = 0.000001;
break;
default:
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
}
}
void SVMSGDImpl::setType(int type)
{
switch (type)
{
case SGD:
params.svmsgdType = SGD;
break;
case ASGD:
params.svmsgdType = ASGD;
break;
default:
params.svmsgdType = ILLEGAL_VALUE;
}
}
int SVMSGDImpl::getType() const
{
return params.svmsgdType;
}
} //ml
} //cv

View File

@ -193,6 +193,16 @@ int str_to_boost_type( String& str )
// 8. rtrees // 8. rtrees
// 9. ertrees // 9. ertrees
int str_to_svmsgd_type( String& str )
{
if ( !str.compare("SGD") )
return SVMSGD::SGD;
if ( !str.compare("ASGD") )
return SVMSGD::ASGD;
CV_Error( CV_StsBadArg, "incorrect boost type string" );
return -1;
}
// ---------------------------------- MLBaseTest --------------------------------------------------- // ---------------------------------- MLBaseTest ---------------------------------------------------
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName) CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
@ -248,7 +258,9 @@ void CV_MLBaseTest::run( int )
{ {
string filename = ts->get_data_path(); string filename = ts->get_data_path();
filename += get_validation_filename(); filename += get_validation_filename();
validationFS.open( filename, FileStorage::READ ); validationFS.open( filename, FileStorage::READ );
read_params( *validationFS ); read_params( *validationFS );
int code = cvtest::TS::OK; int code = cvtest::TS::OK;
@ -436,6 +448,21 @@ int CV_MLBaseTest::train( int testCaseIdx )
model = m; model = m;
} }
else if( modelName == CV_SVMSGD )
{
String svmsgdTypeStr;
modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
Ptr<SVMSGD> m = SVMSGD::create();
int type = str_to_svmsgd_type( svmsgdTypeStr );
m->setType(type);
//m->setType(str_to_svmsgd_type( svmsgdTypeStr ));
m->setLambda(modelParamsNode["lambda"]);
m->setGamma0(modelParamsNode["gamma0"]);
m->setC(modelParamsNode["c"]);
m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001));
model = m;
}
if( !model.empty() ) if( !model.empty() )
is_trained = model->train(data, 0); is_trained = model->train(data, 0);
@ -457,7 +484,7 @@ float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
else if( modelName == CV_ANN ) else if( modelName == CV_ANN )
err = ann_calc_error( model, data, cls_map, type, resp ); err = ann_calc_error( model, data, cls_map, type, resp );
else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES || else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST ) modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST || modelName == CV_SVMSGD )
err = model->calcError( data, true, _resp ); err = model->calcError( data, true, _resp );
if( !_resp.empty() && resp ) if( !_resp.empty() && resp )
_resp.convertTo(*resp, CV_32F); _resp.convertTo(*resp, CV_32F);
@ -485,6 +512,8 @@ void CV_MLBaseTest::load( const char* filename )
model = Algorithm::load<Boost>( filename ); model = Algorithm::load<Boost>( filename );
else if( modelName == CV_RTREES ) else if( modelName == CV_RTREES )
model = Algorithm::load<RTrees>( filename ); model = Algorithm::load<RTrees>( filename );
else if( modelName == CV_SVMSGD )
model = Algorithm::load<SVMSGD>( filename );
else else
CV_Error( CV_StsNotImplemented, "invalid stat model name"); CV_Error( CV_StsNotImplemented, "invalid stat model name");
} }

View File

@ -13,6 +13,7 @@
#include <map> #include <map>
#include "opencv2/ts.hpp" #include "opencv2/ts.hpp"
#include "opencv2/ml.hpp" #include "opencv2/ml.hpp"
#include "opencv2/ml/svmsgd.hpp"
#include "opencv2/core/core_c.h" #include "opencv2/core/core_c.h"
#define CV_NBAYES "nbayes" #define CV_NBAYES "nbayes"
@ -24,6 +25,7 @@
#define CV_BOOST "boost" #define CV_BOOST "boost"
#define CV_RTREES "rtrees" #define CV_RTREES "rtrees"
#define CV_ERTREES "ertrees" #define CV_ERTREES "ertrees"
#define CV_SVMSGD "svmsgd"
enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 }; enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 };
@ -38,6 +40,7 @@ using cv::ml::ANN_MLP;
using cv::ml::DTrees; using cv::ml::DTrees;
using cv::ml::Boost; using cv::ml::Boost;
using cv::ml::RTrees; using cv::ml::RTrees;
using cv::ml::SVMSGD;
class CV_MLBaseTest : public cvtest::BaseTest class CV_MLBaseTest : public cvtest::BaseTest
{ {

View File

@ -150,12 +150,20 @@ int CV_SLMLTest::validate_test_results( int testCaseIdx )
TEST(ML_NaiveBayes, save_load) { CV_SLMLTest test( CV_NBAYES ); test.safe_run(); } TEST(ML_NaiveBayes, save_load) { CV_SLMLTest test( CV_NBAYES ); test.safe_run(); }
TEST(ML_KNearest, save_load) { CV_SLMLTest test( CV_KNEAREST ); test.safe_run(); } TEST(ML_KNearest, save_load) { CV_SLMLTest test( CV_KNEAREST ); test.safe_run(); }
TEST(ML_SVM, save_load) { CV_SLMLTest test( CV_SVM ); test.safe_run(); } TEST(ML_SVM, save_load)
{
CV_SLMLTest test( CV_SVM );
test.safe_run();
}
TEST(ML_ANN, save_load) { CV_SLMLTest test( CV_ANN ); test.safe_run(); } TEST(ML_ANN, save_load) { CV_SLMLTest test( CV_ANN ); test.safe_run(); }
TEST(ML_DTree, save_load) { CV_SLMLTest test( CV_DTREE ); test.safe_run(); } TEST(ML_DTree, save_load) { CV_SLMLTest test( CV_DTREE ); test.safe_run(); }
TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); } TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); } TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); } TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
TEST(MV_SVMSGD, save_load){
CV_SLMLTest test( CV_SVMSGD );
test.safe_run();
}
class CV_LegacyTest : public cvtest::BaseTest class CV_LegacyTest : public cvtest::BaseTest
{ {
@ -201,6 +209,8 @@ protected:
model = Algorithm::load<SVM>(filename); model = Algorithm::load<SVM>(filename);
else if (modelName == CV_RTREES) else if (modelName == CV_RTREES)
model = Algorithm::load<RTrees>(filename); model = Algorithm::load<RTrees>(filename);
else if (modelName == CV_SVMSGD)
model = Algorithm::load<SVMSGD>(filename);
if (!model) if (!model)
{ {
code = cvtest::TS::FAIL_INVALID_TEST_DATA; code = cvtest::TS::FAIL_INVALID_TEST_DATA;
@ -260,6 +270,11 @@ TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushro
TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); } TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); } TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); } TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
TEST(ML_SVMSGD, legacy_load)
{
CV_LegacyTest test(CV_SVMSGD, "_waveform.xml");
test.safe_run();
}
/*TEST(ML_SVM, throw_exception_when_save_untrained_model) /*TEST(ML_SVM, throw_exception_when_save_untrained_model)
{ {

View File

@ -0,0 +1,182 @@
/*M///////////////////////////////////////////////////////////////////////////////////////
//
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
//
// By downloading, copying, installing or using the software you agree to this license.
// If you do not agree to this license, do not download, install,
// copy or use the software.
//
//
// Intel License Agreement
// For Open Source Computer Vision Library
//
// Copyright (C) 2000, Intel Corporation, all rights reserved.
// Third party copyrights are property of their respective owners.
//
// Redistribution and use in source and binary forms, with or without modification,
// are permitted provided that the following conditions are met:
//
// * Redistribution's of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistribution's in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * The name of Intel Corporation may not be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// This software is provided by the copyright holders and contributors "as is" and
// any express or implied warranties, including, but not limited to, the implied
// warranties of merchantability and fitness for a particular purpose are disclaimed.
// In no event shall the Intel Corporation or contributors be liable for any direct,
// indirect, incidental, special, exemplary, or consequential damages
// (including, but not limited to, procurement of substitute goods or services;
// loss of use, data, or profits; or business interruption) however caused
// and on any theory of liability, whether in contract, strict liability,
// or tort (including negligence or otherwise) arising in any way out of
// the use of this software, even if advised of the possibility of such damage.
//
//M*/
#include "test_precomp.hpp"
#include "opencv2/highgui.hpp"
using namespace cv;
using namespace cv::ml;
using cv::ml::SVMSGD;
using cv::ml::TrainData;
class CV_SVMSGDTrainTest : public cvtest::BaseTest
{
public:
CV_SVMSGDTrainTest(Mat _weights, float _shift);
private:
virtual void run( int start_from );
float decisionFunction(Mat sample, Mat weights, float shift);
cv::Ptr<TrainData> data;
cv::Mat testSamples;
cv::Mat testResponses;
static const int TEST_VALUE_LIMIT = 50;
};
CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift)
{
int datasize = 100000;
int varCount = weights.cols;
cv::Mat samples = cv::Mat::zeros( datasize, varCount, CV_32FC1 );
cv::Mat responses = cv::Mat::zeros( datasize, 1, CV_32FC1 );
cv::RNG rng(0);
float lowerLimit = -TEST_VALUE_LIMIT;
float upperLimit = TEST_VALUE_LIMIT;
rng.fill(samples, RNG::UNIFORM, lowerLimit, upperLimit);
for (int sampleIndex = 0; sampleIndex < datasize; sampleIndex++)
{
responses.at<float>( sampleIndex ) = decisionFunction(samples.row(sampleIndex), weights, shift) > 0 ? 1 : -1;
}
data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
int testSamplesCount = 100000;
testSamples.create(testSamplesCount, varCount, CV_32FC1);
rng.fill(testSamples, RNG::UNIFORM, lowerLimit, upperLimit);
testResponses.create(testSamplesCount, 1, CV_32FC1);
for (int i = 0 ; i < testSamplesCount; i++)
{
testResponses.at<float>(i) = decisionFunction(testSamples.row(i), weights, shift) > 0 ? 1 : -1;
}
}
void CV_SVMSGDTrainTest::run( int /*start_from*/ )
{
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
svmsgd->setOptimalParameters(SVMSGD::ASGD);
svmsgd->train( data );
Mat responses;
svmsgd->predict(testSamples, responses);
int errCount = 0;
int testSamplesCount = testSamples.rows;
for (int i = 0; i < testSamplesCount; i++)
{
if (responses.at<float>(i) * testResponses.at<float>(i) < 0 )
errCount++;
}
float err = (float)errCount / testSamplesCount;
std::cout << "err " << err << std::endl;
if ( err > 0.01 )
{
ts->set_failed_test_info( cvtest::TS::FAIL_BAD_ACCURACY );
}
}
float CV_SVMSGDTrainTest::decisionFunction(Mat sample, Mat weights, float shift)
{
return sample.dot(weights) + shift;
}
TEST(ML_SVMSGD, train0)
{
int varCount = 2;
Mat weights;
weights.create(1, varCount, CV_32FC1);
weights.at<float>(0) = 1;
weights.at<float>(1) = 0;
float shift = 5;
CV_SVMSGDTrainTest test(weights, shift);
test.safe_run();
}
TEST(ML_SVMSGD, train1)
{
int varCount = 5;
Mat weights;
weights.create(1, varCount, CV_32FC1);
float lowerLimit = -1;
float upperLimit = 1;
cv::RNG rng(0);
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
float shift = rng.uniform(-5.f, 5.f);
CV_SVMSGDTrainTest test(weights, shift);
test.safe_run();
}
TEST(ML_SVMSGD, train2)
{
int varCount = 100;
Mat weights;
weights.create(1, varCount, CV_32FC1);
float lowerLimit = -1;
float upperLimit = 1;
cv::RNG rng(0);
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
float shift = rng.uniform(-1000.f, 1000.f);
CV_SVMSGDTrainTest test(weights, shift);
test.safe_run();
}

View File

@ -5659,7 +5659,7 @@ class TestCaseNameIs {
// Returns true iff the name of test_case matches name_. // Returns true iff the name of test_case matches name_.
bool operator()(const TestCase* test_case) const { bool operator()(const TestCase* test_case) const {
return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0; return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0;
} }
private: private:

View File

@ -0,0 +1,226 @@
#include <opencv2/opencv.hpp>
#include "opencv2/video/tracking.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include "opencv2/highgui/highgui.hpp"
using namespace cv;
using namespace cv::ml;
#define WIDTH 841
#define HEIGHT 594
struct Data
{
Mat img;
Mat samples;
Mat responses;
RNG rng;
//Point points[2];
Data()
{
img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
imshow("Train svmsgd", img);
}
};
bool doTrain(const Mat samples,const Mat responses, Mat &weights, float &shift);
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2]);
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
void fillSegments(std::vector<std::pair<Point,Point> > &segments);
void redraw(Data data, const Point points[2]);
void addPointsRetrainAndRedraw(Data &data, int x, int y);
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
{
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
svmsgd->setOptimalParameters(SVMSGD::ASGD);
svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 50000, 0.0000001));
svmsgd->setLambda(0.01);
svmsgd->setGamma0(1);
// svmsgd->setC(5);
cv::Ptr<TrainData> train_data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
svmsgd->train( train_data );
if (svmsgd->isTrained())
{
weights = svmsgd->getWeights();
shift = svmsgd->getShift();
std::cout << weights << std::endl;
std::cout << shift << std::endl;
return true;
}
return false;
}
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
{
int x = 0;
int y = 0;
//с (0,0) всё плохо
if (segment.first.x == segment.second.x && weights.at<float>(1) != 0)
{
x = segment.first.x;
y = -(weights.at<float>(0) * x + shift) / weights.at<float>(1);
if (y >= 0 && y <= HEIGHT)
{
crossPoint.x = x;
crossPoint.y = y;
return true;
}
}
else if (segment.first.y == segment.second.y && weights.at<float>(0) != 0)
{
y = segment.first.y;
x = - (weights.at<float>(1) * y + shift) / weights.at<float>(0);
if (x >= 0 && x <= WIDTH)
{
crossPoint.x = x;
crossPoint.y = y;
return true;
}
}
return false;
}
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2])
{
if (weights.empty())
{
return false;
}
int foundPointsCount = 0;
std::vector<std::pair<Point,Point> > segments;
fillSegments(segments);
for (int i = 0; i < 4; i++)
{
if (findCrossPoint(weights, shift, segments[i], points[foundPointsCount]))
foundPointsCount++;
if (foundPointsCount > 2)
break;
}
return true;
}
void fillSegments(std::vector<std::pair<Point,Point> > &segments)
{
std::pair<Point,Point> curSegment;
curSegment.first = Point(0,0);
curSegment.second = Point(0,HEIGHT);
segments.push_back(curSegment);
curSegment.first = Point(0,0);
curSegment.second = Point(WIDTH,0);
segments.push_back(curSegment);
curSegment.first = Point(WIDTH,0);
curSegment.second = Point(WIDTH,HEIGHT);
segments.push_back(curSegment);
curSegment.first = Point(0,HEIGHT);
curSegment.second = Point(WIDTH,HEIGHT);
segments.push_back(curSegment);
}
void redraw(Data data, const Point points[2])
{
data.img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
Point center;
int radius = 3;
Scalar color;
for (int i = 0; i < data.samples.rows; i++)
{
center.x = data.samples.at<float>(i,0);
center.y = data.samples.at<float>(i,1);
color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
circle(data.img, center, radius, color, 5);
}
line(data.img, points[0],points[1],cv::Scalar(1,255,1));
imshow("Train svmsgd", data.img);
}
void addPointsRetrainAndRedraw(Data &data, int x, int y)
{
Mat currentSample(1, 2, CV_32F);
//start
/*
Mat _weights;
_weights.create(1, 2, CV_32FC1);
_weights.at<float>(0) = 1;
_weights.at<float>(1) = -1;
int _x, _y;
for (int i=0;i<199;i++)
{
_x = data.rng.uniform(0,800);
_y = data.rng.uniform(0,500);*/
currentSample.at<float>(0,0) = x;
currentSample.at<float>(0,1) = y;
//if (currentSample.dot(_weights) > 0)
//data.responses.push_back(1);
// else data.responses.push_back(-1);
//finish
data.samples.push_back(currentSample);
Mat weights(1, 2, CV_32F);
float shift = 0;
if (doTrain(data.samples, data.responses, weights, shift))
{
Point points[2];
shift = 0;
findPointsForLine(weights, shift, points);
redraw(data, points);
}
}
static void onMouse( int event, int x, int y, int, void* pData)
{
Data &data = *(Data*)pData;
switch( event )
{
case CV_EVENT_LBUTTONUP:
data.responses.push_back(1);
addPointsRetrainAndRedraw(data, x, y);
break;
case CV_EVENT_RBUTTONDOWN:
data.responses.push_back(-1);
addPointsRetrainAndRedraw(data, x, y);
break;
}
}
int main()
{
Data data;
setMouseCallback( "Train svmsgd", onMouse, &data );
waitKey();
return 0;
}