Merge pull request #6096 from mnoskova:mn/SVMSGD_to_opencv3_0
This commit is contained in:
commit
fbc221d334
@ -874,3 +874,11 @@
|
||||
year={2007},
|
||||
organization={IEEE}
|
||||
}
|
||||
@incollection{bottou2010large,
|
||||
title={Large-scale machine learning with stochastic gradient descent},
|
||||
author={Bottou, L{\'e}on},
|
||||
booktitle={Proceedings of COMPSTAT'2010},
|
||||
pages={177--186},
|
||||
year={2010},
|
||||
publisher={Springer}
|
||||
}
|
||||
|
@ -1499,6 +1499,165 @@ public:
|
||||
CV_WRAP static Ptr<LogisticRegression> create();
|
||||
};
|
||||
|
||||
|
||||
/****************************************************************************************\
|
||||
* Stochastic Gradient Descent SVM Classifier *
|
||||
\****************************************************************************************/
|
||||
|
||||
/*!
|
||||
@brief Stochastic Gradient Descent SVM classifier
|
||||
|
||||
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach,
|
||||
as presented in @cite bottou2010large.
|
||||
|
||||
The classifier has following parameters:
|
||||
- model type,
|
||||
- margin type,
|
||||
- margin regularization (\f$\lambda\f$),
|
||||
- initial step size (\f$\gamma_0\f$),
|
||||
- step decreasing power (\f$c\f$),
|
||||
- and termination criteria.
|
||||
|
||||
The model type may have one of the following values: \ref SGD and \ref ASGD.
|
||||
|
||||
- \ref SGD is the classic version of SVMSGD classifier: every next step is calculated by the formula
|
||||
\f[w_{t+1} = w_t - \gamma(t) \frac{dQ_i}{dw} |_{w = w_t}\f]
|
||||
where
|
||||
- \f$w_t\f$ is the weights vector for decision function at step \f$t\f$,
|
||||
- \f$\gamma(t)\f$ is the step size of model parameters at the iteration \f$t\f$, it is decreased on each step by the formula
|
||||
\f$\gamma(t) = \gamma_0 (1 + \lambda \gamma_0 t) ^ {-c}\f$
|
||||
- \f$Q_i\f$ is the target functional from SVM task for sample with number \f$i\f$, this sample is chosen stochastically on each step of the algorithm.
|
||||
|
||||
- \ref ASGD is Average Stochastic Gradient Descent SVM Classifier. ASGD classifier averages weights vector on each step of algorithm by the formula
|
||||
\f$\widehat{w}_{t+1} = \frac{t}{1+t}\widehat{w}_{t} + \frac{1}{1+t}w_{t+1}\f$
|
||||
|
||||
The recommended model type is ASGD (following @cite bottou2010large).
|
||||
|
||||
The margin type may have one of the following values: \ref SOFT_MARGIN or \ref HARD_MARGIN.
|
||||
|
||||
- You should use \ref HARD_MARGIN type, if you have linearly separable sets.
|
||||
- You should use \ref SOFT_MARGIN type, if you have non-linearly separable sets or sets with outliers.
|
||||
- In the general case (if you know nothing about linear separability of your sets), use SOFT_MARGIN.
|
||||
|
||||
The other parameters may be described as follows:
|
||||
- Margin regularization parameter is responsible for weights decreasing at each step and for the strength of restrictions on outliers
|
||||
(the less the parameter, the less probability that an outlier will be ignored).
|
||||
Recommended value for SGD model is 0.0001, for ASGD model is 0.00001.
|
||||
|
||||
- Initial step size parameter is the initial value for the step size \f$\gamma(t)\f$.
|
||||
You will have to find the best initial step for your problem.
|
||||
|
||||
- Step decreasing power is the power parameter for \f$\gamma(t)\f$ decreasing by the formula, mentioned above.
|
||||
Recommended value for SGD model is 1, for ASGD model is 0.75.
|
||||
|
||||
- Termination criteria can be TermCriteria::COUNT, TermCriteria::EPS or TermCriteria::COUNT + TermCriteria::EPS.
|
||||
You will have to find the best termination criteria for your problem.
|
||||
|
||||
Note that the parameters margin regularization, initial step size, and step decreasing power should be positive.
|
||||
|
||||
To use SVMSGD algorithm do as follows:
|
||||
|
||||
- first, create the SVMSGD object. The algoorithm will set optimal parameters by default, but you can set your own parameters via functions setSvmsgdType(),
|
||||
setMarginType(), setMarginRegularization(), setInitialStepSize(), and setStepDecreasingPower().
|
||||
|
||||
- then the SVM model can be trained using the train features and the correspondent labels by the method train().
|
||||
|
||||
- after that, the label of a new feature vector can be predicted using the method predict().
|
||||
|
||||
@code
|
||||
// Create empty object
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
|
||||
// Train the Stochastic Gradient Descent SVM
|
||||
svmsgd->train(trainData);
|
||||
|
||||
// Predict labels for the new samples
|
||||
svmsgd->predict(samples, responses);
|
||||
@endcode
|
||||
|
||||
*/
|
||||
|
||||
class CV_EXPORTS_W SVMSGD : public cv::ml::StatModel
|
||||
{
|
||||
public:
|
||||
|
||||
/** SVMSGD type.
|
||||
ASGD is often the preferable choice. */
|
||||
enum SvmsgdType
|
||||
{
|
||||
SGD, //!< Stochastic Gradient Descent
|
||||
ASGD //!< Average Stochastic Gradient Descent
|
||||
};
|
||||
|
||||
/** Margin type.*/
|
||||
enum MarginType
|
||||
{
|
||||
SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers.
|
||||
HARD_MARGIN //!< More accurate for the case of linearly separable sets.
|
||||
};
|
||||
|
||||
/**
|
||||
* @return the weights of the trained model (decision function f(x) = weights * x + shift).
|
||||
*/
|
||||
CV_WRAP virtual Mat getWeights() = 0;
|
||||
|
||||
/**
|
||||
* @return the shift of the trained model (decision function f(x) = weights * x + shift).
|
||||
*/
|
||||
CV_WRAP virtual float getShift() = 0;
|
||||
|
||||
/** @brief Creates empty model.
|
||||
* Use StatModel::train to train the model. Since %SVMSGD has several parameters, you may want to
|
||||
* find the best parameters for your problem or use setOptimalParameters() to set some default parameters.
|
||||
*/
|
||||
CV_WRAP static Ptr<SVMSGD> create();
|
||||
|
||||
/** @brief Function sets optimal parameters values for chosen SVM SGD model.
|
||||
* @param svmsgdType is the type of SVMSGD classifier.
|
||||
* @param marginType is the type of margin constraint.
|
||||
*/
|
||||
CV_WRAP virtual void setOptimalParameters(int svmsgdType = SVMSGD::ASGD, int marginType = SVMSGD::SOFT_MARGIN) = 0;
|
||||
|
||||
/** @brief %Algorithm type, one of SVMSGD::SvmsgdType. */
|
||||
/** @see setSvmsgdType */
|
||||
CV_WRAP virtual int getSvmsgdType() const = 0;
|
||||
/** @copybrief getSvmsgdType @see getSvmsgdType */
|
||||
CV_WRAP virtual void setSvmsgdType(int svmsgdType) = 0;
|
||||
|
||||
/** @brief %Margin type, one of SVMSGD::MarginType. */
|
||||
/** @see setMarginType */
|
||||
CV_WRAP virtual int getMarginType() const = 0;
|
||||
/** @copybrief getMarginType @see getMarginType */
|
||||
CV_WRAP virtual void setMarginType(int marginType) = 0;
|
||||
|
||||
/** @brief Parameter marginRegularization of a %SVMSGD optimization problem. */
|
||||
/** @see setMarginRegularization */
|
||||
CV_WRAP virtual float getMarginRegularization() const = 0;
|
||||
/** @copybrief getMarginRegularization @see getMarginRegularization */
|
||||
CV_WRAP virtual void setMarginRegularization(float marginRegularization) = 0;
|
||||
|
||||
/** @brief Parameter initialStepSize of a %SVMSGD optimization problem. */
|
||||
/** @see setInitialStepSize */
|
||||
CV_WRAP virtual float getInitialStepSize() const = 0;
|
||||
/** @copybrief getInitialStepSize @see getInitialStepSize */
|
||||
CV_WRAP virtual void setInitialStepSize(float InitialStepSize) = 0;
|
||||
|
||||
/** @brief Parameter stepDecreasingPower of a %SVMSGD optimization problem. */
|
||||
/** @see setStepDecreasingPower */
|
||||
CV_WRAP virtual float getStepDecreasingPower() const = 0;
|
||||
/** @copybrief getStepDecreasingPower @see getStepDecreasingPower */
|
||||
CV_WRAP virtual void setStepDecreasingPower(float stepDecreasingPower) = 0;
|
||||
|
||||
/** @brief Termination criteria of the training algorithm.
|
||||
You can specify the maximum number of iterations (maxCount) and/or how much the error could
|
||||
change between the iterations to make the algorithm continue (epsilon).*/
|
||||
/** @see setTermCriteria */
|
||||
CV_WRAP virtual TermCriteria getTermCriteria() const = 0;
|
||||
/** @copybrief getTermCriteria @see getTermCriteria */
|
||||
CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0;
|
||||
};
|
||||
|
||||
|
||||
/****************************************************************************************\
|
||||
* Auxilary functions declarations *
|
||||
\****************************************************************************************/
|
||||
|
510
modules/ml/src/svmsgd.cpp
Normal file
510
modules/ml/src/svmsgd.cpp
Normal file
@ -0,0 +1,510 @@
|
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
||||
// Copyright (C) 2016, Itseez Inc, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of the copyright holders may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "precomp.hpp"
|
||||
#include "limits"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
/****************************************************************************************\
|
||||
* Stochastic Gradient Descent SVM Classifier *
|
||||
\****************************************************************************************/
|
||||
|
||||
namespace cv
|
||||
{
|
||||
namespace ml
|
||||
{
|
||||
|
||||
class SVMSGDImpl : public SVMSGD
|
||||
{
|
||||
|
||||
public:
|
||||
SVMSGDImpl();
|
||||
|
||||
virtual ~SVMSGDImpl() {}
|
||||
|
||||
virtual bool train(const Ptr<TrainData>& data, int);
|
||||
|
||||
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
|
||||
|
||||
virtual bool isClassifier() const;
|
||||
|
||||
virtual bool isTrained() const;
|
||||
|
||||
virtual void clear();
|
||||
|
||||
virtual void write(FileStorage &fs) const;
|
||||
|
||||
virtual void read(const FileNode &fn);
|
||||
|
||||
virtual Mat getWeights(){ return weights_; }
|
||||
|
||||
virtual float getShift(){ return shift_; }
|
||||
|
||||
virtual int getVarCount() const { return weights_.cols; }
|
||||
|
||||
virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
|
||||
|
||||
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
|
||||
|
||||
CV_IMPL_PROPERTY(int, SvmsgdType, params.svmsgdType)
|
||||
CV_IMPL_PROPERTY(int, MarginType, params.marginType)
|
||||
CV_IMPL_PROPERTY(float, MarginRegularization, params.marginRegularization)
|
||||
CV_IMPL_PROPERTY(float, InitialStepSize, params.initialStepSize)
|
||||
CV_IMPL_PROPERTY(float, StepDecreasingPower, params.stepDecreasingPower)
|
||||
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
|
||||
|
||||
private:
|
||||
void updateWeights(InputArray sample, bool positive, float stepSize, Mat &weights);
|
||||
|
||||
void writeParams( FileStorage &fs ) const;
|
||||
|
||||
void readParams( const FileNode &fn );
|
||||
|
||||
static inline bool isPositive(float val) { return val > 0; }
|
||||
|
||||
static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
|
||||
|
||||
float calcShift(InputArray _samples, InputArray _responses) const;
|
||||
|
||||
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
|
||||
|
||||
// Vector with SVM weights
|
||||
Mat weights_;
|
||||
float shift_;
|
||||
|
||||
// Parameters for learning
|
||||
struct SVMSGDParams
|
||||
{
|
||||
float marginRegularization;
|
||||
float initialStepSize;
|
||||
float stepDecreasingPower;
|
||||
TermCriteria termCrit;
|
||||
int svmsgdType;
|
||||
int marginType;
|
||||
};
|
||||
|
||||
SVMSGDParams params;
|
||||
};
|
||||
|
||||
Ptr<SVMSGD> SVMSGD::create()
|
||||
{
|
||||
return makePtr<SVMSGDImpl>();
|
||||
}
|
||||
|
||||
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
|
||||
{
|
||||
int featuresCount = samples.cols;
|
||||
int samplesCount = samples.rows;
|
||||
|
||||
average = Mat(1, featuresCount, samples.type());
|
||||
CV_Assert(average.type() == CV_32FC1);
|
||||
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
|
||||
{
|
||||
average.at<float>(featureIndex) = static_cast<float>(mean(samples.col(featureIndex))[0]);
|
||||
}
|
||||
|
||||
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
|
||||
{
|
||||
samples.row(sampleIndex) -= average;
|
||||
}
|
||||
|
||||
double normValue = norm(samples);
|
||||
|
||||
multiplier = static_cast<float>(sqrt(samples.total()) / normValue);
|
||||
|
||||
samples *= multiplier;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
|
||||
{
|
||||
Mat normalizedTrainSamples = trainSamples.clone();
|
||||
int samplesCount = normalizedTrainSamples.rows;
|
||||
|
||||
normalizeSamples(normalizedTrainSamples, average, multiplier);
|
||||
|
||||
Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
|
||||
cv::hconcat(normalizedTrainSamples, onesCol, extendedTrainSamples);
|
||||
}
|
||||
|
||||
void SVMSGDImpl::updateWeights(InputArray _sample, bool positive, float stepSize, Mat& weights)
|
||||
{
|
||||
Mat sample = _sample.getMat();
|
||||
|
||||
int response = positive ? 1 : -1; // ensure that trainResponses are -1 or 1
|
||||
|
||||
if ( sample.dot(weights) * response > 1)
|
||||
{
|
||||
// Not a support vector, only apply weight decay
|
||||
weights *= (1.f - stepSize * params.marginRegularization);
|
||||
}
|
||||
else
|
||||
{
|
||||
// It's a support vector, add it to the weights
|
||||
weights -= (stepSize * params.marginRegularization) * weights - (stepSize * response) * sample;
|
||||
}
|
||||
}
|
||||
|
||||
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
||||
{
|
||||
float margin[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
|
||||
|
||||
Mat trainSamples = _samples.getMat();
|
||||
int trainSamplesCount = trainSamples.rows;
|
||||
|
||||
Mat trainResponses = _responses.getMat();
|
||||
|
||||
CV_Assert(trainResponses.type() == CV_32FC1);
|
||||
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
|
||||
{
|
||||
Mat currentSample = trainSamples.row(samplesIndex);
|
||||
float dotProduct = static_cast<float>(currentSample.dot(weights_));
|
||||
|
||||
bool positive = isPositive(trainResponses.at<float>(samplesIndex));
|
||||
int index = positive ? 0 : 1;
|
||||
float signToMul = positive ? 1.f : -1.f;
|
||||
float curMargin = dotProduct * signToMul;
|
||||
|
||||
if (curMargin < margin[index])
|
||||
{
|
||||
margin[index] = curMargin;
|
||||
}
|
||||
}
|
||||
|
||||
return -(margin[0] - margin[1]) / 2.f;
|
||||
}
|
||||
|
||||
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
{
|
||||
clear();
|
||||
CV_Assert( isClassifier() ); //toDo: consider
|
||||
|
||||
Mat trainSamples = data->getTrainSamples();
|
||||
|
||||
int featureCount = trainSamples.cols;
|
||||
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
|
||||
|
||||
CV_Assert(trainResponses.rows == trainSamples.rows);
|
||||
|
||||
if (trainResponses.empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
int positiveCount = countNonZero(trainResponses >= 0);
|
||||
int negativeCount = countNonZero(trainResponses < 0);
|
||||
|
||||
if ( positiveCount <= 0 || negativeCount <= 0 )
|
||||
{
|
||||
weights_ = Mat::zeros(1, featureCount, CV_32F);
|
||||
shift_ = (positiveCount > 0) ? 1.f : -1.f;
|
||||
return true;
|
||||
}
|
||||
|
||||
Mat extendedTrainSamples;
|
||||
Mat average;
|
||||
float multiplier = 0;
|
||||
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
|
||||
|
||||
int extendedTrainSamplesCount = extendedTrainSamples.rows;
|
||||
int extendedFeatureCount = extendedTrainSamples.cols;
|
||||
|
||||
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||
Mat averageExtendedWeights;
|
||||
if (params.svmsgdType == ASGD)
|
||||
{
|
||||
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||
}
|
||||
|
||||
RNG rng(0);
|
||||
|
||||
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
|
||||
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
|
||||
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
|
||||
|
||||
double err = DBL_MAX;
|
||||
CV_Assert (trainResponses.type() == CV_32FC1);
|
||||
// Stochastic gradient descent SVM
|
||||
for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
|
||||
{
|
||||
int randomNumber = rng.uniform(0, extendedTrainSamplesCount); //generate sample number
|
||||
|
||||
Mat currentSample = extendedTrainSamples.row(randomNumber);
|
||||
|
||||
float stepSize = params.initialStepSize * std::pow((1 + params.marginRegularization * params.initialStepSize * (float)iter), (-params.stepDecreasingPower)); //update stepSize
|
||||
|
||||
updateWeights( currentSample, isPositive(trainResponses.at<float>(randomNumber)), stepSize, extendedWeights );
|
||||
|
||||
//average weights (only for ASGD model)
|
||||
if (params.svmsgdType == ASGD)
|
||||
{
|
||||
averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights + extendedWeights / (1 + (float) iter);
|
||||
err = norm(averageExtendedWeights - previousWeights);
|
||||
averageExtendedWeights.copyTo(previousWeights);
|
||||
}
|
||||
else
|
||||
{
|
||||
err = norm(extendedWeights - previousWeights);
|
||||
extendedWeights.copyTo(previousWeights);
|
||||
}
|
||||
}
|
||||
|
||||
if (params.svmsgdType == ASGD)
|
||||
{
|
||||
extendedWeights = averageExtendedWeights;
|
||||
}
|
||||
|
||||
Rect roi(0, 0, featureCount, 1);
|
||||
weights_ = extendedWeights(roi);
|
||||
weights_ *= multiplier;
|
||||
|
||||
CV_Assert((params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN) && (extendedWeights.type() == CV_32FC1));
|
||||
|
||||
if (params.marginType == SOFT_MARGIN)
|
||||
{
|
||||
shift_ = extendedWeights.at<float>(featureCount) - static_cast<float>(weights_.dot(average));
|
||||
}
|
||||
else
|
||||
{
|
||||
shift_ = calcShift(trainSamples, trainResponses);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
|
||||
{
|
||||
float result = 0;
|
||||
cv::Mat samples = _samples.getMat();
|
||||
int nSamples = samples.rows;
|
||||
cv::Mat results;
|
||||
|
||||
CV_Assert( samples.cols == weights_.cols && samples.type() == CV_32FC1);
|
||||
|
||||
if( _results.needed() )
|
||||
{
|
||||
_results.create( nSamples, 1, samples.type() );
|
||||
results = _results.getMat();
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Assert( nSamples == 1 );
|
||||
results = Mat(1, 1, CV_32FC1, &result);
|
||||
}
|
||||
|
||||
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
|
||||
{
|
||||
Mat currentSample = samples.row(sampleIndex);
|
||||
float criterion = static_cast<float>(currentSample.dot(weights_)) + shift_;
|
||||
results.at<float>(sampleIndex) = (criterion >= 0) ? 1.f : -1.f;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool SVMSGDImpl::isClassifier() const
|
||||
{
|
||||
return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
|
||||
&&
|
||||
(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
|
||||
&&
|
||||
(params.marginRegularization > 0) && (params.initialStepSize > 0) && (params.stepDecreasingPower >= 0);
|
||||
}
|
||||
|
||||
bool SVMSGDImpl::isTrained() const
|
||||
{
|
||||
return !weights_.empty();
|
||||
}
|
||||
|
||||
void SVMSGDImpl::write(FileStorage& fs) const
|
||||
{
|
||||
if( !isTrained() )
|
||||
CV_Error( CV_StsParseError, "SVMSGD model data is invalid, it hasn't been trained" );
|
||||
|
||||
writeParams( fs );
|
||||
|
||||
fs << "weights" << weights_;
|
||||
fs << "shift" << shift_;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||
{
|
||||
String SvmsgdTypeStr;
|
||||
|
||||
switch (params.svmsgdType)
|
||||
{
|
||||
case SGD:
|
||||
SvmsgdTypeStr = "SGD";
|
||||
break;
|
||||
case ASGD:
|
||||
SvmsgdTypeStr = "ASGD";
|
||||
break;
|
||||
default:
|
||||
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
|
||||
}
|
||||
|
||||
fs << "svmsgdType" << SvmsgdTypeStr;
|
||||
|
||||
String marginTypeStr;
|
||||
|
||||
switch (params.marginType)
|
||||
{
|
||||
case SOFT_MARGIN:
|
||||
marginTypeStr = "SOFT_MARGIN";
|
||||
break;
|
||||
case HARD_MARGIN:
|
||||
marginTypeStr = "HARD_MARGIN";
|
||||
break;
|
||||
default:
|
||||
marginTypeStr = format("Unknown_%d", params.marginType);
|
||||
}
|
||||
|
||||
fs << "marginType" << marginTypeStr;
|
||||
|
||||
fs << "marginRegularization" << params.marginRegularization;
|
||||
fs << "initialStepSize" << params.initialStepSize;
|
||||
fs << "stepDecreasingPower" << params.stepDecreasingPower;
|
||||
|
||||
fs << "term_criteria" << "{:";
|
||||
if( params.termCrit.type & TermCriteria::EPS )
|
||||
fs << "epsilon" << params.termCrit.epsilon;
|
||||
if( params.termCrit.type & TermCriteria::COUNT )
|
||||
fs << "iterations" << params.termCrit.maxCount;
|
||||
fs << "}";
|
||||
}
|
||||
void SVMSGDImpl::readParams( const FileNode& fn )
|
||||
{
|
||||
String svmsgdTypeStr = (String)fn["svmsgdType"];
|
||||
int svmsgdType =
|
||||
svmsgdTypeStr == "SGD" ? SGD :
|
||||
svmsgdTypeStr == "ASGD" ? ASGD : -1;
|
||||
|
||||
if( svmsgdType < 0 )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
|
||||
|
||||
params.svmsgdType = svmsgdType;
|
||||
|
||||
String marginTypeStr = (String)fn["marginType"];
|
||||
int marginType =
|
||||
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
|
||||
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : -1;
|
||||
|
||||
if( marginType < 0 )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
|
||||
|
||||
params.marginType = marginType;
|
||||
|
||||
CV_Assert ( fn["marginRegularization"].isReal() );
|
||||
params.marginRegularization = (float)fn["marginRegularization"];
|
||||
|
||||
CV_Assert ( fn["initialStepSize"].isReal() );
|
||||
params.initialStepSize = (float)fn["initialStepSize"];
|
||||
|
||||
CV_Assert ( fn["stepDecreasingPower"].isReal() );
|
||||
params.stepDecreasingPower = (float)fn["stepDecreasingPower"];
|
||||
|
||||
FileNode tcnode = fn["term_criteria"];
|
||||
CV_Assert(!tcnode.empty());
|
||||
params.termCrit.epsilon = (double)tcnode["epsilon"];
|
||||
params.termCrit.maxCount = (int)tcnode["iterations"];
|
||||
params.termCrit.type = (params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
|
||||
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
|
||||
CV_Assert ((params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS));
|
||||
}
|
||||
|
||||
void SVMSGDImpl::read(const FileNode& fn)
|
||||
{
|
||||
clear();
|
||||
|
||||
readParams(fn);
|
||||
|
||||
fn["weights"] >> weights_;
|
||||
fn["shift"] >> shift_;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::clear()
|
||||
{
|
||||
weights_.release();
|
||||
shift_ = 0;
|
||||
}
|
||||
|
||||
|
||||
SVMSGDImpl::SVMSGDImpl()
|
||||
{
|
||||
clear();
|
||||
setOptimalParameters();
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
||||
{
|
||||
switch (svmsgdType)
|
||||
{
|
||||
case SGD:
|
||||
params.svmsgdType = SGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||
params.marginRegularization = 0.0001f;
|
||||
params.initialStepSize = 0.05f;
|
||||
params.stepDecreasingPower = 1.f;
|
||||
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
|
||||
break;
|
||||
|
||||
case ASGD:
|
||||
params.svmsgdType = ASGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : -1;
|
||||
params.marginRegularization = 0.00001f;
|
||||
params.initialStepSize = 0.05f;
|
||||
params.stepDecreasingPower = 0.75f;
|
||||
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
|
||||
break;
|
||||
|
||||
default:
|
||||
CV_Error( CV_StsParseError, "SVMSGD model data is invalid" );
|
||||
}
|
||||
}
|
||||
} //ml
|
||||
} //cv
|
@ -193,6 +193,25 @@ int str_to_boost_type( String& str )
|
||||
// 8. rtrees
|
||||
// 9. ertrees
|
||||
|
||||
int str_to_svmsgd_type( String& str )
|
||||
{
|
||||
if ( !str.compare("SGD") )
|
||||
return SVMSGD::SGD;
|
||||
if ( !str.compare("ASGD") )
|
||||
return SVMSGD::ASGD;
|
||||
CV_Error( CV_StsBadArg, "incorrect svmsgd type string" );
|
||||
return -1;
|
||||
}
|
||||
|
||||
int str_to_margin_type( String& str )
|
||||
{
|
||||
if ( !str.compare("SOFT_MARGIN") )
|
||||
return SVMSGD::SOFT_MARGIN;
|
||||
if ( !str.compare("HARD_MARGIN") )
|
||||
return SVMSGD::HARD_MARGIN;
|
||||
CV_Error( CV_StsBadArg, "incorrect svmsgd margin type string" );
|
||||
return -1;
|
||||
}
|
||||
// ---------------------------------- MLBaseTest ---------------------------------------------------
|
||||
|
||||
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
|
||||
@ -436,6 +455,27 @@ int CV_MLBaseTest::train( int testCaseIdx )
|
||||
model = m;
|
||||
}
|
||||
|
||||
else if( modelName == CV_SVMSGD )
|
||||
{
|
||||
String svmsgdTypeStr;
|
||||
modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
|
||||
|
||||
Ptr<SVMSGD> m = SVMSGD::create();
|
||||
int svmsgdType = str_to_svmsgd_type( svmsgdTypeStr );
|
||||
m->setSvmsgdType(svmsgdType);
|
||||
|
||||
String marginTypeStr;
|
||||
modelParamsNode["marginType"] >> marginTypeStr;
|
||||
int marginType = str_to_margin_type( marginTypeStr );
|
||||
m->setMarginType(marginType);
|
||||
|
||||
m->setMarginRegularization(modelParamsNode["marginRegularization"]);
|
||||
m->setInitialStepSize(modelParamsNode["initialStepSize"]);
|
||||
m->setStepDecreasingPower(modelParamsNode["stepDecreasingPower"]);
|
||||
m->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 10000, 0.00001));
|
||||
model = m;
|
||||
}
|
||||
|
||||
if( !model.empty() )
|
||||
is_trained = model->train(data, 0);
|
||||
|
||||
@ -457,7 +497,7 @@ float CV_MLBaseTest::get_test_error( int /*testCaseIdx*/, vector<float> *resp )
|
||||
else if( modelName == CV_ANN )
|
||||
err = ann_calc_error( model, data, cls_map, type, resp );
|
||||
else if( modelName == CV_DTREE || modelName == CV_BOOST || modelName == CV_RTREES ||
|
||||
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST )
|
||||
modelName == CV_SVM || modelName == CV_NBAYES || modelName == CV_KNEAREST || modelName == CV_SVMSGD )
|
||||
err = model->calcError( data, true, _resp );
|
||||
if( !_resp.empty() && resp )
|
||||
_resp.convertTo(*resp, CV_32F);
|
||||
@ -485,6 +525,8 @@ void CV_MLBaseTest::load( const char* filename )
|
||||
model = Algorithm::load<Boost>( filename );
|
||||
else if( modelName == CV_RTREES )
|
||||
model = Algorithm::load<RTrees>( filename );
|
||||
else if( modelName == CV_SVMSGD )
|
||||
model = Algorithm::load<SVMSGD>( filename );
|
||||
else
|
||||
CV_Error( CV_StsNotImplemented, "invalid stat model name");
|
||||
}
|
||||
|
@ -24,6 +24,7 @@
|
||||
#define CV_BOOST "boost"
|
||||
#define CV_RTREES "rtrees"
|
||||
#define CV_ERTREES "ertrees"
|
||||
#define CV_SVMSGD "svmsgd"
|
||||
|
||||
enum { CV_TRAIN_ERROR=0, CV_TEST_ERROR=1 };
|
||||
|
||||
@ -38,6 +39,7 @@ using cv::ml::ANN_MLP;
|
||||
using cv::ml::DTrees;
|
||||
using cv::ml::Boost;
|
||||
using cv::ml::RTrees;
|
||||
using cv::ml::SVMSGD;
|
||||
|
||||
class CV_MLBaseTest : public cvtest::BaseTest
|
||||
{
|
||||
|
@ -156,6 +156,7 @@ TEST(ML_DTree, save_load) { CV_SLMLTest test( CV_DTREE ); test.safe_run(); }
|
||||
TEST(ML_Boost, save_load) { CV_SLMLTest test( CV_BOOST ); test.safe_run(); }
|
||||
TEST(ML_RTrees, save_load) { CV_SLMLTest test( CV_RTREES ); test.safe_run(); }
|
||||
TEST(DISABLED_ML_ERTrees, save_load) { CV_SLMLTest test( CV_ERTREES ); test.safe_run(); }
|
||||
TEST(MV_SVMSGD, save_load){ CV_SLMLTest test( CV_SVMSGD ); test.safe_run(); }
|
||||
|
||||
class CV_LegacyTest : public cvtest::BaseTest
|
||||
{
|
||||
@ -201,6 +202,8 @@ protected:
|
||||
model = Algorithm::load<SVM>(filename);
|
||||
else if (modelName == CV_RTREES)
|
||||
model = Algorithm::load<RTrees>(filename);
|
||||
else if (modelName == CV_SVMSGD)
|
||||
model = Algorithm::load<SVMSGD>(filename);
|
||||
if (!model)
|
||||
{
|
||||
code = cvtest::TS::FAIL_INVALID_TEST_DATA;
|
||||
@ -260,6 +263,7 @@ TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushro
|
||||
TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_SVMSGD, legacy_load) { CV_LegacyTest test(CV_SVMSGD, "_waveform.xml"); test.safe_run(); }
|
||||
|
||||
/*TEST(ML_SVM, throw_exception_when_save_untrained_model)
|
||||
{
|
||||
|
318
modules/ml/test/test_svmsgd.cpp
Normal file
318
modules/ml/test/test_svmsgd.cpp
Normal file
@ -0,0 +1,318 @@
|
||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
||||
//
|
||||
// By downloading, copying, installing or using the software you agree to this license.
|
||||
// If you do not agree to this license, do not download, install,
|
||||
// copy or use the software.
|
||||
//
|
||||
//
|
||||
// Intel License Agreement
|
||||
// For Open Source Computer Vision Library
|
||||
//
|
||||
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
||||
// Third party copyrights are property of their respective owners.
|
||||
//
|
||||
// Redistribution and use in source and binary forms, with or without modification,
|
||||
// are permitted provided that the following conditions are met:
|
||||
//
|
||||
// * Redistribution's of source code must retain the above copyright notice,
|
||||
// this list of conditions and the following disclaimer.
|
||||
//
|
||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
||||
// this list of conditions and the following disclaimer in the documentation
|
||||
// and/or other materials provided with the distribution.
|
||||
//
|
||||
// * The name of Intel Corporation may not be used to endorse or promote products
|
||||
// derived from this software without specific prior written permission.
|
||||
//
|
||||
// This software is provided by the copyright holders and contributors "as is" and
|
||||
// any express or implied warranties, including, but not limited to, the implied
|
||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
||||
// indirect, incidental, special, exemplary, or consequential damages
|
||||
// (including, but not limited to, procurement of substitute goods or services;
|
||||
// loss of use, data, or profits; or business interruption) however caused
|
||||
// and on any theory of liability, whether in contract, strict liability,
|
||||
// or tort (including negligence or otherwise) arising in any way out of
|
||||
// the use of this software, even if advised of the possibility of such damage.
|
||||
//
|
||||
//M*/
|
||||
|
||||
#include "test_precomp.hpp"
|
||||
#include "opencv2/highgui.hpp"
|
||||
|
||||
using namespace cv;
|
||||
using namespace cv::ml;
|
||||
using cv::ml::SVMSGD;
|
||||
using cv::ml::TrainData;
|
||||
|
||||
|
||||
|
||||
class CV_SVMSGDTrainTest : public cvtest::BaseTest
|
||||
{
|
||||
public:
|
||||
enum TrainDataType
|
||||
{
|
||||
UNIFORM_SAME_SCALE,
|
||||
UNIFORM_DIFFERENT_SCALES
|
||||
};
|
||||
|
||||
CV_SVMSGDTrainTest(const Mat &_weights, float shift, TrainDataType type, double precision = 0.01);
|
||||
private:
|
||||
virtual void run( int start_from );
|
||||
static float decisionFunction(const Mat &sample, const Mat &weights, float shift);
|
||||
void makeData(int samplesCount, const Mat &weights, float shift, RNG &rng, Mat &samples, Mat & responses);
|
||||
void generateSameBorders(int featureCount);
|
||||
void generateDifferentBorders(int featureCount);
|
||||
|
||||
TrainDataType type;
|
||||
double precision;
|
||||
std::vector<std::pair<float,float> > borders;
|
||||
cv::Ptr<TrainData> data;
|
||||
cv::Mat testSamples;
|
||||
cv::Mat testResponses;
|
||||
static const int TEST_VALUE_LIMIT = 500;
|
||||
};
|
||||
|
||||
void CV_SVMSGDTrainTest::generateSameBorders(int featureCount)
|
||||
{
|
||||
float lowerLimit = -TEST_VALUE_LIMIT;
|
||||
float upperLimit = TEST_VALUE_LIMIT;
|
||||
|
||||
for (int featureIndex = 0; featureIndex < featureCount; featureIndex++)
|
||||
{
|
||||
borders.push_back(std::pair<float,float>(lowerLimit, upperLimit));
|
||||
}
|
||||
}
|
||||
|
||||
void CV_SVMSGDTrainTest::generateDifferentBorders(int featureCount)
|
||||
{
|
||||
float lowerLimit = -TEST_VALUE_LIMIT;
|
||||
float upperLimit = TEST_VALUE_LIMIT;
|
||||
cv::RNG rng(0);
|
||||
|
||||
for (int featureIndex = 0; featureIndex < featureCount; featureIndex++)
|
||||
{
|
||||
int crit = rng.uniform(0, 2);
|
||||
|
||||
if (crit > 0)
|
||||
{
|
||||
borders.push_back(std::pair<float,float>(lowerLimit, upperLimit));
|
||||
}
|
||||
else
|
||||
{
|
||||
borders.push_back(std::pair<float,float>(lowerLimit/1000, upperLimit/1000));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float CV_SVMSGDTrainTest::decisionFunction(const Mat &sample, const Mat &weights, float shift)
|
||||
{
|
||||
return static_cast<float>(sample.dot(weights)) + shift;
|
||||
}
|
||||
|
||||
void CV_SVMSGDTrainTest::makeData(int samplesCount, const Mat &weights, float shift, RNG &rng, Mat &samples, Mat & responses)
|
||||
{
|
||||
int featureCount = weights.cols;
|
||||
|
||||
samples.create(samplesCount, featureCount, CV_32FC1);
|
||||
for (int featureIndex = 0; featureIndex < featureCount; featureIndex++)
|
||||
{
|
||||
rng.fill(samples.col(featureIndex), RNG::UNIFORM, borders[featureIndex].first, borders[featureIndex].second);
|
||||
}
|
||||
|
||||
responses.create(samplesCount, 1, CV_32FC1);
|
||||
|
||||
for (int i = 0 ; i < samplesCount; i++)
|
||||
{
|
||||
responses.at<float>(i) = decisionFunction(samples.row(i), weights, shift) > 0 ? 1.f : -1.f;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(const Mat &weights, float shift, TrainDataType _type, double _precision)
|
||||
{
|
||||
type = _type;
|
||||
precision = _precision;
|
||||
|
||||
int featureCount = weights.cols;
|
||||
|
||||
switch(type)
|
||||
{
|
||||
case UNIFORM_SAME_SCALE:
|
||||
generateSameBorders(featureCount);
|
||||
break;
|
||||
case UNIFORM_DIFFERENT_SCALES:
|
||||
generateDifferentBorders(featureCount);
|
||||
break;
|
||||
default:
|
||||
CV_Error(CV_StsBadArg, "Unknown train data type");
|
||||
}
|
||||
|
||||
RNG rng(0);
|
||||
|
||||
Mat trainSamples;
|
||||
Mat trainResponses;
|
||||
int trainSamplesCount = 10000;
|
||||
makeData(trainSamplesCount, weights, shift, rng, trainSamples, trainResponses);
|
||||
data = TrainData::create(trainSamples, cv::ml::ROW_SAMPLE, trainResponses);
|
||||
|
||||
int testSamplesCount = 100000;
|
||||
makeData(testSamplesCount, weights, shift, rng, testSamples, testResponses);
|
||||
}
|
||||
|
||||
void CV_SVMSGDTrainTest::run( int /*start_from*/ )
|
||||
{
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
|
||||
svmsgd->train(data);
|
||||
|
||||
Mat responses;
|
||||
|
||||
svmsgd->predict(testSamples, responses);
|
||||
|
||||
int errCount = 0;
|
||||
int testSamplesCount = testSamples.rows;
|
||||
|
||||
CV_Assert((responses.type() == CV_32FC1) && (testResponses.type() == CV_32FC1));
|
||||
for (int i = 0; i < testSamplesCount; i++)
|
||||
{
|
||||
if (responses.at<float>(i) * testResponses.at<float>(i) < 0)
|
||||
errCount++;
|
||||
}
|
||||
|
||||
float err = (float)errCount / testSamplesCount;
|
||||
|
||||
if ( err > precision )
|
||||
{
|
||||
ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY);
|
||||
}
|
||||
}
|
||||
|
||||
void makeWeightsAndShift(int featureCount, Mat &weights, float &shift)
|
||||
{
|
||||
weights.create(1, featureCount, CV_32FC1);
|
||||
cv::RNG rng(0);
|
||||
double lowerLimit = -1;
|
||||
double upperLimit = 1;
|
||||
|
||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
shift = static_cast<float>(rng.uniform(-featureCount, featureCount));
|
||||
}
|
||||
|
||||
|
||||
TEST(ML_SVMSGD, trainSameScale2)
|
||||
{
|
||||
int featureCount = 2;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_SAME_SCALE);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainSameScale5)
|
||||
{
|
||||
int featureCount = 5;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_SAME_SCALE);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainSameScale100)
|
||||
{
|
||||
int featureCount = 100;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_SAME_SCALE, 0.02);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainDifferentScales2)
|
||||
{
|
||||
int featureCount = 2;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_DIFFERENT_SCALES, 0.01);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainDifferentScales5)
|
||||
{
|
||||
int featureCount = 5;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_DIFFERENT_SCALES, 0.01);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainDifferentScales100)
|
||||
{
|
||||
int featureCount = 100;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_DIFFERENT_SCALES, 0.01);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, twoPoints)
|
||||
{
|
||||
Mat samples(2, 2, CV_32FC1);
|
||||
samples.at<float>(0,0) = 0;
|
||||
samples.at<float>(0,1) = 0;
|
||||
samples.at<float>(1,0) = 1000;
|
||||
samples.at<float>(1,1) = 1;
|
||||
|
||||
Mat responses(2, 1, CV_32FC1);
|
||||
responses.at<float>(0) = -1;
|
||||
responses.at<float>(1) = 1;
|
||||
|
||||
cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
|
||||
|
||||
Mat realWeights(1, 2, CV_32FC1);
|
||||
realWeights.at<float>(0) = 1000;
|
||||
realWeights.at<float>(1) = 1;
|
||||
|
||||
float realShift = -500000.5;
|
||||
|
||||
float normRealWeights = static_cast<float>(norm(realWeights));
|
||||
realWeights /= normRealWeights;
|
||||
realShift /= normRealWeights;
|
||||
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
svmsgd->setOptimalParameters();
|
||||
svmsgd->train( trainData );
|
||||
|
||||
Mat foundWeights = svmsgd->getWeights();
|
||||
float foundShift = svmsgd->getShift();
|
||||
|
||||
float normFoundWeights = static_cast<float>(norm(foundWeights));
|
||||
foundWeights /= normFoundWeights;
|
||||
foundShift /= normFoundWeights;
|
||||
CV_Assert((norm(foundWeights - realWeights) < 0.001) && (abs((foundShift - realShift) / realShift) < 0.05));
|
||||
}
|
210
samples/cpp/train_svmsgd.cpp
Normal file
210
samples/cpp/train_svmsgd.cpp
Normal file
@ -0,0 +1,210 @@
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include "opencv2/video/tracking.hpp"
|
||||
#include "opencv2/imgproc/imgproc.hpp"
|
||||
#include "opencv2/highgui/highgui.hpp"
|
||||
|
||||
using namespace cv;
|
||||
using namespace cv::ml;
|
||||
|
||||
|
||||
struct Data
|
||||
{
|
||||
Mat img;
|
||||
Mat samples; //Set of train samples. Contains points on image
|
||||
Mat responses; //Set of responses for train samples
|
||||
|
||||
Data()
|
||||
{
|
||||
const int WIDTH = 841;
|
||||
const int HEIGHT = 594;
|
||||
img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
|
||||
imshow("Train svmsgd", img);
|
||||
}
|
||||
};
|
||||
|
||||
//Train with SVMSGD algorithm
|
||||
//(samples, responses) is a train set
|
||||
//weights is a required vector for decision function of SVMSGD algorithm
|
||||
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);
|
||||
|
||||
//function finds two points for drawing line (wx = 0)
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point points[], int width, int height);
|
||||
|
||||
// function finds cross point of line (wx = 0) and segment ( (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) )
|
||||
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
|
||||
|
||||
//segments' initialization ( (y = HEIGHT, 0 <= x <= WIDTH) and (x = WIDTH, 0 <= y <= HEIGHT) )
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);
|
||||
|
||||
//redraw points' set and line (wx = 0)
|
||||
void redraw(Data data, const Point points[2]);
|
||||
|
||||
//add point in train set, train SVMSGD algorithm and draw results on image
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);
|
||||
|
||||
|
||||
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
|
||||
{
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
|
||||
cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
|
||||
svmsgd->train( trainData );
|
||||
|
||||
if (svmsgd->isTrained())
|
||||
{
|
||||
weights = svmsgd->getWeights();
|
||||
shift = svmsgd->getShift();
|
||||
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
|
||||
{
|
||||
std::pair<Point,Point> currentSegment;
|
||||
|
||||
currentSegment.first = Point(width, 0);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, height);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(width, 0);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(0, height);
|
||||
segments.push_back(currentSegment);
|
||||
}
|
||||
|
||||
|
||||
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
|
||||
{
|
||||
int x = 0;
|
||||
int y = 0;
|
||||
int xMin = std::min(segment.first.x, segment.second.x);
|
||||
int xMax = std::max(segment.first.x, segment.second.x);
|
||||
int yMin = std::min(segment.first.y, segment.second.y);
|
||||
int yMax = std::max(segment.first.y, segment.second.y);
|
||||
|
||||
CV_Assert(weights.type() == CV_32FC1);
|
||||
CV_Assert(xMin == xMax || yMin == yMax);
|
||||
|
||||
if (xMin == xMax && weights.at<float>(1) != 0)
|
||||
{
|
||||
x = xMin;
|
||||
y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1)));
|
||||
if (y >= yMin && y <= yMax)
|
||||
{
|
||||
crossPoint.x = x;
|
||||
crossPoint.y = y;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else if (yMin == yMax && weights.at<float>(0) != 0)
|
||||
{
|
||||
y = yMin;
|
||||
x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0)));
|
||||
if (x >= xMin && x <= xMax)
|
||||
{
|
||||
crossPoint.x = x;
|
||||
crossPoint.y = y;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height)
|
||||
{
|
||||
if (weights.empty())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
int foundPointsCount = 0;
|
||||
std::vector<std::pair<Point,Point> > segments;
|
||||
fillSegments(segments, width, height);
|
||||
|
||||
for (uint i = 0; i < segments.size(); i++)
|
||||
{
|
||||
if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
|
||||
foundPointsCount++;
|
||||
if (foundPointsCount >= 2)
|
||||
break;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void redraw(Data data, const Point points[2])
|
||||
{
|
||||
data.img.setTo(0);
|
||||
Point center;
|
||||
int radius = 3;
|
||||
Scalar color;
|
||||
CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1));
|
||||
for (int i = 0; i < data.samples.rows; i++)
|
||||
{
|
||||
center.x = static_cast<int>(data.samples.at<float>(i,0));
|
||||
center.y = static_cast<int>(data.samples.at<float>(i,1));
|
||||
color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
|
||||
circle(data.img, center, radius, color, 5);
|
||||
}
|
||||
line(data.img, points[0], points[1],cv::Scalar(1,255,1));
|
||||
|
||||
imshow("Train svmsgd", data.img);
|
||||
}
|
||||
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
|
||||
{
|
||||
Mat currentSample(1, 2, CV_32FC1);
|
||||
|
||||
currentSample.at<float>(0,0) = (float)x;
|
||||
currentSample.at<float>(0,1) = (float)y;
|
||||
data.samples.push_back(currentSample);
|
||||
data.responses.push_back(response);
|
||||
|
||||
Mat weights(1, 2, CV_32FC1);
|
||||
float shift = 0;
|
||||
|
||||
if (doTrain(data.samples, data.responses, weights, shift))
|
||||
{
|
||||
Point points[2];
|
||||
findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
|
||||
|
||||
redraw(data, points);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void onMouse( int event, int x, int y, int, void* pData)
|
||||
{
|
||||
Data &data = *(Data*)pData;
|
||||
|
||||
switch( event )
|
||||
{
|
||||
case CV_EVENT_LBUTTONUP:
|
||||
addPointRetrainAndRedraw(data, x, y, 1);
|
||||
break;
|
||||
|
||||
case CV_EVENT_RBUTTONDOWN:
|
||||
addPointRetrainAndRedraw(data, x, y, -1);
|
||||
break;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
Data data;
|
||||
|
||||
setMouseCallback( "Train svmsgd", onMouse, &data );
|
||||
waitKey();
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user