Added margin type, added tests with different scales of features.
Also fixed documentation, refactored sample.
This commit is contained in:
parent
acd74037b3
commit
bfdca05f25
@ -1504,27 +1504,78 @@ public:
|
||||
/*!
|
||||
@brief Stochastic Gradient Descent SVM classifier
|
||||
|
||||
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach, as presented in @cite bottou2010large.
|
||||
SVMSGD provides a fast and easy-to-use implementation of the SVM classifier using the Stochastic Gradient Descent approach,
|
||||
as presented in @cite bottou2010large.
|
||||
The gradient descent show amazing performance for large-scale problems, reducing the computing time.
|
||||
|
||||
The classifier has 5 parameters. These are
|
||||
- model type,
|
||||
- margin type,
|
||||
- \f$\lambda\f$ (strength of restrictions on outliers),
|
||||
- \f$\gamma_0\f$ (initial step size),
|
||||
- \f$c\f$ (power coefficient for decreasing of step size),
|
||||
- and termination criteria.
|
||||
|
||||
The model type may have one of the following values: \ref SGD and \ref ASGD.
|
||||
|
||||
First, create the SVMSGD object. Set parametrs of model (type, lambda, gamma0, c) using the functions setType, setLambda, setGamma0 and setC or the function setOptimalParametrs.
|
||||
Recommended model type is ASGD.
|
||||
- \ref SGD is the classic version of SVMSGD classifier: every next step is calculated by the formula
|
||||
\f[w_{t+1} = w_t - \gamma(t) \frac{dQ_i}{dw} |_{w = w_t}\f]
|
||||
where
|
||||
- \f$w_t\f$ is the weights vector for decision function at step \f$t\f$,
|
||||
- \f$\gamma(t)\f$ is the step size of model parameters at the iteration \f$t\f$, it is decreased on each step by the formula
|
||||
\f$\gamma(t) = \gamma_0 (1 + \lambda \gamma_0 t) ^ {-c}\f$
|
||||
- \f$Q_i\f$ is the target functional from SVM task for sample with number \f$i\f$, this sample is chosen stochastically on each step of the algorithm.
|
||||
|
||||
Then the SVM model can be trained using the train features and the correspondent labels.
|
||||
- \ref ASGD is Average Stochastic Gradient Descent SVM Classifier. ASGD classifier averages weights vector on each step of algorithm by the formula
|
||||
\f$\widehat{w}_{t+1} = \frac{t}{1+t}\widehat{w}_{t} + \frac{1}{1+t}w_{t+1}\f$
|
||||
|
||||
After that, the label of a new feature vector can be predicted using the predict function.
|
||||
The recommended model type is ASGD (following @cite bottou2010large).
|
||||
|
||||
The margin type may have one of the following values: \ref SOFT_MARGIN or \ref HARD_MARGIN.
|
||||
|
||||
- You should use \ref HARD_MARGIN type, if you have linearly separable sets.
|
||||
- You should use \ref SOFT_MARGIN type, if you have non-linearly separable sets or sets with outliers.
|
||||
- In the general case (if you know nothing about linearly separability of your sets), use SOFT_MARGIN.
|
||||
|
||||
The other parameters may be described as follows:
|
||||
- \f$\lambda\f$ parameter is responsible for weights decreasing at each step and for the strength of restrictions on outliers
|
||||
(the less the parameter, the less probability that an outlier will be ignored).
|
||||
Recommended value for SGD model is 0.0001, for ASGD model is 0.00001.
|
||||
|
||||
- \f$\gamma_0\f$ parameter is the initial value for the step size \f$\gamma(t)\f$.
|
||||
You will have to find the best \f$\gamma_0\f$ for your problem.
|
||||
|
||||
- \f$c\f$ is the power parameter for \f$\gamma(t)\f$ decreasing by the formula, mentioned above.
|
||||
Recommended value for SGD model is 1, for ASGD model is 0.75.
|
||||
|
||||
- Termination criteria can be TermCriteria::COUNT, TermCriteria::EPS or TermCriteria::COUNT + TermCriteria::EPS.
|
||||
You will have to find the best termination criteria for your problem.
|
||||
|
||||
Note that the parameters \f$\lambda\f$, \f$\gamma_0\f$, and \f$c\f$ should be positive.
|
||||
|
||||
To use SVMSGD algorithm do as follows:
|
||||
|
||||
- first, create the SVMSGD object.
|
||||
|
||||
- then set parameters (model type, margin type, \f$\lambda\f$, \f$\gamma_0\f$, \f$c\f$) using the functions
|
||||
setSvmsgdType(), setMarginType(), setLambda(), setGamma0(), and setC(), or the function setOptimalParameters().
|
||||
|
||||
- then the SVM model can be trained using the train features and the correspondent labels by the method train().
|
||||
|
||||
- after that, the label of a new feature vector can be predicted using the method predict().
|
||||
|
||||
@code
|
||||
// Initialize object
|
||||
SVMSGD SvmSgd;
|
||||
// Create empty object
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
|
||||
// Set parameters
|
||||
svmsgd->setOptimalParameters();
|
||||
|
||||
// Train the Stochastic Gradient Descent SVM
|
||||
SvmSgd.train(trainFeatures, labels);
|
||||
SvmSgd->train(trainData);
|
||||
|
||||
// Predict label for the new feature vector (1xM)
|
||||
predictedLabel = SvmSgd.predict(newFeatureVector);
|
||||
// Predict labels for the new samples
|
||||
svmsgd->predict(samples, responses);
|
||||
@endcode
|
||||
|
||||
*/
|
||||
@ -1537,9 +1588,17 @@ public:
|
||||
ASGD is often the preferable choice. */
|
||||
enum SvmsgdType
|
||||
{
|
||||
ILLEGAL_VALUE,
|
||||
SGD, //!Stochastic Gradient Descent
|
||||
ASGD //!Average Stochastic Gradient Descent
|
||||
ILLEGAL_SVMSGD_TYPE,
|
||||
SGD, //!< Stochastic Gradient Descent
|
||||
ASGD //!< Average Stochastic Gradient Descent
|
||||
};
|
||||
|
||||
/** Margin type.*/
|
||||
enum MarginType
|
||||
{
|
||||
ILLEGAL_MARGIN_TYPE,
|
||||
SOFT_MARGIN, //!< General case, suits to the case of non-linearly separable sets, allows outliers.
|
||||
HARD_MARGIN //!< More accurate for the case of linearly separable sets.
|
||||
};
|
||||
|
||||
/**
|
||||
@ -1553,49 +1612,59 @@ public:
|
||||
CV_WRAP virtual float getShift() = 0;
|
||||
|
||||
|
||||
/** Creates empty model.
|
||||
/** @brief Creates empty model.
|
||||
Use StatModel::train to train the model. Since %SVMSGD has several parameters, you may want to
|
||||
find the best parameters for your problem or use setOptimalParameters() to set some default parameters.
|
||||
*/
|
||||
CV_WRAP static Ptr<SVMSGD> create();
|
||||
|
||||
/** Function sets optimal parameters values for chosen SVM SGD model.
|
||||
/** @brief Function sets optimal parameters values for chosen SVM SGD model.
|
||||
* If chosen type is ASGD, function sets the following values for parameters of model:
|
||||
* lambda = 0.00001;
|
||||
* gamma0 = 0.05;
|
||||
* c = 0.75;
|
||||
* \f$\lambda = 0.00001\f$;
|
||||
* \f$\gamma_0 = 0.05\f$;
|
||||
* \f$c = 0.75\f$;
|
||||
* termCrit.maxCount = 100000;
|
||||
* termCrit.epsilon = 0.00001;
|
||||
*
|
||||
* If SGD:
|
||||
* lambda = 0.0001;
|
||||
* gamma0 = 0.05;
|
||||
* c = 1;
|
||||
* \f$\lambda = 0.0001\f$;
|
||||
* \f$\gamma_0 = 0.05\f$;
|
||||
* \f$c = 1\f$;
|
||||
* termCrit.maxCount = 100000;
|
||||
* termCrit.epsilon = 0.00001;
|
||||
* @param type is the type of SVMSGD classifier. Legal values are SvmsgdType::SGD and SvmsgdType::ASGD.
|
||||
* @param svmsgdType is the type of SVMSGD classifier. Legal values are SvmsgdType::SGD and SvmsgdType::ASGD.
|
||||
* Recommended value is SvmsgdType::ASGD (by default).
|
||||
* @param marginType is the type of margin constraint. Legal values are MarginType::SOFT_MARGIN and MarginType::HARD_MARGIN.
|
||||
* Default value is MarginType::SOFT_MARGIN.
|
||||
*/
|
||||
CV_WRAP virtual void setOptimalParameters(int type = ASGD) = 0;
|
||||
CV_WRAP virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN) = 0;
|
||||
|
||||
/** %Algorithm type, one of SVMSGD::SvmsgdType. */
|
||||
/** @brief %Algorithm type, one of SVMSGD::SvmsgdType. */
|
||||
/** @see setAlgorithmType */
|
||||
CV_WRAP virtual int getType() const = 0;
|
||||
CV_WRAP virtual int getSvmsgdType() const = 0;
|
||||
/** @copybrief getAlgorithmType @see getAlgorithmType */
|
||||
CV_WRAP virtual void setType(int type) = 0;
|
||||
CV_WRAP virtual void setSvmsgdType(int svmsgdType) = 0;
|
||||
|
||||
/** Parameter _Lambda_ of a %SVMSGD optimization problem. Default value is 0. */
|
||||
/** @brief %Margin type, one of SVMSGD::MarginType. */
|
||||
/** @see setMarginType */
|
||||
CV_WRAP virtual int getMarginType() const = 0;
|
||||
/** @copybrief getMarginType @see getMarginType */
|
||||
CV_WRAP virtual void setMarginType(int marginType) = 0;
|
||||
|
||||
|
||||
/** @brief Parameter \f$\lambda\f$ of a %SVMSGD optimization problem. Default value is 0. */
|
||||
/** @see setLambda */
|
||||
CV_WRAP virtual float getLambda() const = 0;
|
||||
/** @copybrief getLambda @see getLambda */
|
||||
CV_WRAP virtual void setLambda(float lambda) = 0;
|
||||
|
||||
/** Parameter _Gamma0_ of a %SVMSGD optimization problem. Default value is 0. */
|
||||
/** @brief Parameter \f$\gamma_0\f$ of a %SVMSGD optimization problem. Default value is 0. */
|
||||
/** @see setGamma0 */
|
||||
CV_WRAP virtual float getGamma0() const = 0;
|
||||
/** @copybrief getGamma0 @see getGamma0 */
|
||||
CV_WRAP virtual void setGamma0(float gamma0) = 0;
|
||||
|
||||
/** Parameter _C_ of a %SVMSGD optimization problem. Default value is 0. */
|
||||
/** @brief Parameter \f$c\f$ of a %SVMSGD optimization problem. Default value is 0. */
|
||||
/** @see setC */
|
||||
CV_WRAP virtual float getC() const = 0;
|
||||
/** @copybrief getC @see getC */
|
||||
|
@ -42,7 +42,6 @@
|
||||
|
||||
#include "precomp.hpp"
|
||||
#include "limits"
|
||||
//#include "math.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@ -76,9 +75,9 @@ public:
|
||||
|
||||
virtual void clear();
|
||||
|
||||
virtual void write(FileStorage& fs) const;
|
||||
virtual void write(FileStorage &fs) const;
|
||||
|
||||
virtual void read(const FileNode& fn);
|
||||
virtual void read(const FileNode &fn);
|
||||
|
||||
virtual Mat getWeights(){ return weights_; }
|
||||
|
||||
@ -88,11 +87,15 @@ public:
|
||||
|
||||
virtual String getDefaultName() const {return "opencv_ml_svmsgd";}
|
||||
|
||||
virtual void setOptimalParameters(int type = ASGD);
|
||||
virtual void setOptimalParameters(int svmsgdType = ASGD, int marginType = SOFT_MARGIN);
|
||||
|
||||
virtual int getType() const;
|
||||
virtual int getSvmsgdType() const;
|
||||
|
||||
virtual void setType(int type);
|
||||
virtual void setSvmsgdType(int svmsgdType);
|
||||
|
||||
virtual int getMarginType() const;
|
||||
|
||||
virtual void setMarginType(int marginType);
|
||||
|
||||
CV_IMPL_PROPERTY(float, Lambda, params.lambda)
|
||||
CV_IMPL_PROPERTY(float, Gamma0, params.gamma0)
|
||||
@ -100,21 +103,21 @@ public:
|
||||
CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
|
||||
|
||||
private:
|
||||
void updateWeights(InputArray sample, bool isFirstClass, float gamma, Mat weights);
|
||||
void updateWeights(InputArray sample, bool isFirstClass, float gamma, Mat &weights);
|
||||
|
||||
std::pair<bool,bool> areClassesEmpty(Mat responses);
|
||||
|
||||
void writeParams( FileStorage& fs ) const;
|
||||
void writeParams( FileStorage &fs ) const;
|
||||
|
||||
void readParams( const FileNode& fn );
|
||||
void readParams( const FileNode &fn );
|
||||
|
||||
static inline bool isFirstClass(float val) { return val > 0; }
|
||||
|
||||
static void normalizeSamples(Mat &matrix, Mat &multiplier, Mat &average);
|
||||
static void normalizeSamples(Mat &matrix, Mat &average, float &multiplier);
|
||||
|
||||
float calcShift(InputArray _samples, InputArray _responses) const;
|
||||
|
||||
static void makeExtendedTrainSamples(const Mat trainSamples, Mat &extendedTrainSamples, Mat &multiplier);
|
||||
static void makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier);
|
||||
|
||||
|
||||
|
||||
@ -130,6 +133,7 @@ private:
|
||||
float c;
|
||||
TermCriteria termCrit;
|
||||
SvmsgdType svmsgdType;
|
||||
MarginType marginType;
|
||||
};
|
||||
|
||||
SVMSGDParams params;
|
||||
@ -160,7 +164,7 @@ std::pair<bool,bool> SVMSGDImpl::areClassesEmpty(Mat responses)
|
||||
return emptyInClasses;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &multiplier, Mat &average)
|
||||
void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &average, float &multiplier)
|
||||
{
|
||||
int featuresCount = samples.cols;
|
||||
int samplesCount = samples.rows;
|
||||
@ -176,37 +180,25 @@ void SVMSGDImpl::normalizeSamples(Mat &samples, Mat &multiplier, Mat &average)
|
||||
samples.row(sampleIndex) -= average;
|
||||
}
|
||||
|
||||
Mat featureNorm(1, featuresCount, samples.type());
|
||||
for (int featureIndex = 0; featureIndex < featuresCount; featureIndex++)
|
||||
{
|
||||
featureNorm.at<float>(featureIndex) = norm(samples.col(featureIndex));
|
||||
}
|
||||
double normValue = norm(samples);
|
||||
|
||||
multiplier = sqrt(samplesCount) / featureNorm;
|
||||
for (int sampleIndex = 0; sampleIndex < samplesCount; sampleIndex++)
|
||||
{
|
||||
samples.row(sampleIndex) = samples.row(sampleIndex).mul(multiplier);
|
||||
}
|
||||
multiplier = sqrt(samples.total()) / normValue;
|
||||
|
||||
samples *= multiplier;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::makeExtendedTrainSamples(const Mat trainSamples, Mat &extendedTrainSamples, Mat &multiplier)
|
||||
void SVMSGDImpl::makeExtendedTrainSamples(const Mat &trainSamples, Mat &extendedTrainSamples, Mat &average, float &multiplier)
|
||||
{
|
||||
Mat normalisedTrainSamples = trainSamples.clone();
|
||||
int samplesCount = normalisedTrainSamples.rows;
|
||||
|
||||
Mat average;
|
||||
|
||||
normalizeSamples(normalisedTrainSamples, multiplier, average);
|
||||
normalizeSamples(normalisedTrainSamples, average, multiplier);
|
||||
|
||||
Mat onesCol = Mat::ones(samplesCount, 1, CV_32F);
|
||||
cv::hconcat(normalisedTrainSamples, onesCol, extendedTrainSamples);
|
||||
|
||||
//cout << "SVMSGDImpl::makeExtendedTrainSamples average: \n" << average << endl;
|
||||
//cout << "SVMSGDImpl::makeExtendedTrainSamples multiplier: \n" << multiplier << endl;
|
||||
}
|
||||
|
||||
|
||||
void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float gamma, Mat weights)
|
||||
void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float gamma, Mat& weights)
|
||||
{
|
||||
Mat sample = _sample.getMat();
|
||||
|
||||
@ -226,7 +218,7 @@ void SVMSGDImpl::updateWeights(InputArray _sample, bool firstClass, float gamma,
|
||||
|
||||
float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
||||
{
|
||||
float distance_to_classes[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
|
||||
float distanceToClasses[2] = { std::numeric_limits<float>::max(), std::numeric_limits<float>::max() };
|
||||
|
||||
Mat trainSamples = _samples.getMat();
|
||||
int trainSamplesCount = trainSamples.rows;
|
||||
@ -241,36 +233,29 @@ float SVMSGDImpl::calcShift(InputArray _samples, InputArray _responses) const
|
||||
bool firstClass = isFirstClass(trainResponses.at<float>(samplesIndex));
|
||||
int index = firstClass ? 0:1;
|
||||
float signToMul = firstClass ? 1 : -1;
|
||||
float cur_distance = dotProduct * signToMul;
|
||||
float curDistance = dotProduct * signToMul;
|
||||
|
||||
if (cur_distance < distance_to_classes[index])
|
||||
if (curDistance < distanceToClasses[index])
|
||||
{
|
||||
distance_to_classes[index] = cur_distance;
|
||||
distanceToClasses[index] = curDistance;
|
||||
}
|
||||
}
|
||||
|
||||
return -(distance_to_classes[0] - distance_to_classes[1]) / 2.f;
|
||||
return -(distanceToClasses[0] - distanceToClasses[1]) / 2.f;
|
||||
}
|
||||
|
||||
bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
{
|
||||
//cout << "SVMSGDImpl::train begin" << endl;
|
||||
clear();
|
||||
CV_Assert( isClassifier() ); //toDo: consider
|
||||
|
||||
Mat trainSamples = data->getTrainSamples();
|
||||
|
||||
//cout << "SVMSGDImpl::train trainSamples: \n" << trainSamples << endl;
|
||||
|
||||
int featureCount = trainSamples.cols;
|
||||
Mat trainResponses = data->getTrainResponses(); // (trainSamplesCount x 1) matrix
|
||||
|
||||
//cout << "SVMSGDImpl::train trainresponses: \n" << trainResponses << endl;
|
||||
|
||||
std::pair<bool,bool> areEmpty = areClassesEmpty(trainResponses);
|
||||
|
||||
//cout << "SVMSGDImpl::train areEmpty" << areEmpty.first << "," << areEmpty.second << endl;
|
||||
|
||||
if ( areEmpty.first && areEmpty.second )
|
||||
{
|
||||
return false;
|
||||
@ -283,10 +268,9 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
}
|
||||
|
||||
Mat extendedTrainSamples;
|
||||
Mat multiplier;
|
||||
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, multiplier);
|
||||
|
||||
//cout << "SVMSGDImpl::train extendedTrainSamples: \n" << extendedTrainSamples << endl;
|
||||
Mat average;
|
||||
float multiplier = 0;
|
||||
makeExtendedTrainSamples(trainSamples, extendedTrainSamples, average, multiplier);
|
||||
|
||||
int extendedTrainSamplesCount = extendedTrainSamples.rows;
|
||||
int extendedFeatureCount = extendedTrainSamples.cols;
|
||||
@ -301,6 +285,7 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
|
||||
RNG rng(0);
|
||||
|
||||
CV_Assert (params.termCrit.type & TermCriteria::COUNT || params.termCrit.type & TermCriteria::EPS);
|
||||
int maxCount = (params.termCrit.type & TermCriteria::COUNT) ? params.termCrit.maxCount : INT_MAX;
|
||||
double epsilon = (params.termCrit.type & TermCriteria::EPS) ? params.termCrit.epsilon : 0;
|
||||
|
||||
@ -336,17 +321,20 @@ bool SVMSGDImpl::train(const Ptr<TrainData>& data, int)
|
||||
extendedWeights = averageExtendedWeights;
|
||||
}
|
||||
|
||||
//cout << "SVMSGDImpl::train extendedWeights: \n" << extendedWeights << endl;
|
||||
|
||||
Rect roi(0, 0, featureCount, 1);
|
||||
weights_ = extendedWeights(roi);
|
||||
weights_ = weights_.mul(1/multiplier);
|
||||
weights_ *= multiplier;
|
||||
|
||||
//cout << "SVMSGDImpl::train weights: \n" << weights_ << endl;
|
||||
CV_Assert(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN);
|
||||
|
||||
shift_ = calcShift(trainSamples, trainResponses);
|
||||
|
||||
//cout << "SVMSGDImpl::train shift = " << shift_ << endl;
|
||||
if (params.marginType == SOFT_MARGIN)
|
||||
{
|
||||
shift_ = extendedWeights.at<float>(featureCount) - weights_.dot(average);
|
||||
}
|
||||
else
|
||||
{
|
||||
shift_ = calcShift(trainSamples, trainResponses);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -385,6 +373,8 @@ float SVMSGDImpl::predict( InputArray _samples, OutputArray _results, int ) cons
|
||||
bool SVMSGDImpl::isClassifier() const
|
||||
{
|
||||
return (params.svmsgdType == SGD || params.svmsgdType == ASGD)
|
||||
&&
|
||||
(params.marginType == SOFT_MARGIN || params.marginType == HARD_MARGIN)
|
||||
&&
|
||||
(params.lambda > 0) && (params.gamma0 > 0) && (params.c >= 0);
|
||||
}
|
||||
@ -417,15 +407,32 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||
case ASGD:
|
||||
SvmsgdTypeStr = "ASGD";
|
||||
break;
|
||||
case ILLEGAL_VALUE:
|
||||
SvmsgdTypeStr = format("Uknown_%d", params.svmsgdType);
|
||||
case ILLEGAL_SVMSGD_TYPE:
|
||||
SvmsgdTypeStr = format("Unknown_%d", params.svmsgdType);
|
||||
default:
|
||||
std::cout << "params.svmsgdType isn't initialized" << std::endl;
|
||||
}
|
||||
|
||||
|
||||
fs << "svmsgdType" << SvmsgdTypeStr;
|
||||
|
||||
String marginTypeStr;
|
||||
|
||||
switch (params.marginType)
|
||||
{
|
||||
case SOFT_MARGIN:
|
||||
marginTypeStr = "SOFT_MARGIN";
|
||||
break;
|
||||
case HARD_MARGIN:
|
||||
marginTypeStr = "HARD_MARGIN";
|
||||
break;
|
||||
case ILLEGAL_MARGIN_TYPE:
|
||||
marginTypeStr = format("Unknown_%d", params.marginType);
|
||||
default:
|
||||
std::cout << "params.marginType isn't initialized" << std::endl;
|
||||
}
|
||||
|
||||
fs << "marginType" << marginTypeStr;
|
||||
|
||||
fs << "lambda" << params.lambda;
|
||||
fs << "gamma0" << params.gamma0;
|
||||
fs << "c" << params.c;
|
||||
@ -438,8 +445,6 @@ void SVMSGDImpl::writeParams( FileStorage& fs ) const
|
||||
fs << "}";
|
||||
}
|
||||
|
||||
|
||||
|
||||
void SVMSGDImpl::read(const FileNode& fn)
|
||||
{
|
||||
clear();
|
||||
@ -455,13 +460,23 @@ void SVMSGDImpl::readParams( const FileNode& fn )
|
||||
String svmsgdTypeStr = (String)fn["svmsgdType"];
|
||||
SvmsgdType svmsgdType =
|
||||
svmsgdTypeStr == "SGD" ? SGD :
|
||||
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_VALUE;
|
||||
svmsgdTypeStr == "ASGD" ? ASGD : ILLEGAL_SVMSGD_TYPE;
|
||||
|
||||
if( svmsgdType == ILLEGAL_VALUE )
|
||||
if( svmsgdType == ILLEGAL_SVMSGD_TYPE )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid SVMSGD type" );
|
||||
|
||||
params.svmsgdType = svmsgdType;
|
||||
|
||||
String marginTypeStr = (String)fn["marginType"];
|
||||
MarginType marginType =
|
||||
marginTypeStr == "SOFT_MARGIN" ? SOFT_MARGIN :
|
||||
marginTypeStr == "HARD_MARGIN" ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
||||
|
||||
if( marginType == ILLEGAL_MARGIN_TYPE )
|
||||
CV_Error( CV_StsParseError, "Missing or invalid margin type" );
|
||||
|
||||
params.marginType = marginType;
|
||||
|
||||
CV_Assert ( fn["lambda"].isReal() );
|
||||
params.lambda = (float)fn["lambda"];
|
||||
|
||||
@ -494,7 +509,8 @@ SVMSGDImpl::SVMSGDImpl()
|
||||
{
|
||||
clear();
|
||||
|
||||
params.svmsgdType = ILLEGAL_VALUE;
|
||||
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
|
||||
params.marginType = ILLEGAL_MARGIN_TYPE;
|
||||
|
||||
// Parameters for learning
|
||||
params.lambda = 0; // regularization
|
||||
@ -505,26 +521,28 @@ SVMSGDImpl::SVMSGDImpl()
|
||||
params.termCrit = _termCrit;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setOptimalParameters(int type)
|
||||
void SVMSGDImpl::setOptimalParameters(int svmsgdType, int marginType)
|
||||
{
|
||||
switch (type)
|
||||
switch (svmsgdType)
|
||||
{
|
||||
case SGD:
|
||||
params.svmsgdType = SGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
||||
params.lambda = 0.0001;
|
||||
params.gamma0 = 0.05;
|
||||
params.c = 1;
|
||||
params.termCrit.maxCount = 100000;
|
||||
params.termCrit.epsilon = 0.00001;
|
||||
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
|
||||
break;
|
||||
|
||||
case ASGD:
|
||||
params.svmsgdType = ASGD;
|
||||
params.marginType = (marginType == SOFT_MARGIN) ? SOFT_MARGIN :
|
||||
(marginType == HARD_MARGIN) ? HARD_MARGIN : ILLEGAL_MARGIN_TYPE;
|
||||
params.lambda = 0.00001;
|
||||
params.gamma0 = 0.05;
|
||||
params.c = 0.75;
|
||||
params.termCrit.maxCount = 100000;
|
||||
params.termCrit.epsilon = 0.00001;
|
||||
params.termCrit = TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 100000, 0.00001);
|
||||
break;
|
||||
|
||||
default:
|
||||
@ -532,7 +550,7 @@ void SVMSGDImpl::setOptimalParameters(int type)
|
||||
}
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setType(int type)
|
||||
void SVMSGDImpl::setSvmsgdType(int type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
@ -543,13 +561,33 @@ void SVMSGDImpl::setType(int type)
|
||||
params.svmsgdType = ASGD;
|
||||
break;
|
||||
default:
|
||||
params.svmsgdType = ILLEGAL_VALUE;
|
||||
params.svmsgdType = ILLEGAL_SVMSGD_TYPE;
|
||||
}
|
||||
}
|
||||
|
||||
int SVMSGDImpl::getType() const
|
||||
int SVMSGDImpl::getSvmsgdType() const
|
||||
{
|
||||
return params.svmsgdType;
|
||||
}
|
||||
|
||||
void SVMSGDImpl::setMarginType(int type)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case HARD_MARGIN:
|
||||
params.marginType = HARD_MARGIN;
|
||||
break;
|
||||
case SOFT_MARGIN:
|
||||
params.marginType = SOFT_MARGIN;
|
||||
break;
|
||||
default:
|
||||
params.marginType = ILLEGAL_MARGIN_TYPE;
|
||||
}
|
||||
}
|
||||
|
||||
int SVMSGDImpl::getMarginType() const
|
||||
{
|
||||
return params.marginType;
|
||||
}
|
||||
} //ml
|
||||
} //cv
|
||||
|
@ -199,10 +199,19 @@ int str_to_svmsgd_type( String& str )
|
||||
return SVMSGD::SGD;
|
||||
if ( !str.compare("ASGD") )
|
||||
return SVMSGD::ASGD;
|
||||
CV_Error( CV_StsBadArg, "incorrect boost type string" );
|
||||
CV_Error( CV_StsBadArg, "incorrect svmsgd type string" );
|
||||
return -1;
|
||||
}
|
||||
|
||||
int str_to_margin_type( String& str )
|
||||
{
|
||||
if ( !str.compare("SOFT_MARGIN") )
|
||||
return SVMSGD::SOFT_MARGIN;
|
||||
if ( !str.compare("HARD_MARGIN") )
|
||||
return SVMSGD::HARD_MARGIN;
|
||||
CV_Error( CV_StsBadArg, "incorrect svmsgd margin type string" );
|
||||
return -1;
|
||||
}
|
||||
// ---------------------------------- MLBaseTest ---------------------------------------------------
|
||||
|
||||
CV_MLBaseTest::CV_MLBaseTest(const char* _modelName)
|
||||
@ -452,10 +461,16 @@ int CV_MLBaseTest::train( int testCaseIdx )
|
||||
{
|
||||
String svmsgdTypeStr;
|
||||
modelParamsNode["svmsgdType"] >> svmsgdTypeStr;
|
||||
Ptr<SVMSGD> m = SVMSGD::create();
|
||||
int type = str_to_svmsgd_type( svmsgdTypeStr );
|
||||
m->setType(type);
|
||||
//m->setType(str_to_svmsgd_type( svmsgdTypeStr ));
|
||||
|
||||
Ptr<SVMSGD> m = SVMSGD::create();
|
||||
int svmsgdType = str_to_svmsgd_type( svmsgdTypeStr );
|
||||
m->setSvmsgdType(svmsgdType);
|
||||
|
||||
String marginTypeStr;
|
||||
modelParamsNode["marginType"] >> marginTypeStr;
|
||||
int marginType = str_to_margin_type( marginTypeStr );
|
||||
m->setMarginType(marginType);
|
||||
|
||||
m->setLambda(modelParamsNode["lambda"]);
|
||||
m->setGamma0(modelParamsNode["gamma0"]);
|
||||
m->setC(modelParamsNode["c"]);
|
||||
|
@ -270,11 +270,7 @@ TEST(ML_DTree, legacy_load) { CV_LegacyTest test(CV_DTREE, "_abalone.xml;_mushro
|
||||
TEST(ML_NBayes, legacy_load) { CV_LegacyTest test(CV_NBAYES, "_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_SVM, legacy_load) { CV_LegacyTest test(CV_SVM, "_poletelecomm.xml;_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_RTrees, legacy_load) { CV_LegacyTest test(CV_RTREES, "_waveform.xml"); test.safe_run(); }
|
||||
TEST(ML_SVMSGD, legacy_load)
|
||||
{
|
||||
CV_LegacyTest test(CV_SVMSGD, "_waveform.xml");
|
||||
test.safe_run();
|
||||
}
|
||||
TEST(ML_SVMSGD, legacy_load) { CV_LegacyTest test(CV_SVMSGD, "_waveform.xml"); test.safe_run(); }
|
||||
|
||||
/*TEST(ML_SVM, throw_exception_when_save_untrained_model)
|
||||
{
|
||||
|
@ -52,45 +52,99 @@ using cv::ml::TrainData;
|
||||
class CV_SVMSGDTrainTest : public cvtest::BaseTest
|
||||
{
|
||||
public:
|
||||
CV_SVMSGDTrainTest(Mat _weights, float shift);
|
||||
enum TrainDataType
|
||||
{
|
||||
UNIFORM_SAME_SCALE,
|
||||
UNIFORM_DIFFERENT_SCALES
|
||||
};
|
||||
|
||||
CV_SVMSGDTrainTest(Mat _weights, float shift, TrainDataType type, double precision = 0.01);
|
||||
private:
|
||||
virtual void run( int start_from );
|
||||
float decisionFunction(Mat sample, Mat weights, float shift);
|
||||
static float decisionFunction(const Mat &sample, const Mat &weights, float shift);
|
||||
void makeTrainData(Mat weights, float shift);
|
||||
void makeTestData(Mat weights, float shift);
|
||||
void generateSameScaleData(Mat &samples);
|
||||
void generateDifferentScalesData(Mat &samples, float shift);
|
||||
|
||||
TrainDataType type;
|
||||
double precision;
|
||||
cv::Ptr<TrainData> data;
|
||||
cv::Mat testSamples;
|
||||
cv::Mat testResponses;
|
||||
static const int TEST_VALUE_LIMIT = 500;
|
||||
};
|
||||
|
||||
CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift)
|
||||
void CV_SVMSGDTrainTest::generateSameScaleData(Mat &samples)
|
||||
{
|
||||
float lowerLimit = -TEST_VALUE_LIMIT;
|
||||
float upperLimit = TEST_VALUE_LIMIT;
|
||||
cv::RNG rng(0);
|
||||
rng.fill(samples, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
}
|
||||
|
||||
void CV_SVMSGDTrainTest::generateDifferentScalesData(Mat &samples, float shift)
|
||||
{
|
||||
int featureCount = samples.cols;
|
||||
|
||||
float lowerLimit = -TEST_VALUE_LIMIT;
|
||||
float upperLimit = TEST_VALUE_LIMIT;
|
||||
cv::RNG rng(10);
|
||||
|
||||
|
||||
for (int featureIndex = 0; featureIndex < featureCount; featureIndex++)
|
||||
{
|
||||
int crit = rng.uniform(0, 2);
|
||||
|
||||
if (crit > 0)
|
||||
{
|
||||
rng.fill(samples.col(featureIndex), RNG::UNIFORM, lowerLimit - shift, upperLimit - shift);
|
||||
}
|
||||
else
|
||||
{
|
||||
rng.fill(samples.col(featureIndex), RNG::UNIFORM, lowerLimit/10, upperLimit/10);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CV_SVMSGDTrainTest::makeTrainData(Mat weights, float shift)
|
||||
{
|
||||
int datasize = 100000;
|
||||
int varCount = weights.cols;
|
||||
cv::Mat samples = cv::Mat::zeros( datasize, varCount, CV_32FC1 );
|
||||
cv::Mat responses = cv::Mat::zeros( datasize, 1, CV_32FC1 );
|
||||
cv::RNG rng(0);
|
||||
int featureCount = weights.cols;
|
||||
cv::Mat samples = cv::Mat::zeros(datasize, featureCount, CV_32FC1);
|
||||
cv::Mat responses = cv::Mat::zeros(datasize, 1, CV_32FC1);
|
||||
|
||||
switch(type)
|
||||
{
|
||||
case UNIFORM_SAME_SCALE:
|
||||
generateSameScaleData(samples);
|
||||
break;
|
||||
case UNIFORM_DIFFERENT_SCALES:
|
||||
generateDifferentScalesData(samples, shift);
|
||||
break;
|
||||
default:
|
||||
CV_Error(CV_StsBadArg, "Unknown train data type");
|
||||
}
|
||||
|
||||
for (int sampleIndex = 0; sampleIndex < datasize; sampleIndex++)
|
||||
{
|
||||
responses.at<float>(sampleIndex) = decisionFunction(samples.row(sampleIndex), weights, shift) > 0 ? 1 : -1;
|
||||
}
|
||||
|
||||
data = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
|
||||
}
|
||||
|
||||
void CV_SVMSGDTrainTest::makeTestData(Mat weights, float shift)
|
||||
{
|
||||
int testSamplesCount = 100000;
|
||||
int featureCount = weights.cols;
|
||||
|
||||
float lowerLimit = -TEST_VALUE_LIMIT;
|
||||
float upperLimit = TEST_VALUE_LIMIT;
|
||||
|
||||
cv::RNG rng(0);
|
||||
|
||||
rng.fill(samples, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
for (int sampleIndex = 0; sampleIndex < datasize; sampleIndex++)
|
||||
{
|
||||
responses.at<float>( sampleIndex ) = decisionFunction(samples.row(sampleIndex), weights, shift) > 0 ? 1 : -1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
std::cout << "real weights\n" << weights/norm(weights) << "\n" << std::endl;
|
||||
std::cout << "real shift \n" << shift/norm(weights) << "\n" << std::endl;
|
||||
|
||||
data = TrainData::create( samples, cv::ml::ROW_SAMPLE, responses );
|
||||
|
||||
int testSamplesCount = 100000;
|
||||
|
||||
testSamples.create(testSamplesCount, varCount, CV_32FC1);
|
||||
testSamples.create(testSamplesCount, featureCount, CV_32FC1);
|
||||
rng.fill(testSamples, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
testResponses.create(testSamplesCount, 1, CV_32FC1);
|
||||
|
||||
@ -100,12 +154,24 @@ CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift)
|
||||
}
|
||||
}
|
||||
|
||||
CV_SVMSGDTrainTest::CV_SVMSGDTrainTest(Mat weights, float shift, TrainDataType _type, double _precision)
|
||||
{
|
||||
type = _type;
|
||||
precision = _precision;
|
||||
makeTrainData(weights, shift);
|
||||
makeTestData(weights, shift);
|
||||
}
|
||||
|
||||
float CV_SVMSGDTrainTest::decisionFunction(const Mat &sample, const Mat &weights, float shift)
|
||||
{
|
||||
return sample.dot(weights) + shift;
|
||||
}
|
||||
|
||||
void CV_SVMSGDTrainTest::run( int /*start_from*/ )
|
||||
{
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
|
||||
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
||||
svmsgd->setTermCriteria(TermCriteria(TermCriteria::EPS, 0, 0.00005));
|
||||
svmsgd->setOptimalParameters();
|
||||
|
||||
svmsgd->train(data);
|
||||
|
||||
@ -118,77 +184,106 @@ void CV_SVMSGDTrainTest::run( int /*start_from*/ )
|
||||
|
||||
for (int i = 0; i < testSamplesCount; i++)
|
||||
{
|
||||
if (responses.at<float>(i) * testResponses.at<float>(i) < 0 )
|
||||
if (responses.at<float>(i) * testResponses.at<float>(i) < 0)
|
||||
errCount++;
|
||||
}
|
||||
|
||||
|
||||
float normW = norm(svmsgd->getWeights());
|
||||
|
||||
std::cout << "found weights\n" << svmsgd->getWeights()/normW << "\n" << std::endl;
|
||||
std::cout << "found shift \n" << svmsgd->getShift()/normW << "\n" << std::endl;
|
||||
|
||||
float err = (float)errCount / testSamplesCount;
|
||||
std::cout << "err " << err << std::endl;
|
||||
|
||||
if ( err > 0.01 )
|
||||
if ( err > precision )
|
||||
{
|
||||
ts->set_failed_test_info( cvtest::TS::FAIL_BAD_ACCURACY );
|
||||
ts->set_failed_test_info(cvtest::TS::FAIL_BAD_ACCURACY);
|
||||
}
|
||||
}
|
||||
|
||||
float CV_SVMSGDTrainTest::decisionFunction(Mat sample, Mat weights, float shift)
|
||||
|
||||
void makeWeightsAndShift(int featureCount, Mat &weights, float &shift)
|
||||
{
|
||||
return sample.dot(weights) + shift;
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, train0)
|
||||
{
|
||||
int varCount = 2;
|
||||
|
||||
Mat weights;
|
||||
weights.create(1, varCount, CV_32FC1);
|
||||
weights.at<float>(0) = 1;
|
||||
weights.at<float>(1) = 0;
|
||||
cv::RNG rng(1);
|
||||
float shift = rng.uniform(-varCount, varCount);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, train1)
|
||||
{
|
||||
int varCount = 5;
|
||||
|
||||
Mat weights;
|
||||
weights.create(1, varCount, CV_32FC1);
|
||||
|
||||
float lowerLimit = -1;
|
||||
float upperLimit = 1;
|
||||
weights.create(1, featureCount, CV_32FC1);
|
||||
cv::RNG rng(0);
|
||||
double lowerLimit = -1;
|
||||
double upperLimit = 1;
|
||||
|
||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
|
||||
float shift = rng.uniform(-varCount, varCount);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift);
|
||||
test.safe_run();
|
||||
shift = rng.uniform(-featureCount, featureCount);
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, train2)
|
||||
|
||||
TEST(ML_SVMSGD, trainSameScale2)
|
||||
{
|
||||
int varCount = 100;
|
||||
int featureCount = 2;
|
||||
|
||||
Mat weights;
|
||||
weights.create(1, varCount, CV_32FC1);
|
||||
|
||||
float lowerLimit = -1;
|
||||
float upperLimit = 1;
|
||||
cv::RNG rng(0);
|
||||
rng.fill(weights, RNG::UNIFORM, lowerLimit, upperLimit);
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
float shift = rng.uniform(-varCount, varCount);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights,shift);
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_SAME_SCALE);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainSameScale5)
|
||||
{
|
||||
int featureCount = 5;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_SAME_SCALE);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainSameScale100)
|
||||
{
|
||||
int featureCount = 100;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_SAME_SCALE);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainDifferentScales2)
|
||||
{
|
||||
int featureCount = 2;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_DIFFERENT_SCALES, 0.01);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainDifferentScales5)
|
||||
{
|
||||
int featureCount = 5;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_DIFFERENT_SCALES, 0.05);
|
||||
test.safe_run();
|
||||
}
|
||||
|
||||
TEST(ML_SVMSGD, trainDifferentScales100)
|
||||
{
|
||||
int featureCount = 100;
|
||||
|
||||
Mat weights;
|
||||
|
||||
float shift = 0;
|
||||
makeWeightsAndShift(featureCount, weights, shift);
|
||||
|
||||
CV_SVMSGDTrainTest test(weights, shift, CV_SVMSGDTrainTest::UNIFORM_DIFFERENT_SCALES, 0.10);
|
||||
test.safe_run();
|
||||
}
|
||||
|
@ -40,16 +40,13 @@ void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int
|
||||
void redraw(Data data, const Point points[2]);
|
||||
|
||||
//add point in train set, train SVMSGD algorithm and draw results on image
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y);
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);
|
||||
|
||||
|
||||
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
|
||||
{
|
||||
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
|
||||
svmsgd->setOptimalParameters(SVMSGD::ASGD);
|
||||
svmsgd->setTermCriteria(TermCriteria(TermCriteria::EPS, 0, 0.00000001));
|
||||
svmsgd->setLambda(0.00000001);
|
||||
|
||||
svmsgd->setOptimalParameters();
|
||||
|
||||
cv::Ptr<TrainData> train_data = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
|
||||
svmsgd->train( train_data );
|
||||
@ -64,6 +61,27 @@ bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift
|
||||
return false;
|
||||
}
|
||||
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
|
||||
{
|
||||
std::pair<Point,Point> currentSegment;
|
||||
|
||||
currentSegment.first = Point(width, 0);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, height);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(width, 0);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(0, height);
|
||||
segments.push_back(currentSegment);
|
||||
}
|
||||
|
||||
|
||||
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
|
||||
{
|
||||
@ -123,27 +141,6 @@ bool findPointsForLine(const Mat &weights, float shift, Point (&points)[2], int
|
||||
return true;
|
||||
}
|
||||
|
||||
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
|
||||
{
|
||||
std::pair<Point,Point> currentSegment;
|
||||
|
||||
currentSegment.first = Point(width, 0);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, height);
|
||||
currentSegment.second = Point(width, height);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(width, 0);
|
||||
segments.push_back(currentSegment);
|
||||
|
||||
currentSegment.first = Point(0, 0);
|
||||
currentSegment.second = Point(0, height);
|
||||
segments.push_back(currentSegment);
|
||||
}
|
||||
|
||||
void redraw(Data data, const Point points[2])
|
||||
{
|
||||
data.img.setTo(0);
|
||||
@ -162,19 +159,20 @@ void redraw(Data data, const Point points[2])
|
||||
imshow("Train svmsgd", data.img);
|
||||
}
|
||||
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y)
|
||||
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
|
||||
{
|
||||
Mat currentSample(1, 2, CV_32F);
|
||||
|
||||
currentSample.at<float>(0,0) = x;
|
||||
currentSample.at<float>(0,1) = y;
|
||||
data.samples.push_back(currentSample);
|
||||
data.responses.push_back(response);
|
||||
|
||||
Mat weights(1, 2, CV_32F);
|
||||
float shift = 0;
|
||||
|
||||
if (doTrain(data.samples, data.responses, weights, shift))
|
||||
{
|
||||
{
|
||||
Point points[2];
|
||||
findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
|
||||
|
||||
@ -189,15 +187,12 @@ static void onMouse( int event, int x, int y, int, void* pData)
|
||||
|
||||
switch( event )
|
||||
{
|
||||
case CV_EVENT_LBUTTONUP:
|
||||
data.responses.push_back(1);
|
||||
addPointRetrainAndRedraw(data, x, y);
|
||||
|
||||
case CV_EVENT_LBUTTONUP:
|
||||
addPointRetrainAndRedraw(data, x, y, 1);
|
||||
break;
|
||||
|
||||
case CV_EVENT_RBUTTONDOWN:
|
||||
data.responses.push_back(-1);
|
||||
addPointRetrainAndRedraw(data, x, y);
|
||||
addPointRetrainAndRedraw(data, x, y, -1);
|
||||
break;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user