Increasing the dimension of features space in the SVMSGD::train function.
This commit is contained in:
parent
40bf97c6d1
commit
acd74037b3
@ -75,7 +75,6 @@
|
|||||||
#endif
|
#endif
|
||||||
#ifdef HAVE_OPENCV_ML
|
#ifdef HAVE_OPENCV_ML
|
||||||
#include "opencv2/ml.hpp"
|
#include "opencv2/ml.hpp"
|
||||||
#include "opencv2/ml/svmsgd.hpp"
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -1496,6 +1496,121 @@ public:
|
|||||||
CV_WRAP static Ptr<LogisticRegression> create();
|
CV_WRAP static Ptr<LogisticRegression> create();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/****************************************************************************************\
|
||||||
|
* Stochastic Gradient Descent SVM Classifier *
|
||||||
|
\****************************************************************************************/
|
||||||
|
|
||||||
|
/*!
|
||||||
|
@brief Stochastic Gradient Descent SVM classifier
|
||||||
|
|
||||||
|
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large.
|
||||||
|
The gradient descent show amazing performance for large-scale problems, reducing the computing time.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
First, create the SVMSGD object. Set parametrs of model (type, lambda, gamma0, c) using the functions setType, setLambda, setGamma0 and setC or the function setOptimalParametrs.
|
||||||
|
Recommended model type is ASGD.
|
||||||
|
|
||||||
|
Then the SVM model can be trained using the train features and the correspondent labels.
|
||||||
|
|
||||||
|
After that, the label of a new feature vector can be predicted using the predict function.
|
||||||
|
|
||||||
|
@code
|
||||||
|
// Initialize object
|
||||||
|
SVMSGD SvmSgd;
|
||||||
|
|
||||||
|
// Train the Stochastic Gradient Descent SVM
|
||||||
|
SvmSgd.train(trainFeatures, labels);
|
||||||
|
|
||||||
|
// Predict label for the new feature vector (1xM)
|
||||||
|
predictedLabel = SvmSgd.predict(newFeatureVector);
|
||||||
|
@endcode
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
class CV_EXPORTS_W SVMSGD : public cv::ml::StatModel
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
/** SVMSGD type.
|
||||||
|
ASGD is often the preferable choice. */
|
||||||
|
enum SvmsgdType
|
||||||
|
{
|
||||||
|
ILLEGAL_VALUE,
|
||||||
|
SGD, //!Stochastic Gradient Descent
|
||||||
|
ASGD //!Average Stochastic Gradient Descent
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the weights of the trained model (decision function f(x) = weights * x + shift).
|
||||||
|
*/
|
||||||
|
CV_WRAP virtual Mat getWeights() = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the shift of the trained model (decision function f(x) = weights * x + shift).
|
||||||
|
*/
|
||||||
|
CV_WRAP virtual float getShift() = 0;
|
||||||
|
|
||||||
|
|
||||||
|
/** Creates empty model.
|
||||||
|
Use StatModel::train to train the model. Since %SVMSGD has several parameters, you may want to
|
||||||
|
find the best parameters for your problem or use setOptimalParameters() to set some default parameters.
|
||||||
|
*/
|
||||||
|
CV_WRAP static Ptr<SVMSGD> create();
|
||||||
|
|
||||||
|
/** Function sets optimal parameters values for chosen SVM SGD model.
|
||||||
|
* If chosen type is ASGD, function sets the following values for parameters of model:
|
||||||
|
* lambda = 0.00001;
|
||||||
|
* gamma0 = 0.05;
|
||||||
|
* c = 0.75;
|
||||||
|
* termCrit.maxCount = 100000;
|
||||||
|
* termCrit.epsilon = 0.00001;
|
||||||
|
*
|
||||||
|
* If SGD:
|
||||||
|
* lambda = 0.0001;
|
||||||
|
* gamma0 = 0.05;
|
||||||
|
* c = 1;
|
||||||
|
* termCrit.maxCount = 100000;
|
||||||
|
* termCrit.epsilon = 0.00001;
|
||||||
|
* @param type is the type of SVMSGD classifier. Legal values are SvmsgdType::SGD and SvmsgdType::ASGD.
|
||||||
|
* Recommended value is SvmsgdType::ASGD (by default).
|
||||||
|
*/
|
||||||
|
CV_WRAP virtual void setOptimalParameters(int type = ASGD) = 0;
|
||||||
|
|
||||||
|
/** %Algorithm type, one of SVMSGD::SvmsgdType. */
|
||||||
|
/** @see setAlgorithmType */
|
||||||
|
CV_WRAP virtual int getType() const = 0;
|
||||||
|
/** @copybrief getAlgorithmType @see getAlgorithmType */
|
||||||
|
CV_WRAP virtual void setType(int type) = 0;
|
||||||
|
|
||||||
|
/** Parameter _Lambda_ of a %SVMSGD optimization problem. Default value is 0. */
|
||||||
|
/** @see setLambda */
|
||||||
|
CV_WRAP virtual float getLambda() const = 0;
|
||||||
|
/** @copybrief getLambda @see getLambda */
|
||||||
|
CV_WRAP virtual void setLambda(float lambda) = 0;
|
||||||
|
|
||||||
|
/** Parameter _Gamma0_ of a %SVMSGD optimization problem. Default value is 0. */
|
||||||
|
/** @see setGamma0 */
|
||||||
|
CV_WRAP virtual float getGamma0() const = 0;
|
||||||
|
CV_WRAP virtual void setGamma0(float gamma0) = 0;
|
||||||
|
|
||||||
|
/** Parameter _C_ of a %SVMSGD optimization problem. Default value is 0. */
|
||||||
|
/** @see setC */
|
||||||
|
CV_WRAP virtual float getC() const = 0;
|
||||||
|
/** @copybrief getC @see getC */
|
||||||
|
CV_WRAP virtual void setC(float c) = 0;
|
||||||
|
|
||||||
|
/** @brief Termination criteria of the training algorithm.
|
||||||
|
You can specify the maximum number of iterations (maxCount) and/or how much the error could
|
||||||
|
change between the iterations to make the algorithm continue (epsilon).*/
|
||||||
|
/** @see setTermCriteria */
|
||||||
|
CV_WRAP virtual TermCriteria getTermCriteria() const = 0;
|
||||||
|
/** @copybrief getTermCriteria @see getTermCriteria */
|
||||||
|
CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
/****************************************************************************************\
|
/****************************************************************************************\
|
||||||
* Auxilary functions declarations *
|
* Auxilary functions declarations *
|
||||||
\****************************************************************************************/
|
\****************************************************************************************/
|
||||||
|
@ -1,134 +0,0 @@
|
|||||||
/*M///////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
//
|
|
||||||
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
|
|
||||||
//
|
|
||||||
// By downloading, copying, installing or using the software you agree to this license.
|
|
||||||
// If you do not agree to this license, do not download, install,
|
|
||||||
// copy or use the software.
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// License Agreement
|
|
||||||
// For Open Source Computer Vision Library
|
|
||||||
//
|
|
||||||
// Copyright (C) 2000, Intel Corporation, all rights reserved.
|
|
||||||
// Copyright (C) 2013, OpenCV Foundation, all rights reserved.
|
|
||||||
// Copyright (C) 2014, Itseez Inc, all rights reserved.
|
|
||||||
// Third party copyrights are property of their respective owners.
|
|
||||||
//
|
|
||||||
// Redistribution and use in source and binary forms, with or without modification,
|
|
||||||
// are permitted provided that the following conditions are met:
|
|
||||||
//
|
|
||||||
// * Redistribution's of source code must retain the above copyright notice,
|
|
||||||
// this list of conditions and the following disclaimer.
|
|
||||||
//
|
|
||||||
// * Redistribution's in binary form must reproduce the above copyright notice,
|
|
||||||
// this list of conditions and the following disclaimer in the documentation
|
|
||||||
// and/or other materials provided with the distribution.
|
|
||||||
//
|
|
||||||
// * The name of the copyright holders may not be used to endorse or promote products
|
|
||||||
// derived from this software without specific prior written permission.
|
|
||||||
//
|
|
||||||
// This software is provided by the copyright holders and contributors "as is" and
|
|
||||||
// any express or implied warranties, including, but not limited to, the implied
|
|
||||||
// warranties of merchantability and fitness for a particular purpose are disclaimed.
|
|
||||||
// In no event shall the Intel Corporation or contributors be liable for any direct,
|
|
||||||
// indirect, incidental, special, exemplary, or consequential damages
|
|
||||||
// (including, but not limited to, procurement of substitute goods or services;
|
|
||||||
// loss of use, data, or profits; or business interruption) however caused
|
|
||||||
// and on any theory of liability, whether in contract, strict liability,
|
|
||||||
// or tort (including negligence or otherwise) arising in any way out of
|
|
||||||
// the use of this software, even if advised of the possibility of such damage.
|
|
||||||
//
|
|
||||||
//M*/
|
|
||||||
|
|
||||||
#ifndef __OPENCV_ML_SVMSGD_HPP__
|
|
||||||
#define __OPENCV_ML_SVMSGD_HPP__
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
|
|
||||||
#include "opencv2/ml.hpp"
|
|
||||||
|
|
||||||
namespace cv
|
|
||||||
{
|
|
||||||
namespace ml
|
|
||||||
{
|
|
||||||
|
|
||||||
|
|
||||||
/****************************************************************************************\
|
|
||||||
* Stochastic Gradient Descent SVM Classifier *
|
|
||||||
\****************************************************************************************/
|
|
||||||
|
|
||||||
/*!
|
|
||||||
@brief Stochastic Gradient Descent SVM classifier
|
|
||||||
|
|
||||||
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large.
|
|
||||||
The gradient descent show amazing performance for large-scale problems, reducing the computing time. This allows a fast and reliable online update of the classifier for each new feature which
|
|
||||||
is fundamental when dealing with variations of data over time (like weather and illumination changes in videosurveillance, for example).
|
|
||||||
|
|
||||||
First, create the SVMSGD object. To enable the online update, a value for updateFrequency should be defined.
|
|
||||||
|
|
||||||
Then the SVM model can be trained using the train features and the correspondent labels.
|
|
||||||
|
|
||||||
After that, the label of a new feature vector can be predicted using the predict function. If the updateFrequency was defined in the constructor, the predict function will update the weights automatically.
|
|
||||||
|
|
||||||
@code
|
|
||||||
// Initialize object
|
|
||||||
SVMSGD SvmSgd;
|
|
||||||
|
|
||||||
// Train the Stochastic Gradient Descent SVM
|
|
||||||
SvmSgd.train(trainFeatures, labels);
|
|
||||||
|
|
||||||
// Predict label for the new feature vector (1xM)
|
|
||||||
predictedLabel = SvmSgd.predict(newFeatureVector);
|
|
||||||
@endcode
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
class CV_EXPORTS_W SVMSGD : public cv::ml::StatModel
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
|
|
||||||
enum SvmsgdType
|
|
||||||
{
|
|
||||||
ILLEGAL_VALUE,
|
|
||||||
SGD, //Stochastic Gradient Descent
|
|
||||||
ASGD //Average Stochastic Gradient Descent
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return the weights of the trained model.
|
|
||||||
*/
|
|
||||||
CV_WRAP virtual Mat getWeights() = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual float getShift() = 0;
|
|
||||||
|
|
||||||
CV_WRAP static Ptr<SVMSGD> create();
|
|
||||||
|
|
||||||
CV_WRAP virtual void setOptimalParameters(int type = ASGD) = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual int getType() const = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual void setType(int type) = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual float getLambda() const = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual void setLambda(float lambda) = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual float getGamma0() const = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual void setGamma0(float gamma0) = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual float getC() const = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual void setC(float c) = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual cv::TermCriteria getTermCriteria() const = 0;
|
|
||||||
|
|
||||||
CV_WRAP virtual void setTermCriteria(const cv::TermCriteria &val) = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
} //ml
|
|
||||||
} //cv
|
|
||||||
|
|
||||||
#endif // __clpusplus
|
|
||||||
#endif // __OPENCV_ML_SVMSGD_HPP
|
|
@ -45,7 +45,6 @@
|
|||||||
#include "opencv2/ml.hpp"
|
#include "opencv2/ml.hpp"
|
||||||
#include "opencv2/core/core_c.h"
|
#include "opencv2/core/core_c.h"
|
||||||
#include "opencv2/core/utility.hpp"
|
#include "opencv2/core/utility.hpp"
|
||||||
#include "opencv2/ml/svmsgd.hpp"
|
|
||||||
#include "opencv2/core/private.hpp"
|
#include "opencv2/core/private.hpp"
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
@ -42,6 +42,12 @@
|
|||||||
|
|
||||||
#include "precomp.hpp"
|
#include "precomp.hpp"
|
||||||
#include "limits"
|
#include "limits"
|
||||||
|
//#include "math.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
using std::cout;
|
||||||
|
using std::endl;
|
||||||
|
|
||||||
/****************************************************************************************\
|
/****************************************************************************************\
|
||||||
* Stochastic Gradient Descent SVM Classifier *
|
* Stochastic Gradient Descent SVM Classifier *
|
||||||
@ -64,7 +70,7 @@ public:
|
|||||||
|
|
||||||
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
|
virtual float predict( InputArray samples, OutputArray results=noArray(), int flags = 0 ) const;
|
||||||
|
|
||||||
virtual bool isClassifier() const { return params.svmsgdType == SGD || params.svmsgdType == ASGD; }
|
virtual bool isClassifier() const;
|
||||||
|
|
||||||
virtual bool isTrained() const;
|
virtual bool isTrained() const;
|
||||||
|
|
||||||
@ -93,22 +99,29 @@ public:
|
|||||||
CV_IMPL_PROPERTY(float, C, params.c)
|
CV_IMPL_PROPERTY(float, C, params.c)
|
||||||
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
|
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void updateWeights(InputArray sample, bool is_first_class, float gamma);
|
void updateWeights(InputArray sample, bool isFirstClass, float gamma, Mat weights);
|
||||||
float calcShift(InputArray trainSamples, InputArray trainResponses) const;
|
|
||||||
std::pair<bool,bool> areClassesEmpty(Mat responses);
|
std::pair<bool,bool> areClassesEmpty(Mat responses);
|
||||||
|
|
||||||
void writeParams( FileStorage& fs ) const;
|
void writeParams( FileStorage& fs ) const;
|
||||||
|
|
||||||
void readParams( const FileNode& fn );
|
void readParams( const FileNode& fn );
|
||||||
|
|
||||||
static inline bool isFirstClass(float val) { return val > 0; }
|
static inline bool isFirstClass(float val) { return val > 0; }
|
||||||
|
|
||||||
|
static void normalizeSamples(Mat &matrix, Mat &multiplier, Mat &average);
|
||||||
|
|
||||||
|
float calcShift(InputArray _samples, InputArray _responses) const;
|
||||||
|
|
||||||
|
static void makeExtendedTrainSamples(const Mat trainSamples, Mat &extendedTrainSamples, Mat &multiplier);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Vector with SVM weights
|
// Vector with SVM weights
|
||||||
Mat weights_;
|
Mat weights_;
|
||||||
float shift_;
|
float shift_;
|
||||||
|
|
||||||
// Random index generation
|
|
||||||
RNG rng_;
|
|
||||||
|
|
||||||
// Parameters for learning
|
// Parameters for learning
|
||||||
struct SVMSGDParams
|
struct SVMSGDParams
|
||||||
{
|
{
|
||||||
@ -127,97 +140,88 @@ Ptr<SVMSGD> SVMSGD::create()
|
|||||||
return makePtr<SVMSGDImpl>();
|
return makePtr<SVMSGDImpl>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
|
||||||
{
|
|
||||||
clear();
|
|
||||||
|
|
||||||
Mat trainSamples = data->getTrainSamples();
|
|
||||||
|
|
||||||
// Initialize varCount
|
|
||||||
int trainSamplesCount_ = trainSamples.rows;
|
|
||||||
int varCount = trainSamples.cols;
|
|
||||||
|
|
||||||
// Initialize weights vector with zeros
|
|
||||||
weights_ = Mat::zeros(1, varCount, CV_32F);
|
|
||||||
|
|
||||||
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
|
|
||||||
|
|
||||||
std::pair<bool,bool> are_empty = areClassesEmpty(trainResponses);
|
|
||||||
|
|
||||||
if ( are_empty.first && are_empty.second )
|
|
||||||
{
|
|
||||||
weights_.release();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if ( are_empty.first || are_empty.second )
|
|
||||||
{
|
|
||||||
shift_ = are_empty.first ? -1 : 1;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
Mat currentSample;
|
|
||||||
float gamma = 0;
|
|
||||||
Mat lastWeights = Mat::zeros(1, varCount, CV_32F); //weights vector for calculating terminal criterion
|
|
||||||
Mat averageWeights; //average weights vector for ASGD model
|
|
||||||
double err = DBL_MAX;
|
|
||||||
if (params.svmsgdType == ASGD)
|
|
||||||
{
|
|
||||||
averageWeights = Mat::zeros(1, varCount, CV_32F);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stochastic gradient descent SVM
|
|
||||||
for (int iter = 0; (iter < params.termCrit.maxCount)&&(err > params.termCrit.epsilon); iter++)
|
|
||||||
{
|
|
||||||
//generate sample number
|
|
||||||
int randomNumber = rng_.uniform(0, trainSamplesCount_);
|
|
||||||
|
|
||||||
currentSample = trainSamples.row(randomNumber);
|
|
||||||
|
|
||||||
//update gamma
|
|
||||||
gamma = params.gamma0 * std::pow((1 + params.lambda * params.gamma0 * (float)iter), (-params.c));
|
|
||||||
|
|
||||||
bool is_first_class = isFirstClass(trainResponses.at<float>(randomNumber));
|
|
||||||
updateWeights( currentSample, is_first_class, gamma );
|
|
||||||
|
|
||||||
//average weights (only for ASGD model)
|
|
||||||
if (params.svmsgdType == ASGD)
|
|
||||||
{
|
|
||||||
averageWeights = ((float)iter/ (1 + (float)iter)) * averageWeights + weights_ / (1 + (float) iter);
|
|
||||||
}
|
|
||||||
|
|
||||||
err = norm(weights_ - lastWeights);
|
|
||||||
weights_.copyTo(lastWeights);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (params.svmsgdType == ASGD)
|
|
||||||
{
|
|
||||||
weights_ = averageWeights;
|
|
||||||
}
|
|
||||||
|
|
||||||
shift_ = calcShift(trainSamples, trainResponses);
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
|
std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
|
||||||
{
|
{
|
||||||
std::pair<bool,bool> are_classes_empty(true, true);
|
CV_Assert(responses.cols == 1);
|
||||||
|
std::pair<bool,bool> emptyInClasses(true, true);
|
||||||
int limit_index = responses.rows;
|
int limit_index = responses.rows;
|
||||||
|
|
||||||
for(int index = 0; index < limit_index; index++)
|
for(int index = 0; index < limit_index; index++)
|
||||||
{
|
{
|
||||||
if (isFirstClass(responses.at<float>(index,0)))
|
if (isFirstClass(responses.at<float>(index)))
|
||||||
are_classes_empty.first = false;
|
emptyInClasses.first = false;
|
||||||
else
|
else
|
||||||
are_classes_empty.second = false;
|
emptyInClasses.second = false;
|
||||||
|
|
||||||
if (!are_classes_empty.first && ! are_classes_empty.second)
|
if (!emptyInClasses.first && ! emptyInClasses.second)
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return are_classes_empty;
|
return emptyInClasses;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &multiplier, Mat &average)
|
||||||
|
{
|
||||||
|
int featuresCount = samples.cols;
|
||||||
|
int samplesCount = samples.rows;
|
||||||
|
|
||||||
|
average = Mat(1, featuresCount, samples.type());
|
||||||
|
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
|
||||||
|
{
|
||||||
|
average.at<float>(featureIndex) = mean(samples.col(featureIndex))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
|
||||||
|
{
|
||||||
|
samples.row(sampleIndex) -= average;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat featureNorm(1, featuresCount, samples.type());
|
||||||
|
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
|
||||||
|
{
|
||||||
|
featureNorm.at<float>(featureIndex) = norm(samples.col(featureIndex));
|
||||||
|
}
|
||||||
|
|
||||||
|
multiplier = sqrt(samplesCount) / featureNorm;
|
||||||
|
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
|
||||||
|
{
|
||||||
|
samples.row(sampleIndex) = samples.row(sampleIndex).mul(multiplier);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void SVMSGDImpl::makeExtendedTrainSamples(const Mat trainSamples, Mat &extendedTrainSamples, Mat &multiplier)
|
||||||
|
{
|
||||||
|
Mat normalisedTrainSamples = trainSamples.clone();
|
||||||
|
int samplesCount = normalisedTrainSamples.rows;
|
||||||
|
|
||||||
|
Mat average;
|
||||||
|
|
||||||
|
normalizeSamples(normalisedTrainSamples, multiplier, average);
|
||||||
|
|
||||||
|
Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
|
||||||
|
cv::hconcat(normalisedTrainSamples, onesCol, extendedTrainSamples);
|
||||||
|
|
||||||
|
//cout << "SVMSGDImpl::makeExtendedTrainSamples average: \n" << average << endl;
|
||||||
|
//cout << "SVMSGDImpl::makeExtendedTrainSamples multiplier: \n" << multiplier << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float gamma, Mat weights)
|
||||||
|
{
|
||||||
|
Mat sample = _sample.getMat();
|
||||||
|
|
||||||
|
int response = firstClass ? 1 : -1; // ensure that trainResponses are -1 or 1
|
||||||
|
|
||||||
|
if ( sample.dot(weights) * response > 1)
|
||||||
|
{
|
||||||
|
// Not a support vector, only apply weight decay
|
||||||
|
weights *= (1.f - gamma * params.lambda);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// It's a support vector, add it to the weights
|
||||||
|
weights -= (gamma * params.lambda) * weights - (gamma * response) * sample;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
||||||
@ -232,12 +236,12 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
|||||||
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
|
for (int samplesIndex = 0; samplesIndex < trainSamplesCount; samplesIndex++)
|
||||||
{
|
{
|
||||||
Mat currentSample = trainSamples.row(samplesIndex);
|
Mat currentSample = trainSamples.row(samplesIndex);
|
||||||
float scalar_product = currentSample.dot(weights_);
|
float dotProduct = currentSample.dot(weights_);
|
||||||
|
|
||||||
bool is_first_class = isFirstClass(trainResponses.at<float>(samplesIndex));
|
bool firstClass = isFirstClass(trainResponses.at<float>(samplesIndex));
|
||||||
int index = is_first_class ? 0:1;
|
int index = firstClass ? 0:1;
|
||||||
float sign_to_mul = is_first_class ? 1 : -1;
|
float signToMul = firstClass ? 1 : -1;
|
||||||
float cur_distance = scalar_product * sign_to_mul ;
|
float cur_distance = dotProduct * signToMul;
|
||||||
|
|
||||||
if (cur_distance < distance_to_classes[index])
|
if (cur_distance < distance_to_classes[index])
|
||||||
{
|
{
|
||||||
@ -245,10 +249,109 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//todo: areClassesEmpty(); make const;
|
|
||||||
return -(distance_to_classes[0] - distance_to_classes[1]) / 2.f;
|
return -(distance_to_classes[0] - distance_to_classes[1]) / 2.f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||||
|
{
|
||||||
|
//cout << "SVMSGDImpl::train begin" << endl;
|
||||||
|
clear();
|
||||||
|
CV_Assert( isClassifier() ); //toDo: consider
|
||||||
|
|
||||||
|
Mat trainSamples = data->getTrainSamples();
|
||||||
|
|
||||||
|
//cout << "SVMSGDImpl::train trainSamples: \n" << trainSamples << endl;
|
||||||
|
|
||||||
|
int featureCount = trainSamples.cols;
|
||||||
|
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
|
||||||
|
|
||||||
|
//cout << "SVMSGDImpl::train trainresponses: \n" << trainResponses << endl;
|
||||||
|
|
||||||
|
std::pair<bool,bool> areEmpty = areClassesEmpty(trainResponses);
|
||||||
|
|
||||||
|
//cout << "SVMSGDImpl::train areEmpty" << areEmpty.first << "," << areEmpty.second << endl;
|
||||||
|
|
||||||
|
if ( areEmpty.first && areEmpty.second )
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if ( areEmpty.first || areEmpty.second )
|
||||||
|
{
|
||||||
|
weights_ = Mat::zeros(1, featureCount, CV_32F);
|
||||||
|
shift_ = areEmpty.first ? -1 : 1;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat extendedTrainSamples;
|
||||||
|
Mat multiplier;
|
||||||
|
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, multiplier);
|
||||||
|
|
||||||
|
//cout << "SVMSGDImpl::train extendedTrainSamples: \n" << extendedTrainSamples << endl;
|
||||||
|
|
||||||
|
int extendedTrainSamplesCount = extendedTrainSamples.rows;
|
||||||
|
int extendedFeatureCount = extendedTrainSamples.cols;
|
||||||
|
|
||||||
|
Mat extendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); // Initialize extendedWeights vector with zeros
|
||||||
|
Mat previousWeights = Mat::zeros(1, extendedFeatureCount, CV_32F); //extendedWeights vector for calculating terminal criterion
|
||||||
|
Mat averageExtendedWeights; //average extendedWeights vector for ASGD model
|
||||||
|
if (params.svmsgdType == ASGD)
|
||||||
|
{
|
||||||
|
averageExtendedWeights = Mat::zeros(1, extendedFeatureCount, CV_32F);
|
||||||
|
}
|
||||||
|
|
||||||
|
RNG rng(0);
|
||||||
|
|
||||||
|
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
|
||||||
|
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
|
||||||
|
|
||||||
|
double err = DBL_MAX;
|
||||||
|
// Stochastic gradient descent SVM
|
||||||
|
for (int iter = 0; (iter < maxCount) && (err > epsilon); iter++)
|
||||||
|
{
|
||||||
|
int randomNumber = rng.uniform(0, extendedTrainSamplesCount); //generate sample number
|
||||||
|
|
||||||
|
Mat currentSample = extendedTrainSamples.row(randomNumber);
|
||||||
|
bool firstClass = isFirstClass(trainResponses.at<float>(randomNumber));
|
||||||
|
|
||||||
|
float gamma = params.gamma0 * std::pow((1 + params.lambda * params.gamma0 * (float)iter), (-params.c)); //update gamma
|
||||||
|
|
||||||
|
updateWeights( currentSample, firstClass, gamma, extendedWeights );
|
||||||
|
|
||||||
|
//average weights (only for ASGD model)
|
||||||
|
if (params.svmsgdType == ASGD)
|
||||||
|
{
|
||||||
|
averageExtendedWeights = ((float)iter/ (1 + (float)iter)) * averageExtendedWeights + extendedWeights / (1 + (float) iter);
|
||||||
|
err = norm(averageExtendedWeights - previousWeights);
|
||||||
|
averageExtendedWeights.copyTo(previousWeights);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
err = norm(extendedWeights - previousWeights);
|
||||||
|
extendedWeights.copyTo(previousWeights);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.svmsgdType == ASGD)
|
||||||
|
{
|
||||||
|
extendedWeights = averageExtendedWeights;
|
||||||
|
}
|
||||||
|
|
||||||
|
//cout << "SVMSGDImpl::train extendedWeights: \n" << extendedWeights << endl;
|
||||||
|
|
||||||
|
Rect roi(0, 0, featureCount, 1);
|
||||||
|
weights_ = extendedWeights(roi);
|
||||||
|
weights_ = weights_.mul(1/multiplier);
|
||||||
|
|
||||||
|
//cout << "SVMSGDImpl::train weights: \n" << weights_ << endl;
|
||||||
|
|
||||||
|
shift_ = calcShift(trainSamples, trainResponses);
|
||||||
|
|
||||||
|
//cout << "SVMSGDImpl::train shift = " << shift_ << endl;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
|
float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) const
|
||||||
{
|
{
|
||||||
float result = 0;
|
float result = 0;
|
||||||
@ -269,37 +372,21 @@ float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) cons
|
|||||||
results = Mat(1, 1, CV_32F, &result);
|
results = Mat(1, 1, CV_32F, &result);
|
||||||
}
|
}
|
||||||
|
|
||||||
Mat currentSample;
|
|
||||||
float criterion = 0;
|
|
||||||
|
|
||||||
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
|
for (int sampleIndex = 0; sampleIndex < nSamples; sampleIndex++)
|
||||||
{
|
{
|
||||||
currentSample = samples.row(sampleIndex);
|
Mat currentSample = samples.row(sampleIndex);
|
||||||
criterion = currentSample.dot(weights_) + shift_;
|
float criterion = currentSample.dot(weights_) + shift_;
|
||||||
results.at<float>(sampleIndex) = (criterion >= 0) ? 1 : -1;
|
results.at<float>(sampleIndex) = (criterion >= 0) ? 1 : -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SVMSGDImpl::updateWeights(InputArray _sample, bool is_first_class, float gamma)
|
bool SVMSGDImpl::isClassifier() const
|
||||||
{
|
{
|
||||||
Mat sample = _sample.getMat();
|
return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
|
||||||
|
&&
|
||||||
int responce = is_first_class ? 1 : -1; // ensure that trainResponses are -1 or 1
|
(params.lambda > 0) && (params.gamma0 > 0) && (params.c >= 0);
|
||||||
|
|
||||||
if ( sample.dot(weights_) * responce > 1)
|
|
||||||
{
|
|
||||||
// Not a support vector, only apply weight decay
|
|
||||||
weights_ *= (1.f - gamma * params.lambda);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
// It's a support vector, add it to the weights
|
|
||||||
weights_ -= (gamma * params.lambda) * weights_ - gamma * responce * sample;
|
|
||||||
//std::cout << "sample " << sample << std::endl;
|
|
||||||
//std::cout << "weights_ " << weights_ << std::endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool SVMSGDImpl::isTrained() const
|
bool SVMSGDImpl::isTrained() const
|
||||||
@ -314,8 +401,8 @@ void SVMSGDImpl::write(FileStorage& fs) const
|
|||||||
|
|
||||||
writeParams( fs );
|
writeParams( fs );
|
||||||
|
|
||||||
fs << "shift" << shift_;
|
|
||||||
fs << "weights" << weights_;
|
fs << "weights" << weights_;
|
||||||
|
fs << "shift" << shift_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||||
@ -359,8 +446,8 @@ void SVMSGDImpl::read(const FileNode& fn)
|
|||||||
|
|
||||||
readParams(fn);
|
readParams(fn);
|
||||||
|
|
||||||
shift_ = (float) fn["shift"];
|
|
||||||
fn["weights"] >> weights_;
|
fn["weights"] >> weights_;
|
||||||
|
fn["shift"] >> shift_;
|
||||||
}
|
}
|
||||||
|
|
||||||
void SVMSGDImpl::readParams( const FileNode& fn )
|
void SVMSGDImpl::readParams( const FileNode& fn )
|
||||||
@ -393,21 +480,19 @@ void SVMSGDImpl::readParams( const FileNode& fn )
|
|||||||
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
|
(params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON );
|
params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 100000, FLT_EPSILON );
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SVMSGDImpl::clear()
|
void SVMSGDImpl::clear()
|
||||||
{
|
{
|
||||||
weights_.release();
|
weights_.release();
|
||||||
shift_ = 0;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
SVMSGDImpl::SVMSGDImpl()
|
SVMSGDImpl::SVMSGDImpl()
|
||||||
{
|
{
|
||||||
clear();
|
clear();
|
||||||
rng_(0);
|
|
||||||
|
|
||||||
params.svmsgdType = ILLEGAL_VALUE;
|
params.svmsgdType = ILLEGAL_VALUE;
|
||||||
|
|
||||||
@ -426,20 +511,20 @@ void SVMSGDImpl::setOptimalParameters(int type)
|
|||||||
{
|
{
|
||||||
case SGD:
|
case SGD:
|
||||||
params.svmsgdType = SGD;
|
params.svmsgdType = SGD;
|
||||||
params.lambda = 0.00001;
|
params.lambda = 0.0001;
|
||||||
params.gamma0 = 0.05;
|
params.gamma0 = 0.05;
|
||||||
params.c = 1;
|
params.c = 1;
|
||||||
params.termCrit.maxCount = 50000;
|
params.termCrit.maxCount = 100000;
|
||||||
params.termCrit.epsilon = 0.00000001;
|
params.termCrit.epsilon = 0.00001;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case ASGD:
|
case ASGD:
|
||||||
params.svmsgdType = ASGD;
|
params.svmsgdType = ASGD;
|
||||||
params.lambda = 0.00001;
|
params.lambda = 0.00001;
|
||||||
params.gamma0 = 0.5;
|
params.gamma0 = 0.05;
|
||||||
params.c = 0.75;
|
params.c = 0.75;
|
||||||
params.termCrit.maxCount = 100000;
|
params.termCrit.maxCount = 100000;
|
||||||
params.termCrit.epsilon = 0.000001;
|
params.termCrit.epsilon = 0.00001;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include "opencv2/ts.hpp"
|
#include "opencv2/ts.hpp"
|
||||||
#include "opencv2/ml.hpp"
|
#include "opencv2/ml.hpp"
|
||||||
#include "opencv2/ml/svmsgd.hpp"
|
|
||||||
#include "opencv2/core/core_c.h"
|
#include "opencv2/core/core_c.h"
|
||||||
|
|
||||||
#define CV_NBAYES "nbayes"
|
#define CV_NBAYES "nbayes"
|
||||||
|
@ -52,7 +52,7 @@ using cv::ml::TrainData;
|
|||||||
class CV_SVMSGDTrainTest : public cvtest::BaseTest
|
class CV_SVMSGDTrainTest : public cvtest::BaseTest
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
CV_SVMSGDTrainTest(Mat _weights, float _shift);
|
CV_SVMSGDTrainTest(Mat _weights, float shift);
|
||||||
private:
|
private:
|
||||||
virtual void run( int start_from );
|
virtual void run( int start_from );
|
||||||
float decisionFunction(Mat sample, Mat weights, float shift);
|
float decisionFunction(Mat sample, Mat weights, float shift);
|
||||||
@ -60,7 +60,7 @@ private:
|
|||||||
cv::Ptr<TrainData> data;
|
cv::Ptr<TrainData> data;
|
||||||
cv::Mat testSamples;
|
cv::Mat testSamples;
|
||||||
cv::Mat testResponses;
|
cv::Mat testResponses;
|
||||||
static const int TEST_VALUE_LIMIT = 50;
|
static const int TEST_VALUE_LIMIT = 500;
|
||||||
};
|
};
|
||||||
|
|
||||||
CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift)
|
CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift)
|
||||||
@ -81,6 +81,11 @@ CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift)
|
|||||||
responses.at<float>( sampleIndex ) = decisionFunction(samples.row(sampleIndex), weights, shift) > 0 ? 1 : -1;
|
responses.at<float>( sampleIndex ) = decisionFunction(samples.row(sampleIndex), weights, shift) > 0 ? 1 : -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::cout << "real weights\n" << weights/norm(weights) << "\n" << std::endl;
|
||||||
|
std::cout << "real shift \n" << shift/norm(weights) << "\n" << std::endl;
|
||||||
|
|
||||||
data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
|
data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
|
||||||
|
|
||||||
int testSamplesCount = 100000;
|
int testSamplesCount = 100000;
|
||||||
@ -100,8 +105,9 @@ void CV_SVMSGDTrainTest::run( int /*start_from*/ )
|
|||||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||||
|
|
||||||
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
||||||
|
svmsgd->setTermCriteria(TermCriteria(TermCriteria::EPS, 0, 0.00005));
|
||||||
|
|
||||||
svmsgd->train( data );
|
svmsgd->train(data);
|
||||||
|
|
||||||
Mat responses;
|
Mat responses;
|
||||||
|
|
||||||
@ -116,6 +122,12 @@ void CV_SVMSGDTrainTest::run( int /*start_from*/ )
|
|||||||
errCount++;
|
errCount++;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
float normW = norm(svmsgd->getWeights());
|
||||||
|
|
||||||
|
std::cout << "found weights\n" << svmsgd->getWeights()/normW << "\n" << std::endl;
|
||||||
|
std::cout << "found shift \n" << svmsgd->getShift()/normW << "\n" << std::endl;
|
||||||
|
|
||||||
float err = (float)errCount / testSamplesCount;
|
float err = (float)errCount / testSamplesCount;
|
||||||
std::cout << "err " << err << std::endl;
|
std::cout << "err " << err << std::endl;
|
||||||
|
|
||||||
@ -138,8 +150,8 @@ TEST(ML_SVMSGD, train0)
|
|||||||
weights.create(1, varCount, CV_32FC1);
|
weights.create(1, varCount, CV_32FC1);
|
||||||
weights.at<float>(0) = 1;
|
weights.at<float>(0) = 1;
|
||||||
weights.at<float>(1) = 0;
|
weights.at<float>(1) = 0;
|
||||||
|
cv::RNG rng(1);
|
||||||
float shift = 5;
|
float shift = rng.uniform(-varCount, varCount);
|
||||||
|
|
||||||
CV_SVMSGDTrainTest test(weights, shift);
|
CV_SVMSGDTrainTest test(weights, shift);
|
||||||
test.safe_run();
|
test.safe_run();
|
||||||
@ -157,7 +169,7 @@ TEST(ML_SVMSGD, train1)
|
|||||||
cv::RNG rng(0);
|
cv::RNG rng(0);
|
||||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
|
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||||
|
|
||||||
float shift = rng.uniform(-5.f, 5.f);
|
float shift = rng.uniform(-varCount, varCount);
|
||||||
|
|
||||||
CV_SVMSGDTrainTest test(weights, shift);
|
CV_SVMSGDTrainTest test(weights, shift);
|
||||||
test.safe_run();
|
test.safe_run();
|
||||||
@ -175,8 +187,8 @@ TEST(ML_SVMSGD, train2)
|
|||||||
cv::RNG rng(0);
|
cv::RNG rng(0);
|
||||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
|
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||||
|
|
||||||
float shift = rng.uniform(-1000.f, 1000.f);
|
float shift = rng.uniform(-varCount, varCount);
|
||||||
|
|
||||||
CV_SVMSGDTrainTest test(weights, shift);
|
CV_SVMSGDTrainTest test(weights,shift);
|
||||||
test.safe_run();
|
test.safe_run();
|
||||||
}
|
}
|
||||||
|
@ -12,10 +12,8 @@ using namespace cv::ml;
|
|||||||
struct Data
|
struct Data
|
||||||
{
|
{
|
||||||
Mat img;
|
Mat img;
|
||||||
Mat samples;
|
Mat samples; //Set of train samples. Contains points on image
|
||||||
Mat responses;
|
Mat responses; //Set of responses for train samples
|
||||||
RNG rng;
|
|
||||||
//Point points[2];
|
|
||||||
|
|
||||||
Data()
|
Data()
|
||||||
{
|
{
|
||||||
@ -24,24 +22,36 @@ struct Data
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool doTrain(const Mat samples,const Mat responses, Mat &weights, float &shift);
|
//Train with SVMSGD algorithm
|
||||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2]);
|
//(samples, responses) is a train set
|
||||||
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
|
//weights is a required vector for decision function of SVMSGD algorithm
|
||||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments);
|
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);
|
||||||
|
|
||||||
|
//function finds two points for drawing line (wx = 0)
|
||||||
|
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2], int width, int height);
|
||||||
|
|
||||||
|
// function finds cross point of line (wx = 0) and segment ( (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) )
|
||||||
|
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
|
||||||
|
|
||||||
|
//segments' initialization ( (y = HEIGHT, 0 <= x <= WIDTH) and (x = WIDTH, 0 <= y <= HEIGHT) )
|
||||||
|
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);
|
||||||
|
|
||||||
|
//redraw points' set and line (wx = 0)
|
||||||
void redraw(Data data, const Point points[2]);
|
void redraw(Data data, const Point points[2]);
|
||||||
void addPointsRetrainAndRedraw(Data &data, int x, int y);
|
|
||||||
|
//add point in train set, train SVMSGD algorithm and draw results on image
|
||||||
|
void addPointRetrainAndRedraw(Data &data, int x, int y);
|
||||||
|
|
||||||
|
|
||||||
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
|
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
|
||||||
{
|
{
|
||||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||||
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
||||||
svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 50000, 0.0000001));
|
svmsgd->setTermCriteria(TermCriteria(TermCriteria::EPS, 0, 0.00000001));
|
||||||
svmsgd->setLambda(0.01);
|
svmsgd->setLambda(0.00000001);
|
||||||
svmsgd->setGamma0(1);
|
|
||||||
// svmsgd->setC(5);
|
|
||||||
|
|
||||||
cv::Ptr<TrainData> train_data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
|
|
||||||
|
cv::Ptr<TrainData> train_data = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
|
||||||
svmsgd->train( train_data );
|
svmsgd->train( train_data );
|
||||||
|
|
||||||
if (svmsgd->isTrained())
|
if (svmsgd->isTrained())
|
||||||
@ -49,36 +59,39 @@ bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift
|
|||||||
weights = svmsgd->getWeights();
|
weights = svmsgd->getWeights();
|
||||||
shift = svmsgd->getShift();
|
shift = svmsgd->getShift();
|
||||||
|
|
||||||
std::cout << weights << std::endl;
|
|
||||||
std::cout << shift << std::endl;
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
|
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
|
||||||
{
|
{
|
||||||
int x = 0;
|
int x = 0;
|
||||||
int y = 0;
|
int y = 0;
|
||||||
//с (0,0) всё плохо
|
int xMin = std::min(segment.first.x, segment.second.x);
|
||||||
if (segment.first.x == segment.second.x && weights.at<float>(1) != 0)
|
int xMax = std::max(segment.first.x, segment.second.x);
|
||||||
|
int yMin = std::min(segment.first.y, segment.second.y);
|
||||||
|
int yMax = std::max(segment.first.y, segment.second.y);
|
||||||
|
|
||||||
|
CV_Assert(xMin == xMax || yMin == yMax);
|
||||||
|
|
||||||
|
if (xMin == xMax && weights.at<float>(1) != 0)
|
||||||
{
|
{
|
||||||
x = segment.first.x;
|
x = xMin;
|
||||||
y = -(weights.at<float>(0) * x + shift) / weights.at<float>(1);
|
y = std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1));
|
||||||
if (y >= 0 && y <= HEIGHT)
|
if (y >= yMin && y <= yMax)
|
||||||
{
|
{
|
||||||
crossPoint.x = x;
|
crossPoint.x = x;
|
||||||
crossPoint.y = y;
|
crossPoint.y = y;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (segment.first.y == segment.second.y && weights.at<float>(0) != 0)
|
else if (yMin == yMax && weights.at<float>(0) != 0)
|
||||||
{
|
{
|
||||||
y = segment.first.y;
|
y = yMin;
|
||||||
x = - (weights.at<float>(1) * y + shift) / weights.at<float>(0);
|
x = std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0));
|
||||||
if (x >= 0 && x <= WIDTH)
|
if (x >= xMin && x <= xMax)
|
||||||
{
|
{
|
||||||
crossPoint.x = x;
|
crossPoint.x = x;
|
||||||
crossPoint.y = y;
|
crossPoint.y = y;
|
||||||
@ -88,7 +101,7 @@ bool findCrossPoint(const Mat &weights, float shift, const std::pair<Point,Point
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2])
|
bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2], int width, int height)
|
||||||
{
|
{
|
||||||
if (weights.empty())
|
if (weights.empty())
|
||||||
{
|
{
|
||||||
@ -97,42 +110,43 @@ bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2])
|
|||||||
|
|
||||||
int foundPointsCount = 0;
|
int foundPointsCount = 0;
|
||||||
std::vector<std::pair<Point,Point> > segments;
|
std::vector<std::pair<Point,Point> > segments;
|
||||||
fillSegments(segments);
|
fillSegments(segments, width, height);
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++)
|
for (uint i = 0; i < segments.size(); i++)
|
||||||
{
|
{
|
||||||
if (findCrossPoint(weights, shift, segments[i], points[foundPointsCount]))
|
if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
|
||||||
foundPointsCount++;
|
foundPointsCount++;
|
||||||
if (foundPointsCount > 2)
|
if (foundPointsCount >= 2)
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments)
|
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
|
||||||
{
|
{
|
||||||
std::pair<Point,Point> curSegment;
|
std::pair<Point,Point> currentSegment;
|
||||||
|
|
||||||
curSegment.first = Point(0,0);
|
currentSegment.first = Point(width, 0);
|
||||||
curSegment.second = Point(0,HEIGHT);
|
currentSegment.second = Point(width, height);
|
||||||
segments.push_back(curSegment);
|
segments.push_back(currentSegment);
|
||||||
|
|
||||||
curSegment.first = Point(0,0);
|
currentSegment.first = Point(0, height);
|
||||||
curSegment.second = Point(WIDTH,0);
|
currentSegment.second = Point(width, height);
|
||||||
segments.push_back(curSegment);
|
segments.push_back(currentSegment);
|
||||||
|
|
||||||
curSegment.first = Point(WIDTH,0);
|
currentSegment.first = Point(0, 0);
|
||||||
curSegment.second = Point(WIDTH,HEIGHT);
|
currentSegment.second = Point(width, 0);
|
||||||
segments.push_back(curSegment);
|
segments.push_back(currentSegment);
|
||||||
|
|
||||||
curSegment.first = Point(0,HEIGHT);
|
currentSegment.first = Point(0, 0);
|
||||||
curSegment.second = Point(WIDTH,HEIGHT);
|
currentSegment.second = Point(0, height);
|
||||||
segments.push_back(curSegment);
|
segments.push_back(currentSegment);
|
||||||
}
|
}
|
||||||
|
|
||||||
void redraw(Data data, const Point points[2])
|
void redraw(Data data, const Point points[2])
|
||||||
{
|
{
|
||||||
data.img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
|
data.img.setTo(0);
|
||||||
Point center;
|
Point center;
|
||||||
int radius = 3;
|
int radius = 3;
|
||||||
Scalar color;
|
Scalar color;
|
||||||
@ -143,48 +157,26 @@ void redraw(Data data, const Point points[2])
|
|||||||
color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
|
color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
|
||||||
circle(data.img, center, radius, color, 5);
|
circle(data.img, center, radius, color, 5);
|
||||||
}
|
}
|
||||||
line(data.img, points[0],points[1],cv::Scalar(1,255,1));
|
line(data.img, points[0], points[1],cv::Scalar(1,255,1));
|
||||||
|
|
||||||
imshow("Train svmsgd", data.img);
|
imshow("Train svmsgd", data.img);
|
||||||
}
|
}
|
||||||
|
|
||||||
void addPointsRetrainAndRedraw(Data &data, int x, int y)
|
void addPointRetrainAndRedraw(Data &data, int x, int y)
|
||||||
{
|
{
|
||||||
|
|
||||||
Mat currentSample(1, 2, CV_32F);
|
Mat currentSample(1, 2, CV_32F);
|
||||||
//start
|
|
||||||
/*
|
|
||||||
Mat _weights;
|
|
||||||
_weights.create(1, 2, CV_32FC1);
|
|
||||||
_weights.at<float>(0) = 1;
|
|
||||||
_weights.at<float>(1) = -1;
|
|
||||||
|
|
||||||
int _x, _y;
|
|
||||||
|
|
||||||
for (int i=0;i<199;i++)
|
|
||||||
{
|
|
||||||
_x = data.rng.uniform(0,800);
|
|
||||||
_y = data.rng.uniform(0,500);*/
|
|
||||||
currentSample.at<float>(0,0) = x;
|
currentSample.at<float>(0,0) = x;
|
||||||
currentSample.at<float>(0,1) = y;
|
currentSample.at<float>(0,1) = y;
|
||||||
//if (currentSample.dot(_weights) > 0)
|
|
||||||
//data.responses.push_back(1);
|
|
||||||
// else data.responses.push_back(-1);
|
|
||||||
|
|
||||||
//finish
|
|
||||||
data.samples.push_back(currentSample);
|
data.samples.push_back(currentSample);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Mat weights(1, 2, CV_32F);
|
Mat weights(1, 2, CV_32F);
|
||||||
float shift = 0;
|
float shift = 0;
|
||||||
|
|
||||||
if (doTrain(data.samples, data.responses, weights, shift))
|
if (doTrain(data.samples, data.responses, weights, shift))
|
||||||
{
|
{
|
||||||
Point points[2];
|
Point points[2];
|
||||||
shift = 0;
|
findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
|
||||||
|
|
||||||
findPointsForLine(weights, shift, points);
|
|
||||||
|
|
||||||
redraw(data, points);
|
redraw(data, points);
|
||||||
}
|
}
|
||||||
@ -199,13 +191,13 @@ static void onMouse( int event, int x, int y, int, void* pData)
|
|||||||
{
|
{
|
||||||
case CV_EVENT_LBUTTONUP:
|
case CV_EVENT_LBUTTONUP:
|
||||||
data.responses.push_back(1);
|
data.responses.push_back(1);
|
||||||
addPointsRetrainAndRedraw(data, x, y);
|
addPointRetrainAndRedraw(data, x, y);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case CV_EVENT_RBUTTONDOWN:
|
case CV_EVENT_RBUTTONDOWN:
|
||||||
data.responses.push_back(-1);
|
data.responses.push_back(-1);
|
||||||
addPointsRetrainAndRedraw(data, x, y);
|
addPointRetrainAndRedraw(data, x, y);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,14 +205,10 @@ static void onMouse( int event, int x, int y, int, void* pData)
|
|||||||
|
|
||||||
int main()
|
int main()
|
||||||
{
|
{
|
||||||
|
|
||||||
Data data;
|
Data data;
|
||||||
|
|
||||||
setMouseCallback( "Train svmsgd", onMouse, &data );
|
setMouseCallback( "Train svmsgd", onMouse, &data );
|
||||||
waitKey();
|
waitKey();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user