Refactored SVMSGD class
This commit is contained in:
parent
a2f0963d66
commit
40bf97c6d1
@ -75,6 +75,7 @@
|
||||
#endif
|
||||
#ifdef HAVE_OPENCV_ML
|
||||
#include "opencv2/ml.hpp"
|
||||
#include "opencv2/ml/svmsgd.hpp"
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
@ -1513,126 +1513,6 @@ CV_EXPORTS void randMVNormal( InputArray mean, InputArray cov, int nsamples, Out
|
||||
CV_EXPORTS void createConcentricSpheresTestSet( int nsamples, int nfeatures, int nclasses,
|
||||
OutputArray samples, OutputArray responses);
|
||||
|
||||
/****************************************************************************************\
|
||||
* Stochastic Gradient Descent SVM Classifier *
|
||||
\****************************************************************************************/
|
||||
|
||||
/*!
|
||||
@brief Stochastic Gradient Descent SVM classifier
|
||||
|
||||
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large.
|
||||
The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which
|
||||
is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example).
|
||||
|
||||
First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined.
|
||||
|
||||
Then the SVM model can be trained using the train features and the correspondent labels.
|
||||
|
||||
After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically.
|
||||
|
||||
@code
|
||||
// Initialize object
|
||||
SVMSGD SvmSgd;
|
||||
|
||||
// Train the Stochastic Gradient Descent SVM
|
||||
SvmSgd.train(trainFeatures, labels);
|
||||
|
||||
// Predict label for the new feature vector (1xM)
|
||||
predictedLabel = SvmSgd.predict(newFeatureVector);
|
||||
@endcode
|
||||
|
||||
*/
|
||||
class CV_EXPORTS_W SVMSGD {
|
||||
|
||||
public:
|
||||
/** @brief SGDSVM constructor.
|
||||
|
||||
@param lambda regularization
|
||||
@param learnRate learning rate
|
||||
@param nIterations number of training iterations
|
||||
|
||||
*/
|
||||
SVMSGD(float lambda = 0.000001, float learnRate = 2, uint nIterations = 100000);
|
||||
|
||||
/** @brief SGDSVM constructor.
|
||||
|
||||
@param updateFrequency online update frequency
|
||||
@param learnRateDecay learn rate decay over time: learnRate = learnRate * learnDecay
|
||||
@param lambda regularization
|
||||
@param learnRate learning rate
|
||||
@param nIterations number of training iterations
|
||||
|
||||
*/
|
||||
SVMSGD(uint updateFrequency, float learnRateDecay = 1, float lambda = 0.000001, float learnRate = 2, uint nIterations = 100000);
|
||||
virtual ~SVMSGD();
|
||||
virtual SVMSGD* clone() const;
|
||||
|
||||
/** @brief Train the SGDSVM classifier.
|
||||
|
||||
The function trains the SGDSVM classifier using the train features and the correspondent labels (-1 or 1).
|
||||
|
||||
@param trainFeatures features used for training. Each row is a new sample.
|
||||
@param labels mat (size Nx1 with N = number of features) with the label of each training feature.
|
||||
|
||||
*/
|
||||
virtual void train(cv::Mat trainFeatures, cv::Mat labels);
|
||||
|
||||
/** @brief Predict the label of a new feature vector.
|
||||
|
||||
The function predicts and returns the label of a new feature vector, using the previously trained SVM model.
|
||||
|
||||
@param newFeature new feature vector used for prediction
|
||||
|
||||
*/
|
||||
virtual float predict(cv::Mat newFeature);
|
||||
|
||||
/** @brief Returns the weights of the trained model.
|
||||
|
||||
*/
|
||||
virtual std::vector<float> getWeights(){ return _weights; };
|
||||
|
||||
/** @brief Sets the weights of the trained model.
|
||||
|
||||
@param weights weights used to predict the label of a new feature vector.
|
||||
|
||||
*/
|
||||
virtual void setWeights(std::vector<float> weights){ _weights = weights; };
|
||||
|
||||
private:
|
||||
void updateWeights();
|
||||
void generateRandomIndex();
|
||||
float calcInnerProduct(float *rowDataPointer);
|
||||
void updateWeights(float innerProduct, float *rowDataPointer, int label);
|
||||
|
||||
// Vector with SVM weights
|
||||
std::vector<float> _weights;
|
||||
|
||||
// Random index generation
|
||||
long long int _randomNumber;
|
||||
unsigned int _randomIndex;
|
||||
|
||||
// Number of features and samples
|
||||
unsigned int _nFeatures;
|
||||
unsigned int _nTrainSamples;
|
||||
|
||||
// Parameters for learning
|
||||
float _lambda; //regularization
|
||||
float _learnRate; //learning rate
|
||||
unsigned int _nIterations; //number of training iterations
|
||||
|
||||
// Vars to control the features slider matrix
|
||||
bool _onlineUpdate;
|
||||
bool _initPredict;
|
||||
uint _slidingWindowSize;
|
||||
uint _predictSlidingWindowSize;
|
||||
float* _labelSlider;
|
||||
float _learnRateDecay;
|
||||
|
||||
// Mat with features slider and correspondent counter
|
||||
unsigned int _sliderCounter;
|
||||
cv::Mat _featuresSlider;
|
||||
|
||||
};
|
||||
|
||||
//! @} ml
|
||||
|
||||
|
134
modules/ml/include/opencv2/ml/svmsgd.hpp
Normal file
134
modules/ml/include/opencv2/ml/svmsgd.hpp
Normal file
@ -0,0 +1,134 @@
|
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
||||
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
|
||||
// Copyright (C) 2014, Itseez Inc, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#ifndef __OPENCV_ML_SVMSGD_HPP__
|
||||
#define __OPENCV_ML_SVMSGD_HPP__
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
#include "opencv2/ml.hpp"
|
||||
|
||||
namespace cv
|
||||
{
|
||||
namespace ml
|
||||
{
|
||||
|
||||
|
||||
/****************************************************************************************\
|
||||
* Stochastic Gradient Descent SVM Classifier *
|
||||
\****************************************************************************************/
|
||||
|
||||
/*!
|
||||
@brief Stochastic Gradient Descent SVM classifier
|
||||
|
||||
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large.
|
||||
The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which
|
||||
is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example).
|
||||
|
||||
First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined.
|
||||
|
||||
Then the SVM model can be trained using the train features and the correspondent labels.
|
||||
|
||||
After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically.
|
||||
|
||||
@code
|
||||
// Initialize object
|
||||
SVMSGD SvmSgd;
|
||||
|
||||
// Train the Stochastic Gradient Descent SVM
|
||||
SvmSgd.train(trainFeatures, labels);
|
||||
|
||||
// Predict label for the new feature vector (1xM)
|
||||
predictedLabel = SvmSgd.predict(newFeatureVector);
|
||||
@endcode
|
||||
|
||||
*/
|
||||
|
||||
class CV_EXPORTS_W SVMSGD : public cv::ml::StatModel
|
||||
{
|
||||
public:
|
||||
|
||||
enum SvmsgdType
|
||||
{
|
||||
ILLEGAL_VALUE,
|
||||
SGD, //Stochastic Gradient Descent
|
||||
ASGD //Average Stochastic Gradient Descent
|
||||
};
|
||||
|
||||
/**
|
||||
* @return the weights of the trained model.
|
||||
*/
|
||||
CV_WRAP virtual Mat getWeights() = 0;
|
||||
|
||||
CV_WRAP virtual float getShift() = 0;
|
||||
|
||||
CV_WRAP static Ptr<SVMSGD> create();
|
||||
|
||||
CV_WRAP virtual void setOptimalParameters(int type = ASGD) = 0;
|
||||
|
||||
CV_WRAP virtual int getType() const = 0;
|
||||
|
||||
CV_WRAP virtual void setType(int type) = 0;
|
||||
|
||||
CV_WRAP virtual float getLambda() const = 0;
|
||||
|
||||
CV_WRAP virtual void setLambda(float lambda) = 0;
|
||||
|
||||
CV_WRAP virtual float getGamma0() const = 0;
|
||||
|
||||
CV_WRAP virtual void setGamma0(float gamma0) = 0;
|
||||
|
||||
CV_WRAP virtual float getC() const = 0;
|
||||
|
||||
CV_WRAP virtual void setC(float c) = 0;
|
||||
|
||||
CV_WRAP virtual cv::TermCriteria getTermCriteria() const = 0;
|
||||
|
||||
CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0;
|
||||
};
|
||||
|
||||
} //ml
|
||||
} //cv
|
||||
|
||||
#endif // __clpusplus
|
||||
#endif // __OPENCV_ML_SVMSGD_HPP
|
@ -45,7 +45,7 @@
|
||||
#include "opencv2/ml.hpp"
|
||||
#include "opencv2/core/core_c.h"
|
||||
#include "opencv2/core/utility.hpp"
|
||||
|
||||
#include "opencv2/ml/svmsgd.hpp"
|
||||
#include "opencv2/core/private.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
@ -41,161 +41,430 @@
|
||||
//M*/
|
||||
|
||||
#include "precomp.hpp"
|
||||
#include "limits"
|
||||
|
||||
/****************************************************************************************\
|
||||
* Stochastic Gradient Descent SVM Classifier *
|
||||
\****************************************************************************************/
|
||||
|
||||
namespace cv {
|
||||
namespace ml {
|
||||
namespace cv
|
||||
{
|
||||
namespace ml
|
||||
{
|
||||
|
||||
SVMSGD::SVMSGD(float lambda, float learnRate, uint nIterations){
|
||||
class SVMSGDImpl : public SVMSGD
|
||||
{
|
||||
|
||||
// Initialize with random seed
|
||||
_randomNumber = 1;
|
||||
public:
|
||||
SVMSGDImpl();
|
||||
|
||||
// Initialize constants
|
||||
_slidingWindowSize = 0;
|
||||
_nFeatures = 0;
|
||||
_predictSlidingWindowSize = 1;
|
||||
virtual ~SVMSGDImpl() {}
|
||||
|
||||
// Initialize sliderCounter at index 0
|
||||
_sliderCounter = 0;
|
||||
virtual bool train(const Ptr<TrainData>& data, int);
|
||||
|
||||
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
|
||||
|
||||
virtual bool isClassifier() const { return params.svmsgdType == SGD || params.svmsgdType == ASGD; }
|
||||
|
||||
virtual bool isTrained() const;
|
||||
|
||||
virtual void clear();
|
||||
|
||||
virtual void write(FileStorage& fs) const;
|
||||
|
||||
virtual void read(const FileNode& fn);
|
||||
|
||||
virtual Mat getWeights(){ return weights_; }
|
||||
|
||||
virtual float getShift(){ return shift_; }
|
||||
|
||||
virtual int getVarCount() const { return weights_.cols; }
|
||||
|
||||
virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
|
||||
|
||||
virtual void setOptimalParameters(int type = ASGD);
|
||||
|
||||
virtual int getType() const;
|
||||
|
||||
virtual void setType(int type);
|
||||
|
||||
CV_IMPL_PROPERTY(float, Lambda, params.lambda)
|
||||
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
|
||||
CV_IMPL_PROPERTY(float, C, params.c)
|
||||
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
|
||||
|
||||
private:
|
||||
void updateWeights(InputArray sample, bool is_first_class, float gamma);
|
||||
float calcShift(InputArray trainSamples, InputArray trainResponses) const;
|
||||
std::pair<bool,bool> areClassesEmpty(Mat responses);
|
||||
void writeParams( FileStorage& fs ) const;
|
||||
void readParams( const FileNode& fn );
|
||||
static inline bool isFirstClass(float val) { return val > 0; }
|
||||
|
||||
|
||||
// Vector with SVM weights
|
||||
Mat weights_;
|
||||
float shift_;
|
||||
|
||||
// Random index generation
|
||||
RNG rng_;
|
||||
|
||||
// Parameters for learning
|
||||
_lambda = lambda; // regularization
|
||||
_learnRate = learnRate; // learning rate (ideally should be large at beginning and decay each iteration)
|
||||
_nIterations = nIterations; // number of training iterations
|
||||
struct SVMSGDParams
|
||||
{
|
||||
float lambda; //regularization
|
||||
float gamma0; //learning rate
|
||||
float c;
|
||||
TermCriteria termCrit;
|
||||
SvmsgdType svmsgdType;
|
||||
};
|
||||
|
||||
// True only in the first predict iteration
|
||||
_initPredict = true;
|
||||
SVMSGDParams params;
|
||||
};
|
||||
|
||||
// Online update flag
|
||||
_onlineUpdate = false;
|
||||
Ptr<SVMSGD> SVMSGD::create()
|
||||
{
|
||||
return makePtr<SVMSGDImpl>();
|
||||
}
|
||||
|
||||
SVMSGD::SVMSGD(uint updateFrequency, float learnRateDecay, float lambda, float learnRate, uint nIterations){
|
||||
|
||||
// Initialize with random seed
|
||||
_randomNumber = 1;
|
||||
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
{
|
||||
clear();
|
||||
|
||||
// Initialize constants
|
||||
_slidingWindowSize = 0;
|
||||
_nFeatures = 0;
|
||||
_predictSlidingWindowSize = updateFrequency;
|
||||
Mat trainSamples = data->getTrainSamples();
|
||||
|
||||
// Initialize sliderCounter at index 0
|
||||
_sliderCounter = 0;
|
||||
// Initialize varCount
|
||||
int trainSamplesCount_ = trainSamples.rows;
|
||||
int varCount = trainSamples.cols;
|
||||
|
||||
// Parameters for learning
|
||||
_lambda = lambda; // regularization
|
||||
_learnRate = learnRate; // learning rate (ideally should be large at beginning and decay each iteration)
|
||||
_nIterations = nIterations; // number of training iterations
|
||||
|
||||
// True only in the first predict iteration
|
||||
_initPredict = true;
|
||||
|
||||
// Online update flag
|
||||
_onlineUpdate = true;
|
||||
|
||||
// Learn rate decay: _learnRate = _learnRate * _learnDecay
|
||||
_learnRateDecay = learnRateDecay;
|
||||
}
|
||||
|
||||
SVMSGD::~SVMSGD(){
|
||||
|
||||
}
|
||||
|
||||
SVMSGD* SVMSGD::clone() const{
|
||||
return new SVMSGD(*this);
|
||||
}
|
||||
|
||||
void SVMSGD::train(cv::Mat trainFeatures, cv::Mat labels){
|
||||
|
||||
// Initialize _nFeatures
|
||||
_slidingWindowSize = trainFeatures.rows;
|
||||
_nFeatures = trainFeatures.cols;
|
||||
|
||||
float innerProduct;
|
||||
// Initialize weights vector with zeros
|
||||
if (_weights.size()==0){
|
||||
_weights.reserve(_nFeatures);
|
||||
for (uint feat = 0; feat < _nFeatures; ++feat){
|
||||
_weights.push_back(0.0);
|
||||
}
|
||||
weights_ = Mat::zeros(1, varCount, CV_32F);
|
||||
|
||||
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
|
||||
|
||||
std::pair<bool,bool> are_empty = areClassesEmpty(trainResponses);
|
||||
|
||||
if ( are_empty.first && are_empty.second )
|
||||
{
|
||||
weights_.release();
|
||||
return false;
|
||||
}
|
||||
if ( are_empty.first || are_empty.second )
|
||||
{
|
||||
shift_ = are_empty.first ? -1 : 1;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
Mat currentSample;
|
||||
float gamma = 0;
|
||||
Mat lastWeights = Mat::zeros(1, varCount, CV_32F); //weights vector for calculating terminal criterion
|
||||
Mat averageWeights; //average weights vector for ASGD model
|
||||
double err = DBL_MAX;
|
||||
if (params.svmsgdType == ASGD)
|
||||
{
|
||||
averageWeights = Mat::zeros(1, varCount, CV_32F);
|
||||
}
|
||||
|
||||
// Stochastic gradient descent SVM
|
||||
for (uint iter = 0; iter < _nIterations; ++iter){
|
||||
generateRandomIndex();
|
||||
innerProduct = calcInnerProduct(trainFeatures.ptr<float>(_randomIndex));
|
||||
int label = (labels.at<int>(_randomIndex,0) > 0) ? 1 : -1; // ensure that labels are -1 or 1
|
||||
updateWeights(innerProduct, trainFeatures.ptr<float>(_randomIndex), label );
|
||||
}
|
||||
}
|
||||
for (int iter = 0; (iter < params.termCrit.maxCount)&&(err > params.termCrit.epsilon); iter++)
|
||||
{
|
||||
//generate sample number
|
||||
int randomNumber = rng_.uniform(0, trainSamplesCount_);
|
||||
|
||||
float SVMSGD::predict(cv::Mat newFeature){
|
||||
float innerProduct;
|
||||
currentSample = trainSamples.row(randomNumber);
|
||||
|
||||
if (_initPredict){
|
||||
_nFeatures = newFeature.cols;
|
||||
_slidingWindowSize = _predictSlidingWindowSize;
|
||||
_featuresSlider = cv::Mat::zeros(_slidingWindowSize, _nFeatures, CV_32F);
|
||||
_initPredict = false;
|
||||
_labelSlider = new float[_predictSlidingWindowSize]();
|
||||
_learnRate = _learnRate * _learnRateDecay;
|
||||
}
|
||||
//update gamma
|
||||
gamma = params.gamma0 * std::pow((1 + params.lambda * params.gamma0 * (float)iter), (-params.c));
|
||||
|
||||
innerProduct = calcInnerProduct(newFeature.ptr<float>(0));
|
||||
bool is_first_class = isFirstClass(trainResponses.at<float>(randomNumber));
|
||||
updateWeights( currentSample, is_first_class, gamma );
|
||||
|
||||
// Resultant label (-1 or 1)
|
||||
int label = (innerProduct>=0) ? 1 : -1;
|
||||
|
||||
if (_onlineUpdate){
|
||||
// Update the featuresSlider with newFeature and _labelSlider with label
|
||||
newFeature.row(0).copyTo(_featuresSlider.row(_sliderCounter));
|
||||
_labelSlider[_sliderCounter] = float(label);
|
||||
|
||||
// Update weights with a random index
|
||||
if (_sliderCounter == _slidingWindowSize-1){
|
||||
generateRandomIndex();
|
||||
updateWeights(innerProduct, _featuresSlider.ptr<float>(_randomIndex), int(_labelSlider[_randomIndex]) );
|
||||
//average weights (only for ASGD model)
|
||||
if (params.svmsgdType == ASGD)
|
||||
{
|
||||
averageWeights = ((float)iter/ (1 + (float)iter)) * averageWeights + weights_ / (1 + (float) iter);
|
||||
}
|
||||
|
||||
// _sliderCounter++ if < _slidingWindowSize
|
||||
_sliderCounter = (_sliderCounter == _slidingWindowSize-1) ? 0 : (_sliderCounter+1);
|
||||
err = norm(weights_ - lastWeights);
|
||||
weights_.copyTo(lastWeights);
|
||||
}
|
||||
|
||||
return float(label);
|
||||
}
|
||||
|
||||
void SVMSGD::generateRandomIndex(){
|
||||
// Choose random sample, using Mikolov's fast almost-uniform random number
|
||||
_randomNumber = _randomNumber * (unsigned long long) 25214903917 + 11;
|
||||
_randomIndex = uint(_randomNumber % (unsigned long long) _slidingWindowSize);
|
||||
}
|
||||
|
||||
float SVMSGD::calcInnerProduct(float *rowDataPointer){
|
||||
float innerProduct = 0;
|
||||
for (uint feat = 0; feat < _nFeatures; ++feat){
|
||||
innerProduct += _weights[feat] * rowDataPointer[feat];
|
||||
if (params.svmsgdType == ASGD)
|
||||
{
|
||||
weights_ = averageWeights;
|
||||
}
|
||||
return innerProduct;
|
||||
|
||||
shift_ = calcShift(trainSamples, trainResponses);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void SVMSGD::updateWeights(float innerProduct, float *rowDataPointer, int label){
|
||||
if (label * innerProduct > 1) {
|
||||
std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
|
||||
{
|
||||
std::pair<bool,bool> are_classes_empty(true, true);
|
||||
int limit_index = responses.rows;
|
||||
|
||||
for(int index = 0; index < limit_index; index++)
|
||||
{
|
||||
if (isFirstClass(responses.at<float>(index,0)))
|
||||
are_classes_empty.first = false;
|
||||
else
|
||||
are_classes_empty.second = false;
|
||||
|
||||
if (!are_classes_empty.first && ! are_classes_empty.second)
|
||||
break;
|
||||
}
|
||||
|
||||
return are_classes_empty;
|
||||
}
|
||||
|
||||
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
||||
{
|
||||
float distance_to_classes[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
|
||||
|
||||
Mat trainSamples = _samples.getMat();
|
||||
int trainSamplesCount = trainSamples.rows;
|
||||
|
||||
Mat trainResponses = _responses.getMat();
|
||||
|
||||
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
|
||||
{
|
||||
Mat currentSample = trainSamples.row(samplesIndex);
|
||||
float scalar_product = currentSample.dot(weights_);
|
||||
|
||||
bool is_first_class = isFirstClass(trainResponses.at<float>(samplesIndex));
|
||||
int index = is_first_class ? 0:1;
|
||||
float sign_to_mul = is_first_class ? 1 : -1;
|
||||
float cur_distance = scalar_product * sign_to_mul ;
|
||||
|
||||
if (cur_distance < distance_to_classes[index])
|
||||
{
|
||||
distance_to_classes[index] = cur_distance;
|
||||
}
|
||||
}
|
||||
|
||||
//todo: areClassesEmpty(); make const;
|
||||
return -(distance_to_classes[0] - distance_to_classes[1]) / 2.f;
|
||||
}
|
||||
|
||||
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
|
||||
{
|
||||
float result = 0;
|
||||
cv::Mat samples = _samples.getMat();
|
||||
int nSamples = samples.rows;
|
||||
cv::Mat results;
|
||||
|
||||
CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32F );
|
||||
|
||||
if( _results.needed() )
|
||||
{
|
||||
_results.create( nSamples, 1, samples.type() );
|
||||
results = _results.getMat();
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Assert( nSamples == 1 );
|
||||
results = Mat(1, 1, CV_32F, &result);
|
||||
}
|
||||
|
||||
Mat currentSample;
|
||||
float criterion = 0;
|
||||
|
||||
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
|
||||
{
|
||||
currentSample = samples.row(sampleIndex);
|
||||
criterion = currentSample.dot(weights_) + shift_;
|
||||
results.at<float>(sampleIndex) = (criterion >= 0) ? 1 : -1;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::updateWeights(InputArray _sample, bool is_first_class, float gamma)
|
||||
{
|
||||
Mat sample = _sample.getMat();
|
||||
|
||||
int responce = is_first_class ? 1 : -1; // ensure that trainResponses are -1 or 1
|
||||
|
||||
if ( sample.dot(weights_) * responce > 1)
|
||||
{
|
||||
// Not a support vector, only apply weight decay
|
||||
for (uint feat = 0; feat < _nFeatures; feat++) {
|
||||
_weights[feat] -= _learnRate * _lambda * _weights[feat];
|
||||
}
|
||||
} else {
|
||||
weights_ *= (1.f - gamma * params.lambda);
|
||||
}
|
||||
else
|
||||
{
|
||||
// It's a support vector, add it to the weights
|
||||
for (uint feat = 0; feat < _nFeatures; feat++) {
|
||||
_weights[feat] -= _learnRate * (_lambda * _weights[feat] - label * rowDataPointer[feat]);
|
||||
}
|
||||
weights_ -= (gamma * params.lambda) * weights_ - gamma * responce * sample;
|
||||
//std::cout << "sample " << sample << std::endl;
|
||||
//std::cout << "weights_ " << weights_ << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool SVMSGDImpl::isTrained() const
|
||||
{
|
||||
return !weights_.empty();
|
||||
}
|
||||
|
||||
void SVMSGDImpl::write(FileStorage& fs) const
|
||||
{
|
||||
if( !isTrained() )
|
||||
CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );
|
||||
|
||||
writeParams( fs );
|
||||
|
||||
fs << "shift" << shift_;
|
||||
fs << "weights" << weights_;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||
{
|
||||
String SvmsgdTypeStr;
|
||||
|
||||
switch (params.svmsgdType)
|
||||
{
|
||||
case SGD:
|
||||
SvmsgdTypeStr = "SGD";
|
||||
break;
|
||||
case ASGD:
|
||||
SvmsgdTypeStr = "ASGD";
|
||||
break;
|
||||
case ILLEGAL_VALUE:
|
||||
SvmsgdTypeStr = format("Uknown_%d", params.svmsgdType);
|
||||
default:
|
||||
std::cout << "params.svmsgdType isn't initialized" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
fs << "svmsgdType" << SvmsgdTypeStr;
|
||||
|
||||
fs << "lambda" << params.lambda;
|
||||
fs << "gamma0" << params.gamma0;
|
||||
fs << "c" << params.c;
|
||||
|
||||
fs << "term_criteria" << "{:";
|
||||
if( params.termCrit.type & TermCriteria::EPS )
|
||||
fs << "epsilon" << params.termCrit.epsilon;
|
||||
if( params.termCrit.type & TermCriteria::COUNT )
|
||||
fs << "iterations" << params.termCrit.maxCount;
|
||||
fs << "}";
|
||||
}
|
||||
|
||||
|
||||
|
||||
void SVMSGDImpl::read(const FileNode& fn)
|
||||
{
|
||||
clear();
|
||||
|
||||
readParams(fn);
|
||||
|
||||
shift_ = (float) fn["shift"];
|
||||
fn["weights"] >> weights_;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::readParams( const FileNode& fn )
|
||||
{
|
||||
String svmsgdTypeStr = (String)fn["svmsgdType"];
|
||||
SvmsgdType svmsgdType =
|
||||
svmsgdTypeStr == "SGD" ? SGD :
|
||||
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_VALUE;
|
||||
|
||||
if( svmsgdType == ILLEGAL_VALUE )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
|
||||
|
||||
params.svmsgdType = svmsgdType;
|
||||
|
||||
CV_Assert ( fn["lambda"].isReal() );
|
||||
params.lambda = (float)fn["lambda"];
|
||||
|
||||
CV_Assert ( fn["gamma0"].isReal() );
|
||||
params.gamma0 = (float)fn["gamma0"];
|
||||
|
||||
CV_Assert ( fn["c"].isReal() );
|
||||
params.c = (float)fn["c"];
|
||||
|
||||
FileNode tcnode = fn["term_criteria"];
|
||||
if( !tcnode.empty() )
|
||||
{
|
||||
params.termCrit.epsilon = (double)tcnode["epsilon"];
|
||||
params.termCrit.maxCount = (int)tcnode["iterations"];
|
||||
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
|
||||
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
|
||||
}
|
||||
else
|
||||
params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON );
|
||||
|
||||
}
|
||||
|
||||
void SVMSGDImpl::clear()
|
||||
{
|
||||
weights_.release();
|
||||
shift_ = 0;
|
||||
}
|
||||
|
||||
|
||||
SVMSGDImpl::SVMSGDImpl()
|
||||
{
|
||||
clear();
|
||||
rng_(0);
|
||||
|
||||
params.svmsgdType = ILLEGAL_VALUE;
|
||||
|
||||
// Parameters for learning
|
||||
params.lambda = 0; // regularization
|
||||
params.gamma0 = 0; // learning rate (ideally should be large at beginning and decay each iteration)
|
||||
params.c = 0;
|
||||
|
||||
TermCriteria _termCrit(TermCriteria::COUNT + TermCriteria::EPS, 0, 0);
|
||||
params.termCrit = _termCrit;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setOptimalParameters(int type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case SGD:
|
||||
params.svmsgdType = SGD;
|
||||
params.lambda = 0.00001;
|
||||
params.gamma0 = 0.05;
|
||||
params.c = 1;
|
||||
params.termCrit.maxCount = 50000;
|
||||
params.termCrit.epsilon = 0.00000001;
|
||||
break;
|
||||
|
||||
case ASGD:
|
||||
params.svmsgdType = ASGD;
|
||||
params.lambda = 0.00001;
|
||||
params.gamma0 = 0.5;
|
||||
params.c = 0.75;
|
||||
params.termCrit.maxCount = 100000;
|
||||
params.termCrit.epsilon = 0.000001;
|
||||
break;
|
||||
|
||||
default:
|
||||
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
|
||||
}
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setType(int type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case SGD:
|
||||
params.svmsgdType = SGD;
|
||||
break;
|
||||
case ASGD:
|
||||
params.svmsgdType = ASGD;
|
||||
break;
|
||||
default:
|
||||
params.svmsgdType = ILLEGAL_VALUE;
|
||||
}
|
||||
}
|
||||
|
||||
int SVMSGDImpl::getType() const
|
||||
{
|
||||
return params.svmsgdType;
|
||||
}
|
||||
} //ml
|
||||
} //cv
|
||||
|
@ -193,6 +193,16 @@ int str_to_boost_type( String& str )
|
||||
// 8. rtrees
|
||||
// 9. ertrees
|
||||
|
||||
int str_to_svmsgd_type( String& str )
|
||||
{
|
||||
if ( !str.compare("SGD") )
|
||||
return SVMSGD::SGD;
|
||||
if ( !str.compare("ASGD") )
|
||||
return SVMSGD::ASGD;
|
||||
CV_Error( CV_StsBadArg, "incorrect boost type string" );
|
||||
return -1;
|
||||
}
|
||||
|
||||
// ---------------------------------- MLBaseTest ---------------------------------------------------
|
||||
|
||||
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
|
||||
@ -248,7 +258,9 @@ void CV_MLBaseTest::run( int )
|
||||
{
|
||||
string filename = ts->get_data_path();
|
||||
filename += get_validation_filename();
|
||||
|
||||
validationFS.open( filename, FileStorage::READ );
|
||||
|
||||
read_params( *validationFS );
|
||||
|
||||
int code = cvtest::TS::OK;
|
||||
@ -436,6 +448,21 @@ int CV_MLBaseTest::train( int testCaseIdx )
|
||||
model = m;
|
||||
}
|
||||
|
||||
else if( modelName == CV_SVMSGD )
|
||||
{
|
||||
String svmsgdTypeStr;
|
||||
modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
|
||||
Ptr<SVMSGD> m = SVMSGD::create();
|
||||
int type = str_to_svmsgd_type( svmsgdTypeStr );
|
||||
m->setType(type);
|
||||
//m->setType(str_to_svmsgd_type( svmsgdTypeStr ));
|
||||
m->setLambda(modelParamsNode["lambda"]);
|
||||
m->setGamma0(modelParamsNode["gamma0"]);
|
||||
m->setC(modelParamsNode["c"]);
|
||||
m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001));
|
||||
model = m;
|
||||
}
|
||||
|
||||
if( !model.empty() )
|
||||
is_trained = model->train(data, 0);
|
||||
|
||||
@ -457,7 +484,7 @@ float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
|
||||
else if( modelName == CV_ANN )
|
||||
err = ann_calc_error( model, data, cls_map, type, resp );
|
||||
else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
|
||||
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST )
|
||||
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST || modelName == CV_SVMSGD )
|
||||
err = model->calcError( data, true, _resp );
|
||||
if( !_resp.empty() && resp )
|
||||
_resp.convertTo(*resp, CV_32F);
|
||||
@ -485,6 +512,8 @@ void CV_MLBaseTest::load( const char* filename )
|
||||
model = Algorithm::load<Boost>( filename );
|
||||
else if( modelName == CV_RTREES )
|
||||
model = Algorithm::load<RTrees>( filename );
|
||||
else if( modelName == CV_SVMSGD )
|
||||
model = Algorithm::load<SVMSGD>( filename );
|
||||
else
|
||||
CV_Error( CV_StsNotImplemented, "invalid stat model name");
|
||||
}
|
||||
|
@ -13,6 +13,7 @@
|
||||
#include <map>
|
||||
#include "opencv2/ts.hpp"
|
||||
#include "opencv2/ml.hpp"
|
||||
#include "opencv2/ml/svmsgd.hpp"
|
||||
#include "opencv2/core/core_c.h"
|
||||
|
||||
#define CV_NBAYES "nbayes"
|
||||
@ -24,6 +25,7 @@
|
||||
#define CV_BOOST "boost"
|
||||
#define CV_RTREES "rtrees"
|
||||
#define CV_ERTREES "ertrees"
|
||||
#define CV_SVMSGD "svmsgd"
|
||||
|
||||
enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 };
|
||||
|
||||
@ -38,6 +40,7 @@ using cv::ml::ANN_MLP;
|
||||
using cv::ml::DTrees;
|
||||
using cv::ml::Boost;
|
||||
using cv::ml::RTrees;
|
||||
using cv::ml::SVMSGD;
|
||||
|
||||
class CV_MLBaseTest : public cvtest::BaseTest
|
||||
{
|
||||
|
@ -150,12 +150,20 @@ int CV_SLMLTest::validate_test_results( int testCaseIdx )
|
||||
|
||||
TEST(ML_NaiveBayes, save_load) { CV_SLMLTest test( CV_NBAYES ); test.safe_run(); }
|
||||
TEST(ML_KNearest, save_load) { CV_SLMLTest test( CV_KNEAREST ); test.safe_run(); }
|
||||
TEST(ML_SVM, save_load) { CV_SLMLTest test( CV_SVM ); test.safe_run(); }
|
||||
TEST(ML_SVM, save_load)
|
||||
{
|
||||
CV_SLMLTest test( CV_SVM );
|
||||
test.safe_run();
|
||||
}
|
||||
TEST(ML_ANN, save_load) { CV_SLMLTest test( CV_ANN ); test.safe_run(); }
|
||||
TEST(ML_DTree, save_load) { CV_SLMLTest test( CV_DTREE ); test.safe_run(); }
|
||||
TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
|
||||
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
|
||||
TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
|
||||
TEST(MV_SVMSGD, save_load){
|
||||
CV_SLMLTest test( CV_SVMSGD );
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
class CV_LegacyTest : public cvtest::BaseTest
|
||||
{
|
||||
@ -201,6 +209,8 @@ protected:
|
||||
model = Algorithm::load<SVM>(filename);
|
||||
else if (modelName == CV_RTREES)
|
||||
model = Algorithm::load<RTrees>(filename);
|
||||
else if (modelName == CV_SVMSGD)
|
||||
model = Algorithm::load<SVMSGD>(filename);
|
||||
if (!model)
|
||||
{
|
||||
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||
@ -260,6 +270,11 @@ TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushro
|
||||
TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_SVMSGD, legacy_load)
|
||||
{
|
||||
CV_LegacyTest test(CV_SVMSGD, "_waveform.xml");
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
/*TEST(ML_SVM, throw_exception_when_save_untrained_model)
|
||||
{
|
||||
|
182
modules/ml/test/test_svmsgd.cpp
Normal file
182
modules/ml/test/test_svmsgd.cpp
Normal file
@ -0,0 +1,182 @@
|
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// Intel License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of Intel Corporation may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "test_precomp.hpp"
|
||||
#include "opencv2/highgui.hpp"
|
||||
|
||||
using namespace cv;
|
||||
using namespace cv::ml;
|
||||
using cv::ml::SVMSGD;
|
||||
using cv::ml::TrainData;
|
||||
|
||||
|
||||
|
||||
class CV_SVMSGDTrainTest : public cvtest::BaseTest
|
||||
{
|
||||
public:
|
||||
CV_SVMSGDTrainTest(Mat _weights, float _shift);
|
||||
private:
|
||||
virtual void run( int start_from );
|
||||
float decisionFunction(Mat sample, Mat weights, float shift);
|
||||
|
||||
cv::Ptr<TrainData> data;
|
||||
cv::Mat testSamples;
|
||||
cv::Mat testResponses;
|
||||
static const int TEST_VALUE_LIMIT = 50;
|
||||
};
|
||||
|
||||
CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift)
|
||||
{
|
||||
int datasize = 100000;
|
||||
int varCount = weights.cols;
|
||||
cv::Mat samples = cv::Mat::zeros( datasize, varCount, CV_32FC1 );
|
||||
cv::Mat responses = cv::Mat::zeros( datasize, 1, CV_32FC1 );
|
||||
cv::RNG rng(0);
|
||||
|
||||
float lowerLimit = -TEST_VALUE_LIMIT;
|
||||
float upperLimit = TEST_VALUE_LIMIT;
|
||||
|
||||
|
||||
rng.fill(samples, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
for (int sampleIndex = 0; sampleIndex < datasize; sampleIndex++)
|
||||
{
|
||||
responses.at<float>( sampleIndex ) = decisionFunction(samples.row(sampleIndex), weights, shift) > 0 ? 1 : -1;
|
||||
}
|
||||
|
||||
data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
|
||||
|
||||
int testSamplesCount = 100000;
|
||||
|
||||
testSamples.create(testSamplesCount, varCount, CV_32FC1);
|
||||
rng.fill(testSamples, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
testResponses.create(testSamplesCount, 1, CV_32FC1);
|
||||
|
||||
for (int i = 0 ; i < testSamplesCount; i++)
|
||||
{
|
||||
testResponses.at<float>(i) = decisionFunction(testSamples.row(i), weights, shift) > 0 ? 1 : -1;
|
||||
}
|
||||
}
|
||||
|
||||
void CV_SVMSGDTrainTest::run( int /*start_from*/ )
|
||||
{
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
|
||||
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
||||
|
||||
svmsgd->train( data );
|
||||
|
||||
Mat responses;
|
||||
|
||||
svmsgd->predict(testSamples, responses);
|
||||
|
||||
int errCount = 0;
|
||||
int testSamplesCount = testSamples.rows;
|
||||
|
||||
for (int i = 0; i < testSamplesCount; i++)
|
||||
{
|
||||
if (responses.at<float>(i) * testResponses.at<float>(i) < 0 )
|
||||
errCount++;
|
||||
}
|
||||
|
||||
float err = (float)errCount / testSamplesCount;
|
||||
std::cout << "err " << err << std::endl;
|
||||
|
||||
if ( err > 0.01 )
|
||||
{
|
||||
ts->set_failed_test_info( cvtest::TS::FAIL_BAD_ACCURACY );
|
||||
}
|
||||
}
|
||||
|
||||
float CV_SVMSGDTrainTest::decisionFunction(Mat sample, Mat weights, float shift)
|
||||
{
|
||||
return sample.dot(weights) + shift;
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, train0)
|
||||
{
|
||||
int varCount = 2;
|
||||
|
||||
Mat weights;
|
||||
weights.create(1, varCount, CV_32FC1);
|
||||
weights.at<float>(0) = 1;
|
||||
weights.at<float>(1) = 0;
|
||||
|
||||
float shift = 5;
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, train1)
|
||||
{
|
||||
int varCount = 5;
|
||||
|
||||
Mat weights;
|
||||
weights.create(1, varCount, CV_32FC1);
|
||||
|
||||
float lowerLimit = -1;
|
||||
float upperLimit = 1;
|
||||
cv::RNG rng(0);
|
||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
|
||||
float shift = rng.uniform(-5.f, 5.f);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, train2)
|
||||
{
|
||||
int varCount = 100;
|
||||
|
||||
Mat weights;
|
||||
weights.create(1, varCount, CV_32FC1);
|
||||
|
||||
float lowerLimit = -1;
|
||||
float upperLimit = 1;
|
||||
cv::RNG rng(0);
|
||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
|
||||
float shift = rng.uniform(-1000.f, 1000.f);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift);
|
||||
test.safe_run();
|
||||
}
|
@ -5659,7 +5659,7 @@ class TestCaseNameIs {
|
||||
|
||||
// Returns true iff the name of test_case matches name_.
|
||||
bool operator()(const TestCase* test_case) const {
|
||||
return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0;
|
||||
return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0;
|
||||
}
|
||||
|
||||
private:
|
||||
|
226
samples/cpp/train_svmsgd.cpp
Normal file
226
samples/cpp/train_svmsgd.cpp
Normal file
@ -0,0 +1,226 @@
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include "opencv2/video/tracking.hpp"
|
||||
#include "opencv2/imgproc/imgproc.hpp"
|
||||
#include "opencv2/highgui/highgui.hpp"
|
||||
|
||||
using namespace cv;
|
||||
using namespace cv::ml;
|
||||
|
||||
#define WIDTH 841
|
||||
#define HEIGHT 594
|
||||
|
||||
struct Data
|
||||
{
|
||||
Mat img;
|
||||
Mat samples;
|
||||
Mat responses;
|
||||
RNG rng;
|
||||
//Point points[2];
|
||||
|
||||
Data()
|
||||
{
|
||||
img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
|
||||
imshow("Train svmsgd", img);
|
||||
}
|
||||
};
|
||||
|
||||
bool doTrain(const Mat samples,const Mat responses, Mat &weights, float &shift);
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2]);
|
||||
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments);
|
||||
void redraw(Data data, const Point points[2]);
|
||||
void addPointsRetrainAndRedraw(Data &data, int x, int y);
|
||||
|
||||
|
||||
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
|
||||
{
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
||||
svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 50000, 0.0000001));
|
||||
svmsgd->setLambda(0.01);
|
||||
svmsgd->setGamma0(1);
|
||||
// svmsgd->setC(5);
|
||||
|
||||
cv::Ptr<TrainData> train_data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
|
||||
svmsgd->train( train_data );
|
||||
|
||||
if (svmsgd->isTrained())
|
||||
{
|
||||
weights = svmsgd->getWeights();
|
||||
shift = svmsgd->getShift();
|
||||
|
||||
std::cout << weights << std::endl;
|
||||
std::cout << shift << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
|
||||
{
|
||||
int x = 0;
|
||||
int y = 0;
|
||||
//с (0,0) всё плохо
|
||||
if (segment.first.x == segment.second.x && weights.at<float>(1) != 0)
|
||||
{
|
||||
x = segment.first.x;
|
||||
y = -(weights.at<float>(0) * x + shift) / weights.at<float>(1);
|
||||
if (y >= 0 && y <= HEIGHT)
|
||||
{
|
||||
crossPoint.x = x;
|
||||
crossPoint.y = y;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else if (segment.first.y == segment.second.y && weights.at<float>(0) != 0)
|
||||
{
|
||||
y = segment.first.y;
|
||||
x = - (weights.at<float>(1) * y + shift) / weights.at<float>(0);
|
||||
if (x >= 0 && x <= WIDTH)
|
||||
{
|
||||
crossPoint.x = x;
|
||||
crossPoint.y = y;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2])
|
||||
{
|
||||
if (weights.empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
int foundPointsCount = 0;
|
||||
std::vector<std::pair<Point,Point> > segments;
|
||||
fillSegments(segments);
|
||||
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
if (findCrossPoint(weights, shift, segments[i], points[foundPointsCount]))
|
||||
foundPointsCount++;
|
||||
if (foundPointsCount > 2)
|
||||
break;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments)
|
||||
{
|
||||
std::pair<Point,Point> curSegment;
|
||||
|
||||
curSegment.first = Point(0,0);
|
||||
curSegment.second = Point(0,HEIGHT);
|
||||
segments.push_back(curSegment);
|
||||
|
||||
curSegment.first = Point(0,0);
|
||||
curSegment.second = Point(WIDTH,0);
|
||||
segments.push_back(curSegment);
|
||||
|
||||
curSegment.first = Point(WIDTH,0);
|
||||
curSegment.second = Point(WIDTH,HEIGHT);
|
||||
segments.push_back(curSegment);
|
||||
|
||||
curSegment.first = Point(0,HEIGHT);
|
||||
curSegment.second = Point(WIDTH,HEIGHT);
|
||||
segments.push_back(curSegment);
|
||||
}
|
||||
|
||||
void redraw(Data data, const Point points[2])
|
||||
{
|
||||
data.img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
|
||||
Point center;
|
||||
int radius = 3;
|
||||
Scalar color;
|
||||
for (int i = 0; i < data.samples.rows; i++)
|
||||
{
|
||||
center.x = data.samples.at<float>(i,0);
|
||||
center.y = data.samples.at<float>(i,1);
|
||||
color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
|
||||
circle(data.img, center, radius, color, 5);
|
||||
}
|
||||
line(data.img, points[0],points[1],cv::Scalar(1,255,1));
|
||||
|
||||
imshow("Train svmsgd", data.img);
|
||||
}
|
||||
|
||||
void addPointsRetrainAndRedraw(Data &data, int x, int y)
|
||||
{
|
||||
|
||||
Mat currentSample(1, 2, CV_32F);
|
||||
//start
|
||||
/*
|
||||
Mat _weights;
|
||||
_weights.create(1, 2, CV_32FC1);
|
||||
_weights.at<float>(0) = 1;
|
||||
_weights.at<float>(1) = -1;
|
||||
|
||||
int _x, _y;
|
||||
|
||||
for (int i=0;i<199;i++)
|
||||
{
|
||||
_x = data.rng.uniform(0,800);
|
||||
_y = data.rng.uniform(0,500);*/
|
||||
currentSample.at<float>(0,0) = x;
|
||||
currentSample.at<float>(0,1) = y;
|
||||
//if (currentSample.dot(_weights) > 0)
|
||||
//data.responses.push_back(1);
|
||||
// else data.responses.push_back(-1);
|
||||
|
||||
//finish
|
||||
data.samples.push_back(currentSample);
|
||||
|
||||
|
||||
|
||||
Mat weights(1, 2, CV_32F);
|
||||
float shift = 0;
|
||||
|
||||
if (doTrain(data.samples, data.responses, weights, shift))
|
||||
{
|
||||
Point points[2];
|
||||
shift = 0;
|
||||
|
||||
findPointsForLine(weights, shift, points);
|
||||
|
||||
redraw(data, points);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void onMouse( int event, int x, int y, int, void* pData)
|
||||
{
|
||||
Data &data = *(Data*)pData;
|
||||
|
||||
switch( event )
|
||||
{
|
||||
case CV_EVENT_LBUTTONUP:
|
||||
data.responses.push_back(1);
|
||||
addPointsRetrainAndRedraw(data, x, y);
|
||||
|
||||
break;
|
||||
|
||||
case CV_EVENT_RBUTTONDOWN:
|
||||
data.responses.push_back(-1);
|
||||
addPointsRetrainAndRedraw(data, x, y);
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
|
||||
Data data;
|
||||
|
||||
setMouseCallback( "Train svmsgd", onMouse, &data );
|
||||
waitKey();
|
||||
|
||||
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user